From d5892fbcfc04c68f659b9cedcb151e59d00cc853 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Wed, 28 Jan 2026 22:00:54 -0500 Subject: [PATCH 1/3] atlas problem spec --- examples/Atlas/Project.toml | 1 + examples/Atlas/build_atlas_problem.jl | 206 +----------------- examples/Atlas/train_dr_atlas_det_eq.jl | 12 +- .../Atlas/train_dr_atlas_multipleshooting.jl | 202 +++++++++++++++++ examples/Atlas/visualize_atlas_policy.jl | 10 +- .../Atlas/visualize_atlas_policy_det_eq.jl | 194 +++++++++++++++++ ...visualize_atlas_policy_multipleshooting.jl | 203 +++++++++++++++++ 7 files changed, 627 insertions(+), 201 deletions(-) create mode 100644 examples/Atlas/train_dr_atlas_multipleshooting.jl create mode 100644 examples/Atlas/visualize_atlas_policy_det_eq.jl create mode 100644 examples/Atlas/visualize_atlas_policy_multipleshooting.jl diff --git a/examples/Atlas/Project.toml b/examples/Atlas/Project.toml index 51f1da5..610b8a4 100644 --- a/examples/Atlas/Project.toml +++ b/examples/Atlas/Project.toml @@ -11,6 +11,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" MathProgIncidence = "892fab00-3092-4bd0-9c46-66676a93f84e" MeshCat = "283c5d60-a78f-5afe-a0af-af636b173e11" MeshCatMechanisms = "6ad125db-dd91-5488-b820-c1df6aab299d" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" RigidBodyDynamics = "366cf18f-59d5-5db9-a4de-86a9f6786172" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" diff --git a/examples/Atlas/build_atlas_problem.jl b/examples/Atlas/build_atlas_problem.jl index 10c0d1a..961c0bf 100644 --- a/examples/Atlas/build_atlas_problem.jl +++ b/examples/Atlas/build_atlas_problem.jl @@ -30,7 +30,7 @@ end Build a multi-stage stochastic optimization problem for Atlas robot balancing. The problem minimizes deviation from a reference state while subject to: -- Discrete-time dynamics with stochastic perturbations +- Discrete-time dynamics with stochastic perturbations (applied every `perturbation_frequency` stages) - Torque limits on all joints # Penalty arguments @@ -53,6 +53,7 @@ function build_atlas_subproblems(; h::Float64 = 0.01, N::Int = 100, perturbation_scale::Float64 = 0.1, + perturbation_frequency::Int = 1, perturbation_indices::Union{Nothing, Vector{Int}} = nothing, num_scenarios::Int = 10, penalty::Float64 = 1e3, @@ -78,6 +79,10 @@ function build_atlas_subproblems(; "linear_solver" => "ma27" )) end + + if perturbation_frequency < 1 + error("perturbation_frequency must be >= 1") + end # Default perturbation indices: perturb velocity states if isnothing(perturbation_indices) @@ -240,7 +245,11 @@ function build_atlas_subproblems(; state_params_out[t] = [(target_x[i], x[i]) for i in 1:nx] # Generate uncertainty samples (random perturbations) - uncertainty_samples[t] = [(w[i], perturbation_scale * randn(num_scenarios)) for i in 1:n_perturb] + if (t - 1) % perturbation_frequency == 0 + uncertainty_samples[t] = [(w[i], perturbation_scale * randn(num_scenarios)) for i in 1:n_perturb] + else + uncertainty_samples[t] = [(w[i], zeros(num_scenarios)) for i in 1:n_perturb] + end end initial_state = copy(x_ref) @@ -248,196 +257,3 @@ function build_atlas_subproblems(; return subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, X_vars, U_vars, x_ref, u_ref, atlas end - - -""" - build_atlas_deterministic_equivalent(; kwargs...) - -Build a deterministic equivalent formulation for the Atlas balancing problem. -This creates a single large optimization problem instead of decomposed subproblems. - -# Penalty arguments -- `penalty`: Legacy argument for L1 norm penalty (backwards compatible) -- `penalty_l1`: Penalty for L1 norm (NormOneCone). If provided, creates L1 deviation constraint. -- `penalty_l2`: Penalty for L2 norm (SecondOrderCone). If provided, creates L2 deviation constraint. -- If both `penalty_l1` and `penalty_l2` are provided, both norms are used with separate penalties. -""" -function build_atlas_deterministic_equivalent(; - atlas::Atlas = Atlas(), - x_ref::Union{Nothing, Vector{Float64}} = nothing, - u_ref::Union{Nothing, Vector{Float64}} = nothing, - h::Float64 = 0.01, - N::Int = 100, - perturbation_scale::Float64 = 0.1, - perturbation_indices::Union{Nothing, Vector{Int}} = nothing, - num_scenarios::Int = 10, - penalty::Float64 = 1e3, - penalty_l1::Union{Nothing, Float64} = nothing, - penalty_l2::Union{Nothing, Float64} = nothing, -) - # Handle penalty arguments: if penalty_l1/penalty_l2 not specified, use legacy penalty for L1 - if isnothing(penalty_l1) && isnothing(penalty_l2) - penalty_l1 = penalty - end - - # Load reference state if not provided - if isnothing(x_ref) || isnothing(u_ref) - @load joinpath(@__DIR__, "atlas_ref.jld2") x_ref u_ref - end - - # Default perturbation indices - if isnothing(perturbation_indices) - perturbation_indices = [atlas.nq + 5] - end - - nx = atlas.nx - nu = atlas.nu - n_perturb = length(perturbation_indices) - - # Create model - det_equivalent = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, - "print_level" => 0, - "hsllib" => HSL_jll.libhsl_path, - "linear_solver" => "ma27" - )) - - # State and control variables for all stages - @variable(det_equivalent, X[t=1:N, i=1:nx], start = x_ref[i]) - @variable(det_equivalent, -atlas.torque_limits[i] <= U[t=1:N-1, i=1:nu] <= atlas.torque_limits[i], start = u_ref[i]) - - # Perturbed state variables - @variable(det_equivalent, X_perturbed[t=1:N-1, i=1:nx], start = x_ref[i]) - - # Target parameters (policy outputs) - @variable(det_equivalent, target[t=1:N-1, i=1:nx] ∈ MOI.Parameter(x_ref[i])) - - # Perturbation parameters - @variable(det_equivalent, w[t=1:N-1, i=1:n_perturb] ∈ MOI.Parameter(0.0)) - - # Deviation variable - @variable(det_equivalent, norm_deficit >= 0) - - # Fix initial condition - for i in 1:nx - fix(X[1, i], x_ref[i]; force=true) - end - - # Perturbation constraints - for t in 1:N-1 - for i in 1:nx - perturb_idx = findfirst(==(i), perturbation_indices) - if !isnothing(perturb_idx) - @constraint(det_equivalent, X_perturbed[t, i] == X[t, i] + w[t, perturb_idx]) - else - @constraint(det_equivalent, X_perturbed[t, i] == X[t, i]) - end - end - end - - # Objective - @variable(det_equivalent, cost >= 0) - - # Create norm constraints based on penalty arguments - use_l1 = !isnothing(penalty_l1) - use_l2 = !isnothing(penalty_l2) - deviation_expr = vec([target[t,i] - X[t+1,i] for t in 1:N-1, i in 1:nx]) - deviation_dim = (N-1) * nx - - if use_l1 && use_l2 - # Both L1 and L2 squared norms - @variable(det_equivalent, norm_l1 >= 0) - @variable(det_equivalent, norm_l2_sq >= 0) # L2 squared (sum of squares) - @constraint(det_equivalent, [norm_l1; deviation_expr] in MOI.NormOneCone(1 + deviation_dim)) - @constraint(det_equivalent, norm_l2_sq >= sum(deviation_expr[i]^2 for i in 1:deviation_dim)) - @constraint(det_equivalent, norm_deficit >= penalty_l1 * norm_l1 + penalty_l2 * norm_l2_sq) - deficit_coef = 1.0 - elseif use_l1 - # L1 norm only - @constraint(det_equivalent, [norm_deficit; deviation_expr] in MOI.NormOneCone(1 + deviation_dim)) - deficit_coef = penalty_l1 - elseif use_l2 - # L2 squared norm only (sum of squares) - @constraint(det_equivalent, norm_deficit >= sum(deviation_expr[i]^2 for i in 1:deviation_dim)) - deficit_coef = penalty_l2 - else - error("At least one of penalty_l1 or penalty_l2 must be specified") - end - - @constraint(det_equivalent, cost >= sum((X[t,i] - x_ref[i])^2 for t in 2:N, i in 1:nx)) - @objective(det_equivalent, Min, - cost + - deficit_coef * norm_deficit - ) - - # Build VectorNonlinearOracle (same as subproblems) - VNO_dim = 2*nx + nu - - jacobian_structure = Tuple{Int,Int}[] - append!(jacobian_structure, map(i -> (i, i), 1:nx)) - for i in 1:nx, j in 1:(nx + nu) - push!(jacobian_structure, (i, j + nx)) - end - - hessian_lagrangian_structure = [ - (i, j) - for i in nx+1:VNO_dim - for j in nx+1:VNO_dim - ] - - local_atlas = atlas - local_h = h - local_nx = nx - local_nu = nu - - VNO = MOI.VectorNonlinearOracle(; - dimension = VNO_dim, - l = zeros(nx), - u = zeros(nx), - eval_f = (ret, z) -> begin - ret[1:local_nx] .= z[1:local_nx] - atlas_dynamics(local_atlas, local_h, z[local_nx+1:VNO_dim]...) - return - end, - jacobian_structure = jacobian_structure, - eval_jacobian = (ret, z) -> begin - dyn_jac = ForwardDiff.jacobian( - xu -> atlas_dynamics(local_atlas, local_h, xu...), - z[local_nx+1:VNO_dim] - ) - jnnz = length(jacobian_structure) - ret[1:local_nx] .= ones(local_nx) - ret[local_nx+1:jnnz] .= -reshape(dyn_jac', local_nx * (local_nx + local_nu)) - return - end, - hessian_lagrangian_structure = hessian_lagrangian_structure, - eval_hessian_lagrangian = (ret, z, λ) -> begin - hess = ForwardDiff.hessian( - xu -> dot(λ, atlas_dynamics(local_atlas, local_h, xu...)), - z[local_nx+1:VNO_dim] - ) - hnnz = length(hessian_lagrangian_structure) - ret[1:hnnz] .= -reshape(hess, hnnz) - return - end - ) - - # Dynamics constraints - for t in 1:N-1 - vars = vcat(X[t+1, :], X_perturbed[t, :], U[t, :]) - @constraint(det_equivalent, vars in VNO) - end - - # Generate uncertainty samples - uncertainty_samples = Vector{Vector{Tuple{VariableRef, Vector{Float64}}}}(undef, N-1) - for t in 1:N-1 - uncertainty_samples[t] = [(w[t, i], perturbation_scale * randn(num_scenarios)) for i in 1:n_perturb] - end - - # State parameters for DecisionRules.jl - state_params_in = [collect(X[t, :]) for t in 1:N-1] - state_params_out = [[(target[t,i], X[t+1,i]) for i in 1:nx] for t in 1:N-1] - - initial_state = copy(x_ref) - - return det_equivalent, state_params_in, state_params_out, initial_state, uncertainty_samples, - X, U, x_ref, u_ref, atlas -end diff --git a/examples/Atlas/train_dr_atlas_det_eq.jl b/examples/Atlas/train_dr_atlas_det_eq.jl index 9252c53..c03fb4d 100644 --- a/examples/Atlas/train_dr_atlas_det_eq.jl +++ b/examples/Atlas/train_dr_atlas_det_eq.jl @@ -22,11 +22,12 @@ include(joinpath(Atlas_dir, "build_atlas_problem.jl")) # ============================================================================ # Problem parameters -N = 50 # Number of time steps +N = 10 # Number of time steps h = 0.01 # Time step perturbation_scale = 0.05 # Scale of random perturbations num_scenarios = 10 # Number of uncertainty samples per stage penalty = 1e3 # Penalty for state deviation +perturbation_frequency = 5 # Frequency of perturbations (every k stages) # Training parameters num_epochs = 1 @@ -55,13 +56,15 @@ println("Building Atlas deterministic equivalent problem...") perturbation_scale = perturbation_scale, num_scenarios = num_scenarios, penalty = penalty, + perturbation_frequency = perturbation_frequency, ) # Build deterministic equivalent det_equivalent = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0, "hsllib" => HSL_jll.libhsl_path, - "linear_solver" => "ma27" + "linear_solver" => "ma97", + # "mu_target" => 1e-8, )) # Convert subproblems to deterministic equivalent using DecisionRules @@ -138,6 +141,11 @@ objective_values = [simulate_multistage( best_obj = mean(objective_values) println("Initial objective: $best_obj") +# for testing visualization. fill x with visited states +# X[2:end] = [value.([var[2] for var in stage]) for stage in state_params_out_sub] +# calculate distance from reference +# dist = sum((X[t][i] - x_ref[i])^2 for i in 1:length(x_ref) for t in 1:length(X)) + model_path = joinpath(model_dir, save_file * ".jld2") save_control = SaveBest(best_obj, model_path) diff --git a/examples/Atlas/train_dr_atlas_multipleshooting.jl b/examples/Atlas/train_dr_atlas_multipleshooting.jl new file mode 100644 index 0000000..7751b74 --- /dev/null +++ b/examples/Atlas/train_dr_atlas_multipleshooting.jl @@ -0,0 +1,202 @@ +# Train DecisionRules.jl policy for Atlas Robot Balancing +# Using multiple shooting (windowed decomposition) + +using Flux +using DecisionRules +using Random +using Statistics +using JuMP +import Ipopt, HSL_jll +using Wandb, Dates, Logging +using JLD2 +using DiffOpt + +Atlas_dir = dirname(@__FILE__) +include(joinpath(Atlas_dir, "build_atlas_problem.jl")) + +# ============================================================================ +# Parameters +# ============================================================================ + +# Problem parameters +N = 11 # Number of time steps +h = 0.01 # Time step +perturbation_scale = 0.05 # Scale of random perturbations +num_scenarios = 10 # Number of uncertainty samples per stage +penalty = 1e3 # Penalty for state deviation +perturbation_frequency = 5 # Frequency of perturbations (every k stages) +window_size = 5 # Multiple shooting window length + +# Training parameters +num_epochs = 1 +num_batches = 100 +_num_train_per_batch = 1 +layers = Int64[64, 64] +activation = sigmoid +optimizers = [Flux.Adam(0.001)] + +# Save paths +model_dir = joinpath(Atlas_dir, "models") +mkpath(model_dir) +save_file = "atlas-balancing-shooting-N$(N)-w$(window_size)-$(now())" + +# ============================================================================ +# Build Subproblems +# ============================================================================ + +println("Building Atlas subproblems...") + +diff_optimizer = () -> DiffOpt.diff_optimizer(optimizer_with_attributes(Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma27" +)) + +@time subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, + _, _, x_ref, u_ref, atlas = build_atlas_subproblems(; + N = N, + h = h, + perturbation_scale = perturbation_scale, + num_scenarios = num_scenarios, + penalty = penalty, + perturbation_frequency = perturbation_frequency, + optimizer = diff_optimizer, +) + +nx = atlas.nx +nu = atlas.nu +n_perturb = length(uncertainty_samples[1]) # Number of perturbation parameters + +println("Atlas state dimension: $nx") +println("Atlas control dimension: $nu") +println("Number of perturbations: $n_perturb") +println("Number of stages: $(N-1)") +println("Window size: $window_size") + +# ============================================================================ +# Logging +# ============================================================================ + +lg = WandbLogger( + project = "DecisionRules-Atlas", + name = save_file, + config = Dict( + "N" => N, + "h" => h, + "perturbation_scale" => perturbation_scale, + "num_scenarios" => num_scenarios, + "penalty" => penalty, + "perturbation_frequency" => perturbation_frequency, + "window_size" => window_size, + "layers" => layers, + "activation" => string(activation), + "optimizer" => string(optimizers), + "nx" => nx, + "nu" => nu, + "training_method" => "multiple_shooting", + ) +) + +function record_loss(iter, model, loss, tag) + Wandb.log(lg, Dict(tag => loss)) + return false +end + +# ============================================================================ +# Define Neural Network Policy +# ============================================================================ + +# Policy architecture: LSTM processes perturbations, Dense combines with previous state +# This design is memory-efficient and allows the LSTM to focus on temporal patterns +n_uncertainties = length(uncertainty_samples[1]) +models = state_conditioned_policy(n_uncertainties, nx, nx, layers; + activation=activation, encoder_type=Flux.LSTM) + +println("Model architecture: StateConditionedPolicy") +println(" Encoder (LSTM): $n_uncertainties -> $(layers)") +println(" Combiner (Dense): $(layers[end]) + $nx -> $nx") + +# ============================================================================ +# Setup multiple shooting windows +# ============================================================================ + +windows = DecisionRules.setup_shooting_windows( + subproblems, + state_params_in, + state_params_out, + Float64.(initial_state), + uncertainty_samples; + window_size=window_size, + optimizer_factory=diff_optimizer, +) + +# ============================================================================ +# Initial Evaluation +# ============================================================================ + +println("\nEvaluating initial policy...") +Random.seed!(8788) +objective_values = [begin + uncertainty_sample = DecisionRules.sample(uncertainty_samples) + uncertainties_vec = [[Float32(u[2]) for u in stage_u] for stage_u in uncertainty_sample] + DecisionRules.simulate_multiple_shooting( + windows, + models, + Float32.(initial_state), + uncertainty_sample, + uncertainties_vec + ) +end for _ in 1:2] + +best_obj = mean(objective_values) +println("Initial objective: $best_obj") + +model_path = joinpath(model_dir, save_file * ".jld2") +save_control = SaveBest(best_obj, model_path) +convergence_criterium = StallingCriterium(200, best_obj, 0) + +# ============================================================================ +# Hyperparameter Adjustment +# ============================================================================ + +adjust_hyperparameters = (iter, opt_state, num_train_per_batch) -> begin + if iter % 2100 == 0 + num_train_per_batch = num_train_per_batch * 2 + end + return num_train_per_batch +end + +# ============================================================================ +# Training +# ============================================================================ + +println("\nStarting training with multiple shooting...") +println("Epochs: $num_epochs, Batches per epoch: $num_batches") + +for iter in 1:num_epochs + num_train_per_batch = _num_train_per_batch + train_multiple_shooting( + models, + initial_state, + windows, + () -> uncertainty_samples; + num_batches=num_batches, + num_train_per_batch=num_train_per_batch, + optimizer=optimizers[floor(Int, min(iter, length(optimizers)))], + record_loss=(iter, model, loss, tag) -> begin + if tag == "metrics/training_loss" + save_control(iter, model, loss) + record_loss(iter, model, loss, tag) + return convergence_criterium(iter, model, loss) + end + return record_loss(iter, model, loss, tag) + end, + adjust_hyperparameters=adjust_hyperparameters + ) +end + +# Finish logging +close(lg) + +println("\nModel saved to: $model_path") +println("Training complete!") diff --git a/examples/Atlas/visualize_atlas_policy.jl b/examples/Atlas/visualize_atlas_policy.jl index c1bc473..23272d7 100644 --- a/examples/Atlas/visualize_atlas_policy.jl +++ b/examples/Atlas/visualize_atlas_policy.jl @@ -28,8 +28,9 @@ model_path = nothing # Set to path of trained model, or nothing to use latest # Problem parameters (should match training) N = 50 # Number of time steps h = 0.01 # Time step -perturbation_scale = 0.05 # Scale of random perturbations -num_scenarios = 10 # Number of scenarios to simulate +perturbation_scale = 1.5 # Scale of random perturbations +num_scenarios = 1 # Number of scenarios to simulate +perturbation_frequency = 5 # Frequency of perturbations (every k stages) # Visualization options animate_robot = true # Whether to animate in MeshCat @@ -87,6 +88,7 @@ println("Control dimension: $nu") h = h, perturbation_scale = perturbation_scale, num_scenarios = num_scenarios, + perturbation_frequency = perturbation_frequency, ) # ============================================================================ @@ -133,8 +135,8 @@ for s in 1:num_scenarios # Sample perturbations perturbation_sample = DecisionRules.sample(uncertainty_samples) - # Record perturbations for this scenario - all_perturbations[s] = [p[1] for p in perturbation_sample] # First (and only) perturbation per stage + # Record perturbations for this scenario (first perturbation per stage) + all_perturbations[s] = [stage_u[1][2] for stage_u in perturbation_sample] # Simulate using the policy try diff --git a/examples/Atlas/visualize_atlas_policy_det_eq.jl b/examples/Atlas/visualize_atlas_policy_det_eq.jl new file mode 100644 index 0000000..9edd10d --- /dev/null +++ b/examples/Atlas/visualize_atlas_policy_det_eq.jl @@ -0,0 +1,194 @@ +# Visualize Trained Atlas Balancing Policy (Deterministic Equivalent) +# +# This script loads a trained policy and animates its performance on the Atlas +# balancing task using the deterministic equivalent formulation. + +using Flux +using DecisionRules +using Random +using Statistics +using JuMP +import Ipopt, HSL_jll +using JLD2 +using DiffOpt + +Atlas_dir = dirname(@__FILE__) +include(joinpath(Atlas_dir, "build_atlas_problem.jl")) +include(joinpath(Atlas_dir, "atlas_utils.jl")) +include(joinpath(Atlas_dir, "atlas_visualization.jl")) + +# ============================================================================ +# Configuration +# ============================================================================ + +# Model to load (modify this path to your trained model) +model_path = nothing # Set to path of trained model, or nothing to use latest + +# Problem parameters (should match training) +N = 50 # Number of time steps +h = 0.01 # Time step +perturbation_scale = 0.05 # Scale of random perturbations +num_scenarios = 10 # Number of scenarios to simulate +penalty = 1e3 # Penalty for state deviation +perturbation_frequency = 5 # Frequency of perturbations (every k stages) + +# Visualization options +animate_robot = true # Whether to animate in MeshCat + +# ============================================================================ +# Load Model +# ============================================================================ + +if isnothing(model_path) + model_dir = joinpath(Atlas_dir, "models") + if isdir(model_dir) + model_files = filter( + f -> endswith(f, ".jld2") && startswith(f, "atlas-balancing-deteq"), + readdir(model_dir), + ) + if !isempty(model_files) + model_files_full = [joinpath(model_dir, f) for f in model_files] + model_path = model_files_full[argmax([mtime(f) for f in model_files_full])] + println("Using latest model: $model_path") + end + end +end + +if isnothing(model_path) || !isfile(model_path) + println("No trained model found. Creating a random policy for visualization.") + use_random_policy = true +else + use_random_policy = false + println("Loading model from: $model_path") +end + +# ============================================================================ +# Setup Problem +# ============================================================================ + +println("\nSetting up Atlas problem (deterministic equivalent)...") +atlas = Atlas() +@load joinpath(Atlas_dir, "atlas_ref.jld2") x_ref u_ref + +nx = atlas.nx +nu = atlas.nu + +@time subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, + _, _, _, _, _ = build_atlas_subproblems(; + atlas = atlas, + x_ref = x_ref, + u_ref = u_ref, + N = N, + h = h, + perturbation_scale = perturbation_scale, + num_scenarios = num_scenarios, + penalty = penalty, + perturbation_frequency = perturbation_frequency, +) + +det_equivalent = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma97", +)) + +det_equivalent, uncertainty_samples_det = DecisionRules.deterministic_equivalent!( + det_equivalent, + subproblems, + state_params_in, + state_params_out, + initial_state, + uncertainty_samples, +) + +println("Atlas state dimension: $nx") +println("Atlas control dimension: $nu") +println("Number of stages: $(N - 1)") + +# ============================================================================ +# Load or Create Policy +# ============================================================================ + +layers = [64, 64] +activation = sigmoid + +n_uncertainties = length(uncertainty_samples_det[1]) +models = state_conditioned_policy(n_uncertainties, nx, nx, layers; + activation=activation, encoder_type=Flux.LSTM) + +if !use_random_policy + model_data = JLD2.load(model_path) + if haskey(model_data, "model_state") + Flux.loadmodel!(models, model_data["model_state"]) + println("Loaded trained model weights") + else + println("Warning: Could not find model_state in file, using random weights") + end +end + +# ============================================================================ +# Simulate Multiple Scenarios +# ============================================================================ + +println("\nSimulating $num_scenarios scenarios...") + +all_states = Vector{Vector{Vector{Float64}}}(undef, num_scenarios) +all_objectives = fill(NaN, num_scenarios) + +for s in 1:num_scenarios + Random.seed!(s * 100 + 42) + Flux.reset!(models) + + perturbation_sample = DecisionRules.sample(uncertainty_samples_det) + + try + obj = simulate_multistage( + det_equivalent, state_params_in, state_params_out, + initial_state, perturbation_sample, + models + ) + all_objectives[s] = obj + + states = Vector{Vector{Float64}}(undef, N) + states[1] = copy(initial_state) + for t in 1:N-1 + states[t + 1] = [value(state_params_out[t][i][2]) for i in 1:nx] + end + all_states[s] = states + + println("Scenario $s: objective = $(round(obj, digits=4))") + catch e + println("Scenario $s: FAILED - $e") + all_states[s] = [copy(initial_state) for _ in 1:N] + end +end + +valid_scenarios = findall(!isnan, all_objectives) +println("\nSuccessful scenarios: $(length(valid_scenarios))/$num_scenarios") +if !isempty(valid_scenarios) + println("Mean objective: $(round(mean(all_objectives[valid_scenarios]), digits=4))") + println("Std objective: $(round(std(all_objectives[valid_scenarios]), digits=4))") +end + +# ============================================================================ +# MeshCat Animation +# ============================================================================ + +if animate_robot + println("\nSetting up MeshCat visualizer...") + vis = Visualizer() + mvis = init_visualizer(atlas, vis) + + if !isempty(valid_scenarios) + best_scenario = valid_scenarios[argmin(all_objectives[valid_scenarios])] + println("Animating best scenario (scenario $best_scenario)...") + + X_animate = all_states[best_scenario] + animate!(atlas, mvis, X_animate, Δt=h) + + println("\nAnimation ready! Open MeshCat visualizer to view.") + println("Best scenario objective: $(all_objectives[best_scenario])") + end +end + +println("\nVisualization complete!") diff --git a/examples/Atlas/visualize_atlas_policy_multipleshooting.jl b/examples/Atlas/visualize_atlas_policy_multipleshooting.jl new file mode 100644 index 0000000..f402931 --- /dev/null +++ b/examples/Atlas/visualize_atlas_policy_multipleshooting.jl @@ -0,0 +1,203 @@ +# Visualize Trained Atlas Balancing Policy (Multiple Shooting) +# +# This script loads a trained policy and animates its performance on the Atlas +# balancing task using the multiple shooting (windowed) formulation. + +using Flux +using DecisionRules +using Random +using Statistics +using JuMP +import Ipopt, HSL_jll +using JLD2 +using DiffOpt + +Atlas_dir = dirname(@__FILE__) +include(joinpath(Atlas_dir, "build_atlas_problem.jl")) +include(joinpath(Atlas_dir, "atlas_utils.jl")) +include(joinpath(Atlas_dir, "atlas_visualization.jl")) + +# ============================================================================ +# Configuration +# ============================================================================ + +# Model to load (modify this path to your trained model) +model_path = nothing # Set to path of trained model, or nothing to use latest + +# Problem parameters (should match training) +N = 50 # Number of time steps +h = 0.01 # Time step +perturbation_scale = 0.05 # Scale of random perturbations +num_scenarios = 10 # Number of scenarios to simulate +penalty = 1e3 # Penalty for state deviation +perturbation_frequency = 5 # Frequency of perturbations (every k stages) +window_size = 5 # Multiple shooting window length + +# Visualization options +animate_robot = true # Whether to animate in MeshCat + +# ============================================================================ +# Load Model +# ============================================================================ + +if isnothing(model_path) + model_dir = joinpath(Atlas_dir, "models") + if isdir(model_dir) + model_files = filter( + f -> endswith(f, ".jld2") && startswith(f, "atlas-balancing-shooting"), + readdir(model_dir), + ) + if !isempty(model_files) + model_files_full = [joinpath(model_dir, f) for f in model_files] + model_path = model_files_full[argmax([mtime(f) for f in model_files_full])] + println("Using latest model: $model_path") + end + end +end + +if isnothing(model_path) || !isfile(model_path) + println("No trained model found. Creating a random policy for visualization.") + use_random_policy = true +else + use_random_policy = false + println("Loading model from: $model_path") +end + +# ============================================================================ +# Setup Problem +# ============================================================================ + +println("\nSetting up Atlas problem (multiple shooting)...") +atlas = Atlas() +@load joinpath(Atlas_dir, "atlas_ref.jld2") x_ref u_ref + +nx = atlas.nx +nu = atlas.nu + +diff_optimizer = () -> DiffOpt.diff_optimizer(optimizer_with_attributes(Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma27" +)) + +@time subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, + _, _, _, _, _ = build_atlas_subproblems(; + atlas = atlas, + x_ref = x_ref, + u_ref = u_ref, + N = N, + h = h, + perturbation_scale = perturbation_scale, + num_scenarios = num_scenarios, + penalty = penalty, + perturbation_frequency = perturbation_frequency, + optimizer = diff_optimizer, +) + +windows = DecisionRules.setup_shooting_windows( + subproblems, + state_params_in, + state_params_out, + Float64.(initial_state), + uncertainty_samples; + window_size=window_size, + optimizer_factory=diff_optimizer, +) + +println("Atlas state dimension: $nx") +println("Atlas control dimension: $nu") +println("Number of stages: $(N - 1)") +println("Window size: $window_size") + +# ============================================================================ +# Load or Create Policy +# ============================================================================ + +layers = [64, 64] +activation = sigmoid + +n_uncertainties = length(uncertainty_samples[1]) +models = state_conditioned_policy(n_uncertainties, nx, nx, layers; + activation=activation, encoder_type=Flux.LSTM) + +if !use_random_policy + model_data = JLD2.load(model_path) + if haskey(model_data, "model_state") + Flux.loadmodel!(models, model_data["model_state"]) + println("Loaded trained model weights") + else + println("Warning: Could not find model_state in file, using random weights") + end +end + +# ============================================================================ +# Simulate Multiple Scenarios +# ============================================================================ + +println("\nSimulating $num_scenarios scenarios...") + +all_states = Vector{Vector{Vector{Float64}}}(undef, num_scenarios) +all_objectives = fill(NaN, num_scenarios) + +for s in 1:num_scenarios + Random.seed!(s * 100 + 42) + Flux.reset!(models) + + perturbation_sample = DecisionRules.sample(uncertainty_samples) + uncertainties_vec = [[Float32(u[2]) for u in stage_u] for stage_u in perturbation_sample] + + try + obj = DecisionRules.simulate_multiple_shooting( + windows, + models, + Float32.(initial_state), + perturbation_sample, + uncertainties_vec + ) + all_objectives[s] = obj + + states = Vector{Vector{Float64}}(undef, N) + states[1] = copy(initial_state) + for window in windows + for (local_idx, t) in enumerate(window.stage_range) + states[t + 1] = [value(pair[2]) for pair in window.state_out_params[local_idx]] + end + end + all_states[s] = states + + println("Scenario $s: objective = $(round(obj, digits=4))") + catch e + println("Scenario $s: FAILED - $e") + all_states[s] = [copy(initial_state) for _ in 1:N] + end +end + +valid_scenarios = findall(!isnan, all_objectives) +println("\nSuccessful scenarios: $(length(valid_scenarios))/$num_scenarios") +if !isempty(valid_scenarios) + println("Mean objective: $(round(mean(all_objectives[valid_scenarios]), digits=4))") + println("Std objective: $(round(std(all_objectives[valid_scenarios]), digits=4))") +end + +# ============================================================================ +# MeshCat Animation +# ============================================================================ + +if animate_robot + println("\nSetting up MeshCat visualizer...") + vis = Visualizer() + mvis = init_visualizer(atlas, vis) + + if !isempty(valid_scenarios) + best_scenario = valid_scenarios[argmin(all_objectives[valid_scenarios])] + println("Animating best scenario (scenario $best_scenario)...") + + X_animate = all_states[best_scenario] + animate!(atlas, mvis, X_animate, Δt=h) + + println("\nAnimation ready! Open MeshCat visualizer to view.") + println("Best scenario objective: $(all_objectives[best_scenario])") + end +end + +println("\nVisualization complete!") From 769c1eb43f5480c8df7f477f4ac21c9b4b5a00ec Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Thu, 29 Jan 2026 14:33:45 -0500 Subject: [PATCH 2/3] update api and examples --- Project.toml | 2 +- examples/Atlas/build_atlas_problem.jl | 2 +- examples/Atlas/train_dr_atlas_det_eq.jl | 6 +- .../Atlas/train_dr_atlas_multipleshooting.jl | 12 +- examples/Atlas/visualize_atlas_policy.jl | 13 +- .../Atlas/visualize_atlas_policy_det_eq.jl | 194 ------------------ ...visualize_atlas_policy_multipleshooting.jl | 2 +- src/DecisionRules.jl | 2 +- src/utils.jl | 28 ++- 9 files changed, 48 insertions(+), 213 deletions(-) delete mode 100644 examples/Atlas/visualize_atlas_policy_det_eq.jl diff --git a/Project.toml b/Project.toml index 11ccce6..e825684 100644 --- a/Project.toml +++ b/Project.toml @@ -36,10 +36,10 @@ julia = "~1.9" [extras] CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" +Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" [targets] test = ["Test", "SCS", "Zygote", "Ipopt"] diff --git a/examples/Atlas/build_atlas_problem.jl b/examples/Atlas/build_atlas_problem.jl index 961c0bf..82e8e8d 100644 --- a/examples/Atlas/build_atlas_problem.jl +++ b/examples/Atlas/build_atlas_problem.jl @@ -56,7 +56,7 @@ function build_atlas_subproblems(; perturbation_frequency::Int = 1, perturbation_indices::Union{Nothing, Vector{Int}} = nothing, num_scenarios::Int = 10, - penalty::Float64 = 1e3, + penalty::Float64 = 10.0, penalty_l1::Union{Nothing, Float64} = nothing, penalty_l2::Union{Nothing, Float64} = nothing, optimizer = nothing, diff --git a/examples/Atlas/train_dr_atlas_det_eq.jl b/examples/Atlas/train_dr_atlas_det_eq.jl index c03fb4d..e7d33f7 100644 --- a/examples/Atlas/train_dr_atlas_det_eq.jl +++ b/examples/Atlas/train_dr_atlas_det_eq.jl @@ -24,9 +24,9 @@ include(joinpath(Atlas_dir, "build_atlas_problem.jl")) # Problem parameters N = 10 # Number of time steps h = 0.01 # Time step -perturbation_scale = 0.05 # Scale of random perturbations +perturbation_scale = 1.5 # Scale of random perturbations num_scenarios = 10 # Number of uncertainty samples per stage -penalty = 1e3 # Penalty for state deviation +penalty = 10.0 # Penalty for state deviation perturbation_frequency = 5 # Frequency of perturbations (every k stages) # Training parameters @@ -60,7 +60,7 @@ println("Building Atlas deterministic equivalent problem...") ) # Build deterministic equivalent -det_equivalent = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, +det_equivalent = DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0, "hsllib" => HSL_jll.libhsl_path, "linear_solver" => "ma97", diff --git a/examples/Atlas/train_dr_atlas_multipleshooting.jl b/examples/Atlas/train_dr_atlas_multipleshooting.jl index 7751b74..285b8be 100644 --- a/examples/Atlas/train_dr_atlas_multipleshooting.jl +++ b/examples/Atlas/train_dr_atlas_multipleshooting.jl @@ -10,6 +10,7 @@ import Ipopt, HSL_jll using Wandb, Dates, Logging using JLD2 using DiffOpt +import MathOptInterface as MOI Atlas_dir = dirname(@__FILE__) include(joinpath(Atlas_dir, "build_atlas_problem.jl")) @@ -19,11 +20,11 @@ include(joinpath(Atlas_dir, "build_atlas_problem.jl")) # ============================================================================ # Problem parameters -N = 11 # Number of time steps +N = 50 # Number of time steps h = 0.01 # Time step -perturbation_scale = 0.05 # Scale of random perturbations +perturbation_scale = 1.5 # Scale of random perturbations num_scenarios = 10 # Number of uncertainty samples per stage -penalty = 1e3 # Penalty for state deviation +penalty = 10.0 # Penalty for state deviation perturbation_frequency = 5 # Frequency of perturbations (every k stages) window_size = 5 # Multiple shooting window length @@ -130,6 +131,11 @@ windows = DecisionRules.setup_shooting_windows( optimizer_factory=diff_optimizer, ) +# loop over windows and set diffopt backend +for window in windows + MOI.set(window.model, DiffOpt.ModelConstructor(), DiffOpt.NonLinearProgram.Model) +end + # ============================================================================ # Initial Evaluation # ============================================================================ diff --git a/examples/Atlas/visualize_atlas_policy.jl b/examples/Atlas/visualize_atlas_policy.jl index 23272d7..f53f94b 100644 --- a/examples/Atlas/visualize_atlas_policy.jl +++ b/examples/Atlas/visualize_atlas_policy.jl @@ -23,7 +23,7 @@ include(joinpath(Atlas_dir, "atlas_visualization.jl")) # ============================================================================ # Model to load (modify this path to your trained model) -model_path = nothing # Set to path of trained model, or nothing to use latest +model_path = "./models/atlas-balancing-deteq-N10-2026-01-28T17:53:58.216.jld2" # Set to path of trained model, or nothing to use latest # Problem parameters (should match training) N = 50 # Number of time steps @@ -99,19 +99,16 @@ println("Control dimension: $nu") layers = [64, 64] activation = sigmoid -models = Chain( - Dense(nx, layers[1], activation), - x -> reshape(x, :, 1), - Flux.LSTM(layers[1] => layers[2]), - x -> x[:, end], - Dense(layers[2], nx) +n_uncertainties = length(uncertainty_samples[1]) +models = state_conditioned_policy(n_uncertainties, nx, nx, layers; + activation=activation, encoder_type=Flux.LSTM ) if !use_random_policy # Load trained weights model_data = JLD2.load(model_path) if haskey(model_data, "model_state") - Flux.loadmodel!(models, model_data["model_state"]) + Flux.loadmodel!(models, normalize_recur_state(model_data["model_state"])) println("Loaded trained model weights") else println("Warning: Could not find model_state in file, using random weights") diff --git a/examples/Atlas/visualize_atlas_policy_det_eq.jl b/examples/Atlas/visualize_atlas_policy_det_eq.jl deleted file mode 100644 index 9edd10d..0000000 --- a/examples/Atlas/visualize_atlas_policy_det_eq.jl +++ /dev/null @@ -1,194 +0,0 @@ -# Visualize Trained Atlas Balancing Policy (Deterministic Equivalent) -# -# This script loads a trained policy and animates its performance on the Atlas -# balancing task using the deterministic equivalent formulation. - -using Flux -using DecisionRules -using Random -using Statistics -using JuMP -import Ipopt, HSL_jll -using JLD2 -using DiffOpt - -Atlas_dir = dirname(@__FILE__) -include(joinpath(Atlas_dir, "build_atlas_problem.jl")) -include(joinpath(Atlas_dir, "atlas_utils.jl")) -include(joinpath(Atlas_dir, "atlas_visualization.jl")) - -# ============================================================================ -# Configuration -# ============================================================================ - -# Model to load (modify this path to your trained model) -model_path = nothing # Set to path of trained model, or nothing to use latest - -# Problem parameters (should match training) -N = 50 # Number of time steps -h = 0.01 # Time step -perturbation_scale = 0.05 # Scale of random perturbations -num_scenarios = 10 # Number of scenarios to simulate -penalty = 1e3 # Penalty for state deviation -perturbation_frequency = 5 # Frequency of perturbations (every k stages) - -# Visualization options -animate_robot = true # Whether to animate in MeshCat - -# ============================================================================ -# Load Model -# ============================================================================ - -if isnothing(model_path) - model_dir = joinpath(Atlas_dir, "models") - if isdir(model_dir) - model_files = filter( - f -> endswith(f, ".jld2") && startswith(f, "atlas-balancing-deteq"), - readdir(model_dir), - ) - if !isempty(model_files) - model_files_full = [joinpath(model_dir, f) for f in model_files] - model_path = model_files_full[argmax([mtime(f) for f in model_files_full])] - println("Using latest model: $model_path") - end - end -end - -if isnothing(model_path) || !isfile(model_path) - println("No trained model found. Creating a random policy for visualization.") - use_random_policy = true -else - use_random_policy = false - println("Loading model from: $model_path") -end - -# ============================================================================ -# Setup Problem -# ============================================================================ - -println("\nSetting up Atlas problem (deterministic equivalent)...") -atlas = Atlas() -@load joinpath(Atlas_dir, "atlas_ref.jld2") x_ref u_ref - -nx = atlas.nx -nu = atlas.nu - -@time subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, - _, _, _, _, _ = build_atlas_subproblems(; - atlas = atlas, - x_ref = x_ref, - u_ref = u_ref, - N = N, - h = h, - perturbation_scale = perturbation_scale, - num_scenarios = num_scenarios, - penalty = penalty, - perturbation_frequency = perturbation_frequency, -) - -det_equivalent = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, - "print_level" => 0, - "hsllib" => HSL_jll.libhsl_path, - "linear_solver" => "ma97", -)) - -det_equivalent, uncertainty_samples_det = DecisionRules.deterministic_equivalent!( - det_equivalent, - subproblems, - state_params_in, - state_params_out, - initial_state, - uncertainty_samples, -) - -println("Atlas state dimension: $nx") -println("Atlas control dimension: $nu") -println("Number of stages: $(N - 1)") - -# ============================================================================ -# Load or Create Policy -# ============================================================================ - -layers = [64, 64] -activation = sigmoid - -n_uncertainties = length(uncertainty_samples_det[1]) -models = state_conditioned_policy(n_uncertainties, nx, nx, layers; - activation=activation, encoder_type=Flux.LSTM) - -if !use_random_policy - model_data = JLD2.load(model_path) - if haskey(model_data, "model_state") - Flux.loadmodel!(models, model_data["model_state"]) - println("Loaded trained model weights") - else - println("Warning: Could not find model_state in file, using random weights") - end -end - -# ============================================================================ -# Simulate Multiple Scenarios -# ============================================================================ - -println("\nSimulating $num_scenarios scenarios...") - -all_states = Vector{Vector{Vector{Float64}}}(undef, num_scenarios) -all_objectives = fill(NaN, num_scenarios) - -for s in 1:num_scenarios - Random.seed!(s * 100 + 42) - Flux.reset!(models) - - perturbation_sample = DecisionRules.sample(uncertainty_samples_det) - - try - obj = simulate_multistage( - det_equivalent, state_params_in, state_params_out, - initial_state, perturbation_sample, - models - ) - all_objectives[s] = obj - - states = Vector{Vector{Float64}}(undef, N) - states[1] = copy(initial_state) - for t in 1:N-1 - states[t + 1] = [value(state_params_out[t][i][2]) for i in 1:nx] - end - all_states[s] = states - - println("Scenario $s: objective = $(round(obj, digits=4))") - catch e - println("Scenario $s: FAILED - $e") - all_states[s] = [copy(initial_state) for _ in 1:N] - end -end - -valid_scenarios = findall(!isnan, all_objectives) -println("\nSuccessful scenarios: $(length(valid_scenarios))/$num_scenarios") -if !isempty(valid_scenarios) - println("Mean objective: $(round(mean(all_objectives[valid_scenarios]), digits=4))") - println("Std objective: $(round(std(all_objectives[valid_scenarios]), digits=4))") -end - -# ============================================================================ -# MeshCat Animation -# ============================================================================ - -if animate_robot - println("\nSetting up MeshCat visualizer...") - vis = Visualizer() - mvis = init_visualizer(atlas, vis) - - if !isempty(valid_scenarios) - best_scenario = valid_scenarios[argmin(all_objectives[valid_scenarios])] - println("Animating best scenario (scenario $best_scenario)...") - - X_animate = all_states[best_scenario] - animate!(atlas, mvis, X_animate, Δt=h) - - println("\nAnimation ready! Open MeshCat visualizer to view.") - println("Best scenario objective: $(all_objectives[best_scenario])") - end -end - -println("\nVisualization complete!") diff --git a/examples/Atlas/visualize_atlas_policy_multipleshooting.jl b/examples/Atlas/visualize_atlas_policy_multipleshooting.jl index f402931..42c0c07 100644 --- a/examples/Atlas/visualize_atlas_policy_multipleshooting.jl +++ b/examples/Atlas/visualize_atlas_policy_multipleshooting.jl @@ -123,7 +123,7 @@ models = state_conditioned_policy(n_uncertainties, nx, nx, layers; if !use_random_policy model_data = JLD2.load(model_path) if haskey(model_data, "model_state") - Flux.loadmodel!(models, model_data["model_state"]) + Flux.loadmodel!(models, normalize_recur_state(model_data["model_state"])) println("Loaded trained model weights") else println("Warning: Could not find model_state in file, using random weights") diff --git a/src/DecisionRules.jl b/src/DecisionRules.jl index 8386d33..ef40044 100644 --- a/src/DecisionRules.jl +++ b/src/DecisionRules.jl @@ -11,7 +11,7 @@ using DiffOpt using Logging export simulate_multistage, sample, train_multistage, simulate_states, simulate_stage, dense_multilayer_nn, variable_to_parameter, create_deficit!, - SaveBest, find_variables, compute_parameter_dual, StallingCriterium, policy_input_dim, + SaveBest, find_variables, compute_parameter_dual, StallingCriterium, policy_input_dim, normalize_recur_state, StateConditionedPolicy, state_conditioned_policy, materialize_tangent, # Multiple shooting exports train_multiple_shooting, setup_shooting_windows, solve_window, predict_window_targets, diff --git a/src/utils.jl b/src/utils.jl index 361f86f..8eb001b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -103,12 +103,38 @@ mutable struct SaveBest <: Function best_loss::Float64 model_path::String end + +""" + normalize_recur_state(state) + +Return a copy of a `Flux.state` object where any Recur-like nodes have their +`state` field set to `cell.state0`. This avoids `Flux.loadmodel!` tie errors +when loading into freshly constructed recurrent layers. +""" +function normalize_recur_state(state) + if state isa NamedTuple + vals = map(normalize_recur_state, values(state)) + nt = NamedTuple{keys(state)}(vals) + if (:cell in keys(nt)) && (:state in keys(nt)) && + (getproperty(nt, :cell) isa NamedTuple) && (:state0 in keys(getproperty(nt, :cell))) + ks = keys(nt) + newvals = map(k -> (k === :state ? getproperty(nt.cell, :state0) : getproperty(nt, k)), ks) + return NamedTuple{ks}(newvals) + end + return nt + elseif state isa Tuple + return map(normalize_recur_state, state) + else + return state + end +end + function (callback::SaveBest)(iter, model, loss) if loss < callback.best_loss m = model |> cpu @info "best model change" callback.best_loss loss callback.best_loss = loss - model_state = Flux.state(m) + model_state = normalize_recur_state(Flux.state(m)) jldsave(callback.model_path; model_state=model_state) end return false From 6c18a2d76cb9a78d2c3e233cf999722f7ba33b52 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Thu, 29 Jan 2026 19:51:14 -0500 Subject: [PATCH 3/3] update api --- README.md | 41 +++++++------------ .../Atlas/train_dr_atlas_multipleshooting.jl | 13 +++--- examples/Atlas/visualize_atlas_policy.jl | 2 +- ...visualize_atlas_policy_multipleshooting.jl | 8 +++- .../check_consistent_state_paths.jl | 2 +- ...in_dr_hydropowermodels_multipleshooting.jl | 8 +++- src/multiple_shooting.jl | 9 ++-- test/runtests.jl | 30 +++++++------- 8 files changed, 55 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 19c3728..ee63342 100644 --- a/README.md +++ b/README.md @@ -130,10 +130,23 @@ policy = Chain( Dense(64, length(initial_state)), ) +windows = DecisionRules.setup_shooting_windows( + subproblems, + state_params_in, + state_params_out, + Float64.(initial_state), + uncertainty_samples; + window_size=24, + model_factory=() -> DiffOpt.nonlinear_diff_model(optimizer_with_attributes( + SCS.Optimizer, + "verbose" => 0, + )), +) + DecisionRules.train_multiple_shooting( policy, initial_state, - subproblems, + windows, state_params_in, state_params_out, uncertainty_sampler; @@ -141,32 +154,6 @@ DecisionRules.train_multiple_shooting( num_batches=100, num_train_per_batch=32, optimizer=Flux.Adam(1e-3), - optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer), -) -``` - -If you want lower-level control, you can pre-build the window models once: - -```julia -windows = DecisionRules.setup_shooting_windows( - subproblems, - state_params_in, - state_params_out, - Float64.(initial_state), - uncertainties_structure; - window_size=24, - optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer), -) - -uncertainty_sample = uncertainty_sampler() -uncertainties_vec = [[Float32(u[2]) for u in stage_u] for stage_u in uncertainty_sample] - -obj = DecisionRules.simulate_multiple_shooting( - windows, - policy, - Float32.(initial_state), - uncertainty_sample, - uncertainties_vec, ) ``` diff --git a/examples/Atlas/train_dr_atlas_multipleshooting.jl b/examples/Atlas/train_dr_atlas_multipleshooting.jl index 285b8be..8086bc2 100644 --- a/examples/Atlas/train_dr_atlas_multipleshooting.jl +++ b/examples/Atlas/train_dr_atlas_multipleshooting.jl @@ -53,6 +53,12 @@ diff_optimizer = () -> DiffOpt.diff_optimizer(optimizer_with_attributes(Ipopt.Op "linear_solver" => "ma27" )) +diff_model = () -> DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma27" +)) + @time subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, _, _, x_ref, u_ref, atlas = build_atlas_subproblems(; N = N, @@ -128,14 +134,9 @@ windows = DecisionRules.setup_shooting_windows( Float64.(initial_state), uncertainty_samples; window_size=window_size, - optimizer_factory=diff_optimizer, + model_factory=diff_model, ) -# loop over windows and set diffopt backend -for window in windows - MOI.set(window.model, DiffOpt.ModelConstructor(), DiffOpt.NonLinearProgram.Model) -end - # ============================================================================ # Initial Evaluation # ============================================================================ diff --git a/examples/Atlas/visualize_atlas_policy.jl b/examples/Atlas/visualize_atlas_policy.jl index f53f94b..9eb3734 100644 --- a/examples/Atlas/visualize_atlas_policy.jl +++ b/examples/Atlas/visualize_atlas_policy.jl @@ -23,7 +23,7 @@ include(joinpath(Atlas_dir, "atlas_visualization.jl")) # ============================================================================ # Model to load (modify this path to your trained model) -model_path = "./models/atlas-balancing-deteq-N10-2026-01-28T17:53:58.216.jld2" # Set to path of trained model, or nothing to use latest +model_path = "./models/atlas-balancing-deteq-N10-2026-01-29T12:52:04.657.jld2" # Set to path of trained model, or nothing to use latest # Problem parameters (should match training) N = 50 # Number of time steps diff --git a/examples/Atlas/visualize_atlas_policy_multipleshooting.jl b/examples/Atlas/visualize_atlas_policy_multipleshooting.jl index 42c0c07..9d5abca 100644 --- a/examples/Atlas/visualize_atlas_policy_multipleshooting.jl +++ b/examples/Atlas/visualize_atlas_policy_multipleshooting.jl @@ -80,6 +80,12 @@ diff_optimizer = () -> DiffOpt.diff_optimizer(optimizer_with_attributes(Ipopt.Op "linear_solver" => "ma27" )) +diff_model = () -> DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma27" +)) + @time subproblems, state_params_in, state_params_out, initial_state, uncertainty_samples, _, _, _, _, _ = build_atlas_subproblems(; atlas = atlas, @@ -101,7 +107,7 @@ windows = DecisionRules.setup_shooting_windows( Float64.(initial_state), uncertainty_samples; window_size=window_size, - optimizer_factory=diff_optimizer, + model_factory=diff_model, ) println("Atlas state dimension: $nx") diff --git a/examples/HydroPowerModels/check_consistent_state_paths.jl b/examples/HydroPowerModels/check_consistent_state_paths.jl index 2cb6358..34aa410 100644 --- a/examples/HydroPowerModels/check_consistent_state_paths.jl +++ b/examples/HydroPowerModels/check_consistent_state_paths.jl @@ -91,7 +91,7 @@ windows = DecisionRules.setup_shooting_windows( Float64.(initial_state_w), uncert_w; window_size=window_size, - optimizer_factory=diff_optimizer, + model_factory=diff_optimizer, ) uncertainties_w = [[(stage_u[i][1], base_values[t][i]) for i in eachindex(stage_u)] diff --git a/examples/HydroPowerModels/train_dr_hydropowermodels_multipleshooting.jl b/examples/HydroPowerModels/train_dr_hydropowermodels_multipleshooting.jl index ebd1bb4..58a59c4 100644 --- a/examples/HydroPowerModels/train_dr_hydropowermodels_multipleshooting.jl +++ b/examples/HydroPowerModels/train_dr_hydropowermodels_multipleshooting.jl @@ -49,6 +49,12 @@ diff_optimizer = () -> DiffOpt.diff_optimizer(optimizer_with_attributes(Ipopt.Op "linear_solver" => "ma27" )) +diff_model = () -> DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, + "print_level" => 0, + "hsllib" => HSL_jll.libhsl_path, + "linear_solver" => "ma27" +)) + subproblems, state_params_in, state_params_out, uncertainty_samples, initial_state, max_volume = build_hydropowermodels( joinpath(HydroPowerModels_dir, case_name), formulation_file; num_stages=num_stages, @@ -105,7 +111,7 @@ windows = DecisionRules.setup_shooting_windows( Float64.(initial_state), uncertainty_samples; window_size=window_size, - optimizer_factory=diff_optimizer, + model_factory=diff_model, ) objective_values = [begin diff --git a/src/multiple_shooting.jl b/src/multiple_shooting.jl index 948eef5..23d4bd4 100644 --- a/src/multiple_shooting.jl +++ b/src/multiple_shooting.jl @@ -478,7 +478,7 @@ end """ setup_shooting_windows(subproblems, state_params_in, state_params_out, initial_state, - uncertainties; window_size, optimizer_factory=nothing) + uncertainties; window_size, model_factory=() -> JuMP.Model()) Build window models for multiple shooting. @@ -492,7 +492,7 @@ function setup_shooting_windows( initial_state::Vector{Float64}, uncertainties; # typically Vector{Vector{Tuple{VariableRef,Vector{Float64}}}} or similar window_size::Int, - optimizer_factory=nothing, + model_factory=() -> JuMP.Model(), ) where {U} num_stages = length(subproblems) @@ -510,10 +510,7 @@ function setup_shooting_windows( window_state_params_out = [state_params_out[t] for t in stage_range] window_uncertainties = uncertainties[stage_range] - window_model = JuMP.Model() - if optimizer_factory !== nothing - set_optimizer(window_model, optimizer_factory) - end + window_model = model_factory() # Build window equivalent model without mutating the originals. window_model, diff --git a/test/runtests.jl b/test/runtests.jl index 6efbd94..2028261 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -426,7 +426,7 @@ end uncertainty_samples = Vector{Vector{Tuple{VariableRef, Vector{Float64}}}}(undef, num_stages) for t in 1:num_stages - subproblems[t] = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "verbose" => 0)) + subproblems[t] = DiffOpt.diff_model(Ipopt.Optimizer) @variable(subproblems[t], x[1:3] >= 0) @variable(subproblems[t], state_in in MOI.Parameter(1.0)) @variable(subproblems[t], uncertainty in MOI.Parameter(0.5)) @@ -454,7 +454,7 @@ end initial_state, uncertainty_samples; window_size=window_size, - optimizer_factory=() -> DiffOpt.diff_optimizer(Ipopt.Optimizer) + model_factory=() -> DiffOpt.nonlinear_diff_model(Ipopt.Optimizer) ) @test length(windows) == 3 # 6 stages / 2 window_size = 3 windows @@ -479,7 +479,7 @@ end initial_state, uncertainty_samples; window_size=4, # 6 stages / 4 = 2 windows, last window has 2 stages - optimizer_factory=() -> DiffOpt.diff_optimizer(Ipopt.Optimizer) + model_factory=() -> DiffOpt.nonlinear_diff_model(Ipopt.Optimizer) ) @test length(windows_odd) == 2 @@ -535,7 +535,7 @@ end uncertainty_samples = Vector{Vector{Tuple{VariableRef, Vector{Float64}}}}(undef, num_stages) for t in 1:num_stages - subproblems[t] = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "verbose" => 0)) + subproblems[t] = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0)) @variable(subproblems[t], x[1:4] >= 0) @variable(subproblems[t], state_in in MOI.Parameter(1.0)) @variable(subproblems[t], uncertainty in MOI.Parameter(0.5)) @@ -562,7 +562,7 @@ end initial_state, uncertainty_samples; window_size=2, - optimizer_factory=() -> DiffOpt.diff_optimizer(Ipopt.Optimizer) + model_factory=() -> DiffOpt.nonlinear_diff_model(Ipopt.Optimizer) ) @test length(windows) == 1 # Only one window for 2 stages with window_size=2 @@ -630,7 +630,7 @@ end [1.0], uncertainty_samples; window_size=1, - optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer) + model_factory=() -> DiffOpt.conic_diff_model(SCS.Optimizer) ) window = windows[1] @@ -662,7 +662,7 @@ end uncertainty_samples = Vector{Vector{Tuple{VariableRef, Vector{Float64}}}}(undef, num_stages) for t in 1:num_stages - subproblems[t] = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "verbose" => 0)) + subproblems[t] = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0)) @variable(subproblems[t], x[1:4] >= 0) @variable(subproblems[t], state_in in MOI.Parameter(1.0)) @variable(subproblems[t], uncertainty in MOI.Parameter(0.5)) @@ -688,7 +688,7 @@ end initial_state, uncertainty_samples; window_size=2, - optimizer_factory=() -> DiffOpt.diff_optimizer(Ipopt.Optimizer) + model_factory=() -> DiffOpt.nonlinear_diff_model(Ipopt.Optimizer) ) window = windows[1] @@ -733,7 +733,7 @@ end uncertainty_samples = Vector{Vector{Tuple{VariableRef, Vector{Float64}}}}(undef, num_stages) for t in 1:num_stages - subproblems[t] = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "verbose" => 0)) + subproblems[t] = DiffOpt.diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0)) @variable(subproblems[t], x[1:4] >= 0) @variable(subproblems[t], state_in in MOI.Parameter(1.0)) @variable(subproblems[t], uncertainty in MOI.Parameter(0.5)) @@ -760,7 +760,7 @@ end initial_state, uncertainty_samples; window_size=2, - optimizer_factory=() -> DiffOpt.diff_optimizer(Ipopt.Optimizer) + model_factory=() -> DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0)) ) @test length(windows) == 2 @@ -808,7 +808,7 @@ end uncertainty_samples = Vector{Vector{Tuple{VariableRef, Vector{Float64}}}}(undef, num_stages) for t in 1:num_stages - subproblems[t] = DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, "verbose" => 0)) + subproblems[t] = DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0)) @variable(subproblems[t], x[1:4] >= 0) @variable(subproblems[t], state_in in MOI.Parameter(1.0)) @variable(subproblems[t], uncertainty in MOI.Parameter(0.5)) @@ -834,7 +834,7 @@ end initial_state, uncertainty_samples; window_size=2, - optimizer_factory=() -> DiffOpt.diff_optimizer(Ipopt.Optimizer) + model_factory=() -> DiffOpt.nonlinear_diff_model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0)) ) # Set uncertainty values @@ -899,7 +899,7 @@ end Float64.(initial_state), uncertainty_samples; window_size=2, - optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer) + model_factory=() -> DiffOpt.conic_diff_model(SCS.Optimizer) ) # Policy expects flat [u1, u2, state] input @@ -952,7 +952,7 @@ end [1.0], uncertainty_samples; window_size=1, - optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer), + model_factory=() -> DiffOpt.conic_diff_model(SCS.Optimizer), ) DecisionRules.train_multiple_shooting( @@ -1068,7 +1068,7 @@ end Float64.(initial_state), uncertainty_samples_w; window_size=6, - optimizer_factory=() -> DiffOpt.diff_optimizer(SCS.Optimizer), + model_factory=() -> DiffOpt.conic_diff_model(SCS.Optimizer), ) # Variable count checks stage_var_count = sum(length.(all_variables.(subproblems_s)))