Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ julia = "~1.9"
[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"

[targets]
test = ["Test", "SCS", "Zygote", "Ipopt"]
41 changes: 14 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,43 +130,30 @@ policy = Chain(
Dense(64, length(initial_state)),
)

windows = DecisionRules.setup_shooting_windows(
subproblems,
state_params_in,
state_params_out,
Float64.(initial_state),
uncertainty_samples;
window_size=24,
model_factory=() -> DiffOpt.nonlinear_diff_model(optimizer_with_attributes(
SCS.Optimizer,
"verbose" => 0,
)),
)

DecisionRules.train_multiple_shooting(
policy,
initial_state,
subproblems,
windows,
state_params_in,
state_params_out,
uncertainty_sampler;
window_size=24, # e.g., 6, 24, ...
num_batches=100,
num_train_per_batch=32,
optimizer=Flux.Adam(1e-3),
optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer),
)
```

If you want lower-level control, you can pre-build the window models once:

```julia
windows = DecisionRules.setup_shooting_windows(
subproblems,
state_params_in,
state_params_out,
Float64.(initial_state),
uncertainties_structure;
window_size=24,
optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer),
)

uncertainty_sample = uncertainty_sampler()
uncertainties_vec = [[Float32(u[2]) for u in stage_u] for stage_u in uncertainty_sample]

obj = DecisionRules.simulate_multiple_shooting(
windows,
policy,
Float32.(initial_state),
uncertainty_sample,
uncertainties_vec,
)
```

Expand Down
1 change: 1 addition & 0 deletions examples/Atlas/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
MathProgIncidence = "892fab00-3092-4bd0-9c46-66676a93f84e"
MeshCat = "283c5d60-a78f-5afe-a0af-af636b173e11"
MeshCatMechanisms = "6ad125db-dd91-5488-b820-c1df6aab299d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
RigidBodyDynamics = "366cf18f-59d5-5db9-a4de-86a9f6786172"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
Expand Down
208 changes: 12 additions & 196 deletions examples/Atlas/build_atlas_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
Build a multi-stage stochastic optimization problem for Atlas robot balancing.

The problem minimizes deviation from a reference state while subject to:
- Discrete-time dynamics with stochastic perturbations
- Discrete-time dynamics with stochastic perturbations (applied every `perturbation_frequency` stages)
- Torque limits on all joints

# Penalty arguments
Expand All @@ -53,9 +53,10 @@ function build_atlas_subproblems(;
h::Float64 = 0.01,
N::Int = 100,
perturbation_scale::Float64 = 0.1,
perturbation_frequency::Int = 1,
perturbation_indices::Union{Nothing, Vector{Int}} = nothing,
num_scenarios::Int = 10,
penalty::Float64 = 1e3,
penalty::Float64 = 10.0,
penalty_l1::Union{Nothing, Float64} = nothing,
penalty_l2::Union{Nothing, Float64} = nothing,
optimizer = nothing,
Expand All @@ -78,6 +79,10 @@ function build_atlas_subproblems(;
"linear_solver" => "ma27"
))
end

if perturbation_frequency < 1
error("perturbation_frequency must be >= 1")
end

# Default perturbation indices: perturb velocity states
if isnothing(perturbation_indices)
Expand Down Expand Up @@ -240,204 +245,15 @@ function build_atlas_subproblems(;
state_params_out[t] = [(target_x[i], x[i]) for i in 1:nx]

# Generate uncertainty samples (random perturbations)
uncertainty_samples[t] = [(w[i], perturbation_scale * randn(num_scenarios)) for i in 1:n_perturb]
if (t - 1) % perturbation_frequency == 0
uncertainty_samples[t] = [(w[i], perturbation_scale * randn(num_scenarios)) for i in 1:n_perturb]
else
uncertainty_samples[t] = [(w[i], zeros(num_scenarios)) for i in 1:n_perturb]
end
end

initial_state = copy(x_ref)

return subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples,
X_vars, U_vars, x_ref, u_ref, atlas
end


"""
build_atlas_deterministic_equivalent(; kwargs...)

Build a deterministic equivalent formulation for the Atlas balancing problem.
This creates a single large optimization problem instead of decomposed subproblems.

# Penalty arguments
- `penalty`: Legacy argument for L1 norm penalty (backwards compatible)
- `penalty_l1`: Penalty for L1 norm (NormOneCone). If provided, creates L1 deviation constraint.
- `penalty_l2`: Penalty for L2 norm (SecondOrderCone). If provided, creates L2 deviation constraint.
- If both `penalty_l1` and `penalty_l2` are provided, both norms are used with separate penalties.
"""
function build_atlas_deterministic_equivalent(;
atlas::Atlas = Atlas(),
x_ref::Union{Nothing, Vector{Float64}} = nothing,
u_ref::Union{Nothing, Vector{Float64}} = nothing,
h::Float64 = 0.01,
N::Int = 100,
perturbation_scale::Float64 = 0.1,
perturbation_indices::Union{Nothing, Vector{Int}} = nothing,
num_scenarios::Int = 10,
penalty::Float64 = 1e3,
penalty_l1::Union{Nothing, Float64} = nothing,
penalty_l2::Union{Nothing, Float64} = nothing,
)
# Handle penalty arguments: if penalty_l1/penalty_l2 not specified, use legacy penalty for L1
if isnothing(penalty_l1) && isnothing(penalty_l2)
penalty_l1 = penalty
end

# Load reference state if not provided
if isnothing(x_ref) || isnothing(u_ref)
@load joinpath(@__DIR__, "atlas_ref.jld2") x_ref u_ref
end

# Default perturbation indices
if isnothing(perturbation_indices)
perturbation_indices = [atlas.nq + 5]
end

nx = atlas.nx
nu = atlas.nu
n_perturb = length(perturbation_indices)

# Create model
det_equivalent = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer,
"print_level" => 0,
"hsllib" => HSL_jll.libhsl_path,
"linear_solver" => "ma27"
))

# State and control variables for all stages
@variable(det_equivalent, X[t=1:N, i=1:nx], start = x_ref[i])
@variable(det_equivalent, -atlas.torque_limits[i] <= U[t=1:N-1, i=1:nu] <= atlas.torque_limits[i], start = u_ref[i])

# Perturbed state variables
@variable(det_equivalent, X_perturbed[t=1:N-1, i=1:nx], start = x_ref[i])

# Target parameters (policy outputs)
@variable(det_equivalent, target[t=1:N-1, i=1:nx] ∈ MOI.Parameter(x_ref[i]))

# Perturbation parameters
@variable(det_equivalent, w[t=1:N-1, i=1:n_perturb] ∈ MOI.Parameter(0.0))

# Deviation variable
@variable(det_equivalent, norm_deficit >= 0)

# Fix initial condition
for i in 1:nx
fix(X[1, i], x_ref[i]; force=true)
end

# Perturbation constraints
for t in 1:N-1
for i in 1:nx
perturb_idx = findfirst(==(i), perturbation_indices)
if !isnothing(perturb_idx)
@constraint(det_equivalent, X_perturbed[t, i] == X[t, i] + w[t, perturb_idx])
else
@constraint(det_equivalent, X_perturbed[t, i] == X[t, i])
end
end
end

# Objective
@variable(det_equivalent, cost >= 0)

# Create norm constraints based on penalty arguments
use_l1 = !isnothing(penalty_l1)
use_l2 = !isnothing(penalty_l2)
deviation_expr = vec([target[t,i] - X[t+1,i] for t in 1:N-1, i in 1:nx])
deviation_dim = (N-1) * nx

if use_l1 && use_l2
# Both L1 and L2 squared norms
@variable(det_equivalent, norm_l1 >= 0)
@variable(det_equivalent, norm_l2_sq >= 0) # L2 squared (sum of squares)
@constraint(det_equivalent, [norm_l1; deviation_expr] in MOI.NormOneCone(1 + deviation_dim))
@constraint(det_equivalent, norm_l2_sq >= sum(deviation_expr[i]^2 for i in 1:deviation_dim))
@constraint(det_equivalent, norm_deficit >= penalty_l1 * norm_l1 + penalty_l2 * norm_l2_sq)
deficit_coef = 1.0
elseif use_l1
# L1 norm only
@constraint(det_equivalent, [norm_deficit; deviation_expr] in MOI.NormOneCone(1 + deviation_dim))
deficit_coef = penalty_l1
elseif use_l2
# L2 squared norm only (sum of squares)
@constraint(det_equivalent, norm_deficit >= sum(deviation_expr[i]^2 for i in 1:deviation_dim))
deficit_coef = penalty_l2
else
error("At least one of penalty_l1 or penalty_l2 must be specified")
end

@constraint(det_equivalent, cost >= sum((X[t,i] - x_ref[i])^2 for t in 2:N, i in 1:nx))
@objective(det_equivalent, Min,
cost +
deficit_coef * norm_deficit
)

# Build VectorNonlinearOracle (same as subproblems)
VNO_dim = 2*nx + nu

jacobian_structure = Tuple{Int,Int}[]
append!(jacobian_structure, map(i -> (i, i), 1:nx))
for i in 1:nx, j in 1:(nx + nu)
push!(jacobian_structure, (i, j + nx))
end

hessian_lagrangian_structure = [
(i, j)
for i in nx+1:VNO_dim
for j in nx+1:VNO_dim
]

local_atlas = atlas
local_h = h
local_nx = nx
local_nu = nu

VNO = MOI.VectorNonlinearOracle(;
dimension = VNO_dim,
l = zeros(nx),
u = zeros(nx),
eval_f = (ret, z) -> begin
ret[1:local_nx] .= z[1:local_nx] - atlas_dynamics(local_atlas, local_h, z[local_nx+1:VNO_dim]...)
return
end,
jacobian_structure = jacobian_structure,
eval_jacobian = (ret, z) -> begin
dyn_jac = ForwardDiff.jacobian(
xu -> atlas_dynamics(local_atlas, local_h, xu...),
z[local_nx+1:VNO_dim]
)
jnnz = length(jacobian_structure)
ret[1:local_nx] .= ones(local_nx)
ret[local_nx+1:jnnz] .= -reshape(dyn_jac', local_nx * (local_nx + local_nu))
return
end,
hessian_lagrangian_structure = hessian_lagrangian_structure,
eval_hessian_lagrangian = (ret, z, λ) -> begin
hess = ForwardDiff.hessian(
xu -> dot(λ, atlas_dynamics(local_atlas, local_h, xu...)),
z[local_nx+1:VNO_dim]
)
hnnz = length(hessian_lagrangian_structure)
ret[1:hnnz] .= -reshape(hess, hnnz)
return
end
)

# Dynamics constraints
for t in 1:N-1
vars = vcat(X[t+1, :], X_perturbed[t, :], U[t, :])
@constraint(det_equivalent, vars in VNO)
end

# Generate uncertainty samples
uncertainty_samples = Vector{Vector{Tuple{VariableRef, Vector{Float64}}}}(undef, N-1)
for t in 1:N-1
uncertainty_samples[t] = [(w[t, i], perturbation_scale * randn(num_scenarios)) for i in 1:n_perturb]
end

# State parameters for DecisionRules.jl
state_params_in = [collect(X[t, :]) for t in 1:N-1]
state_params_out = [[(target[t,i], X[t+1,i]) for i in 1:nx] for t in 1:N-1]

initial_state = copy(x_ref)

return det_equivalent, state_params_in, state_params_out, initial_state, uncertainty_samples,
X, U, x_ref, u_ref, atlas
end
18 changes: 13 additions & 5 deletions examples/Atlas/train_dr_atlas_det_eq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ include(joinpath(Atlas_dir, "build_atlas_problem.jl"))
# ============================================================================

# Problem parameters
N = 50 # Number of time steps
N = 10 # Number of time steps
h = 0.01 # Time step
perturbation_scale = 0.05 # Scale of random perturbations
perturbation_scale = 1.5 # Scale of random perturbations
num_scenarios = 10 # Number of uncertainty samples per stage
penalty = 1e3 # Penalty for state deviation
penalty = 10.0 # Penalty for state deviation
perturbation_frequency = 5 # Frequency of perturbations (every k stages)

# Training parameters
num_epochs = 1
Expand Down Expand Up @@ -55,13 +56,15 @@ println("Building Atlas deterministic equivalent problem...")
perturbation_scale = perturbation_scale,
num_scenarios = num_scenarios,
penalty = penalty,
perturbation_frequency = perturbation_frequency,
)

# Build deterministic equivalent
det_equivalent = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer,
det_equivalent = DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer,
"print_level" => 0,
"hsllib" => HSL_jll.libhsl_path,
"linear_solver" => "ma27"
"linear_solver" => "ma97",
# "mu_target" => 1e-8,
))

# Convert subproblems to deterministic equivalent using DecisionRules
Expand Down Expand Up @@ -138,6 +141,11 @@ objective_values = [simulate_multistage(
best_obj = mean(objective_values)
println("Initial objective: $best_obj")

# for testing visualization. fill x with visited states
# X[2:end] = [value.([var[2] for var in stage]) for stage in state_params_out_sub]
# calculate distance from reference
# dist = sum((X[t][i] - x_ref[i])^2 for i in 1:length(x_ref) for t in 1:length(X))

model_path = joinpath(model_dir, save_file * ".jld2")
save_control = SaveBest(best_obj, model_path)

Expand Down
Loading