Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
68f45f4
Refactor DeepSSM: add constants module and reproducible seeding
akenmorris Jan 6, 2026
6ad0111
Add type hints to DeepSSM public API functions
akenmorris Jan 7, 2026
4008a1e
Add config schema validation for DeepSSM
akenmorris Jan 7, 2026
6fa9933
Improve error handling in DeepSSM data loaders
akenmorris Jan 7, 2026
bce52eb
Add error handling to loaders and --exact_check option
akenmorris Jan 7, 2026
1290941
Add --tl_net flag and fix TL-DeepSSM bugs
akenmorris Jan 7, 2026
1c65c2d
Validate --exact_check and --tl_net are only used with deep_ssm
akenmorris Jan 7, 2026
0efe52c
Use separate exact_check files for standard and tl_net modes
akenmorris Jan 7, 2026
881bc34
Add GTest-based tests for DeepSSM that use shapeworks project files.
akenmorris Jan 8, 2026
a16e305
Add result verification to DeepSSM tests
akenmorris Jan 8, 2026
45f4e25
Add documentation and extended test infrastructure for DeepSSM
akenmorris Jan 9, 2026
d200fbe
Fix DeepSSM command arg parsing after return value fix
akenmorris Jan 9, 2026
8be54c0
Fix toMesh pipeline and add empty mesh validation
akenmorris Jan 14, 2026
f2114c8
Return identity transform for empty meshes in ICP
akenmorris Jan 14, 2026
c50c21b
Add streaming data loaders to reduce DeepSSM memory usage
akenmorris Jan 14, 2026
ab35dfd
Fix bounding box calculation and add error handling in run_utils
akenmorris Jan 14, 2026
36a6810
Fail with clear errors instead of silently skipping missing files
akenmorris Feb 4, 2026
ca9c7a3
Reduce DeepSSM tests from 4 to 2 configurations
akenmorris Feb 5, 2026
fb5e2d7
Resolve #2487 - Auto subset size in grooming should pick a smart auto
akenmorris Feb 5, 2026
9a650bd
Merge pull request #2488 from SCIInstitute/amorris/2487-auto-referenc…
akenmorris Feb 5, 2026
c77de4c
Refactor DeepSSM: add constants module and reproducible seeding
akenmorris Jan 6, 2026
3554d87
Add type hints to DeepSSM public API functions
akenmorris Jan 7, 2026
5b8c663
Add config schema validation for DeepSSM
akenmorris Jan 7, 2026
5c81566
Improve error handling in DeepSSM data loaders
akenmorris Jan 7, 2026
fc58e52
Add error handling to loaders and --exact_check option
akenmorris Jan 7, 2026
3b2e4ae
Add --tl_net flag and fix TL-DeepSSM bugs
akenmorris Jan 7, 2026
6566ecf
Validate --exact_check and --tl_net are only used with deep_ssm
akenmorris Jan 7, 2026
83d7789
Use separate exact_check files for standard and tl_net modes
akenmorris Jan 7, 2026
36a499c
Add GTest-based tests for DeepSSM that use shapeworks project files.
akenmorris Jan 8, 2026
0728968
Add result verification to DeepSSM tests
akenmorris Jan 8, 2026
ba4a193
Add documentation and extended test infrastructure for DeepSSM
akenmorris Jan 9, 2026
22cab67
Fix DeepSSM command arg parsing after return value fix
akenmorris Jan 9, 2026
3e05885
Fix toMesh pipeline and add empty mesh validation
akenmorris Jan 14, 2026
3951ffe
Return identity transform for empty meshes in ICP
akenmorris Jan 14, 2026
d8dba88
Add streaming data loaders to reduce DeepSSM memory usage
akenmorris Jan 14, 2026
29f4165
Fix bounding box calculation and add error handling in run_utils
akenmorris Jan 14, 2026
22436d1
Fail with clear errors instead of silently skipping missing files
akenmorris Feb 4, 2026
3660dd5
Reduce DeepSSM tests from 4 to 2 configurations
akenmorris Feb 5, 2026
1225995
Update baselines
akenmorris Feb 6, 2026
f40d62c
Merge branch 'deepssm_refactor2' of github.com:SCIInstitute/ShapeWork…
akenmorris Feb 6, 2026
70cd306
Fix tests
akenmorris Feb 6, 2026
ff77d86
Fix SW_MAJOR_VERSION
akenmorris Feb 6, 2026
53205b7
Fixes for deepssm tests
akenmorris Feb 6, 2026
16add77
Fix CI tests
akenmorris Feb 6, 2026
a18ee3a
Set OMP_NUM_THREADS=1 for windows CI deepssm
akenmorris Feb 7, 2026
5334d33
Add --aug_processes option to avoid multiprocessing hang on Windows CI
akenmorris Feb 7, 2026
56bb670
Test run to debug windows CI
akenmorris Feb 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Applications/shapeworks/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Command {
const std::string desc() const { return parser.description(); }

/// parses the arguments for this command, saving them in the parser and returning the leftovers
std::vector<std::string> parse_args(const std::vector<std::string> &arguments);
virtual std::vector<std::string> parse_args(const std::vector<std::string> &arguments);

/// calls execute for this command using the parsed args, returning system exit value
int run(SharedCommandData &sharedData);
Expand Down Expand Up @@ -108,6 +108,12 @@ class DeepSSMCommandGroup : public Command
public:
const std::string type() override { return "DeepSSM"; }

// DeepSSM is a terminal command - don't pass remaining args to other commands
std::vector<std::string> parse_args(const std::vector<std::string> &arguments) override {
Command::parse_args(arguments);
return {}; // return empty - DeepSSM consumes all args
}

private:
};

Expand Down
37 changes: 23 additions & 14 deletions Applications/shapeworks/Commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,12 @@ void DeepSSMCommand::buildParser() {
.set_default(0)
.help("Number of data loader workers (default: 0)");

parser.add_option("--aug_processes")
.action("store")
.type("int")
.set_default(0)
.help("Number of augmentation processes (default: 0 = use all cores)");

Command::buildParser();
}

Expand Down Expand Up @@ -413,12 +419,14 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
bool do_test = options.is_set("test") || options.is_set("all");

int num_workers = static_cast<int>(options.get("num_workers"));
int aug_processes = static_cast<int>(options.get("aug_processes"));

std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n";
std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n";
std::cout << "Train step: " << (do_train ? "on" : "off") << "\n";
std::cout << "Test step: " << (do_test ? "on" : "off") << "\n";
std::cout << "Num dataloader workers: " << num_workers << "\n";
std::cout << "Prep step: " << (do_prep ? "on" : "off") << std::endl;
std::cout << "Augment step: " << (do_augment ? "on" : "off") << std::endl;
std::cout << "Train step: " << (do_train ? "on" : "off") << std::endl;
std::cout << "Test step: " << (do_test ? "on" : "off") << std::endl;
std::cout << "Num dataloader workers: " << num_workers << std::endl;
std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << std::endl;

if (!do_prep && !do_augment && !do_train && !do_test) {
do_prep = true;
Expand Down Expand Up @@ -462,44 +470,45 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
SW_ERROR("Unknown prep step: {}", prep_step);
return false;
}
std::cout << "Running DeepSSM preparation step...\n";
std::cerr << "Running DeepSSM preparation step..." << std::endl;
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM preparation step completed.\n";
std::cerr << "DeepSSM preparation step completed." << std::endl;
}
if (do_augment) {
std::cout << "Running DeepSSM data augmentation...\n";
std::cerr << "Running DeepSSM data augmentation..." << std::endl;
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType);
job->set_aug_processes(aug_processes);
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM data augmentation completed.\n";
std::cerr << "DeepSSM data augmentation completed." << std::endl;
}
if (do_train) {
std::cout << "Running DeepSSM training...\n";
std::cerr << "Running DeepSSM training..." << std::endl;
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType);
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM training completed.\n";
std::cerr << "DeepSSM training completed." << std::endl;
}
if (do_test) {
std::cout << "Running DeepSSM testing...\n";
std::cerr << "Running DeepSSM testing..." << std::endl;
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TestingType);
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM testing completed.\n";
std::cerr << "DeepSSM testing completed." << std::endl;
}

project->save();

return false;
return true;
}

} // namespace shapeworks
6 changes: 6 additions & 0 deletions Examples/Python/RunUseCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,14 @@
parser.add_argument("--tiny_test", help="Run as a short test", action="store_true")
parser.add_argument("--verify", help="Run as a full test", action="store_true")
parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true")
parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)",
choices=["save", "verify"])
args = parser.parse_args()

# Validate deep_ssm-specific arguments
if args.exact_check and args.use_case != "deep_ssm":
parser.error("--exact_check is only supported for the deep_ssm use case")

type = ""
if args.tiny_test:
type = "tiny_test_"
Expand Down
62 changes: 45 additions & 17 deletions Examples/Python/deep_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def Run_Pipeline(args):
This data is comprised of femur meshes and corresponding hip CT scans.
"""

if platform.system() == "Darwin":
# On MacOS, CPU PyTorch is hanging with parallel
if platform.system() != "Linux":
# CPU PyTorch hangs with OpenMP parallelism on macOS and Windows
os.environ['OMP_NUM_THREADS'] = "1"
# If running a tiny_test, then download subset of the data
if args.tiny_test:
Expand Down Expand Up @@ -396,6 +396,7 @@ def Run_Pipeline(args):
"c_lat": 6.3
}
}

if args.tiny_test:
model_parameters["trainer"]["epochs"] = 1
# Save config file
Expand Down Expand Up @@ -436,17 +437,17 @@ def Run_Pipeline(args):
val_world_particles.append(project_path + subjects[index].get_world_particle_filenames()[0])
val_mesh_files.append(project_path + subjects[index].get_groomed_filenames()[0])

val_out_dir = output_directory + model_name + '/validation_predictions/'
predicted_val_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='validation')
print("Validation world predictions saved.")
# Generate local predictions
local_val_prediction_dir = val_out_dir + 'local_predictions/'
# Generate local predictions - create directory next to world_predictions
world_pred_dir = os.path.dirname(predicted_val_world_particles[0])
local_val_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions")
if not os.path.exists(local_val_prediction_dir):
os.makedirs(local_val_prediction_dir)
predicted_val_local_particles = []
for particle_file, transform in zip(predicted_val_world_particles, val_transforms):
particles = np.loadtxt(particle_file)
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
local_particle_file = particle_file.replace("world_predictions", "local_predictions")
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
np.savetxt(local_particle_file, local_particles)
predicted_val_local_particles.append(local_particle_file)
Expand All @@ -462,6 +463,8 @@ def Run_Pipeline(args):
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]
# Get distance between clipped true and predicted meshes
# Get the validation output directory from the predictions path
val_out_dir = os.path.dirname(local_val_prediction_dir.rstrip('/')) + '/'
mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_val_local_particles, val_mesh_files,
template_particles, template_mesh, val_out_dir,
planes=val_planes)
Expand Down Expand Up @@ -500,17 +503,17 @@ def Run_Pipeline(args):
with open(plane_file) as json_file:
test_planes.append(json.load(json_file)['planes'][0]['points'])

test_out_dir = output_directory + model_name + '/test_predictions/'
predicted_test_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='test')
print("Test world predictions saved.")
# Generate local predictions
local_test_prediction_dir = test_out_dir + 'local_predictions/'
# Generate local predictions - create directory next to world_predictions
world_pred_dir = os.path.dirname(predicted_test_world_particles[0])
local_test_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions")
if not os.path.exists(local_test_prediction_dir):
os.makedirs(local_test_prediction_dir)
predicted_test_local_particles = []
for particle_file, transform in zip(predicted_test_world_particles, test_transforms):
particles = np.loadtxt(particle_file)
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
local_particle_file = particle_file.replace("world_predictions", "local_predictions")
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
np.savetxt(local_particle_file, local_particles)
predicted_test_local_particles.append(local_particle_file)
Expand All @@ -524,28 +527,53 @@ def Run_Pipeline(args):
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]

# Get the test output directory from the predictions path
test_out_dir = os.path.dirname(local_test_prediction_dir.rstrip('/')) + '/'
mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_test_local_particles, test_mesh_files,
template_particles, template_mesh, test_out_dir,
planes=test_planes)
print("Test mean mesh surface-to-surface distance: " + str(mean_dist))

DeepSSMUtils.process_test_predictions(project, config_file)
final_mean_dist = DeepSSMUtils.process_test_predictions(project, config_file)

# If tiny test or verify, check results and exit
check_results(args, mean_dist)
check_results(args, final_mean_dist, output_directory)

open(status_dir + "step_12.txt", 'w').close()

print("All steps complete")


# Verification
def check_results(args, mean_dist):
def check_results(args, mean_dist, output_directory):
if args.tiny_test:
print("\nVerifying use case results.")
if not math.isclose(mean_dist, 10, rel_tol=1):
print("Test failed.")
exit(-1)

exact_check_file = output_directory + "exact_check_value.txt"

# Exact check for refactoring verification (platform-specific)
if args.exact_check == "save":
with open(exact_check_file, "w") as f:
f.write(str(mean_dist))
print(f"Saved exact check value to: {exact_check_file}")
print(f"Value: {mean_dist}")
elif args.exact_check == "verify":
if not os.path.exists(exact_check_file):
print(f"Error: No saved value found at {exact_check_file}")
print("Run with --exact_check save first to create baseline.")
exit(-1)
with open(exact_check_file, "r") as f:
expected_mean_dist = float(f.read().strip())
if mean_dist != expected_mean_dist:
print(f"Exact check failed: expected {expected_mean_dist}, got {mean_dist}")
exit(-1)
print(f"Exact check passed: {mean_dist}")
else:
# Relaxed check for CI/cross-platform
if not math.isclose(mean_dist, 10, rel_tol=1):
print("Test failed.")
exit(-1)

print("Done with test, verification succeeded.")
exit(0)
else:
Expand Down
11 changes: 9 additions & 2 deletions Libs/Application/DeepSSM/DeepSSMJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,12 @@ void DeepSSMJob::run_augmentation() {
py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils");
py::object run_data_aug = py_deep_ssm_utils.attr("run_data_augmentation");

int processes = aug_processes_ > 0 ? aug_processes_ : QThread::idealThreadCount();

int aug_dims = run_data_aug(project_, params.get_aug_num_samples(),
0 /* num dims, set to zero to allow percent variability to be used */,
params.get_aug_percent_variability(), sampler_type.toStdString(), 0 /* mixture_num */,
QThread::idealThreadCount() /* processes */
)
processes)
.cast<int>();

params.set_training_num_dims(aug_dims);
Expand Down Expand Up @@ -394,6 +395,12 @@ void DeepSSMJob::set_num_dataloader_workers(int num_workers) { num_dataloader_wo
//---------------------------------------------------------------------------
int DeepSSMJob::get_num_dataloader_workers() { return num_dataloader_workers_; }

//---------------------------------------------------------------------------
void DeepSSMJob::set_aug_processes(int processes) { aug_processes_ = processes; }

//---------------------------------------------------------------------------
int DeepSSMJob::get_aug_processes() { return aug_processes_; }

//---------------------------------------------------------------------------
void DeepSSMJob::update_prep_stage(PrepStep step) {
/*
Expand Down
4 changes: 4 additions & 0 deletions Libs/Application/DeepSSM/DeepSSMJob.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class DeepSSMJob : public Job {
void set_num_dataloader_workers(int num_workers);
int get_num_dataloader_workers();

void set_aug_processes(int processes);
int get_aug_processes();

void set_prep_step(DeepSSMJob::PrepStep step) {
std::lock_guard<std::mutex> lock(mutex_);
prep_step_ = step;
Expand All @@ -72,6 +75,7 @@ class DeepSSMJob : public Job {
DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED};

int num_dataloader_workers_{0};
int aug_processes_{0};

// mutex
std::mutex mutex_;
Expand Down
5 changes: 2 additions & 3 deletions Libs/Application/Job/PythonWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,10 @@ bool PythonWorker::init() {
path = QString::fromStdString(line);
}
file.close();
qputenv("PATH", path.toUtf8());
SW_LOG("Setting PATH for Python to: " + path.toStdString());
}

qputenv("PATH", path.toUtf8());
SW_LOG("Setting PATH for Python to: " + path.toStdString());

// Python 3.8+ requires explicit DLL directory registration
// PATH environment variable is no longer used for DLL search
SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_USER_DIRS);
Expand Down
17 changes: 15 additions & 2 deletions Libs/Groom/Groom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,17 @@ bool Groom::image_pipeline(std::shared_ptr<Subject> subject, size_t domain) {
std::string groomed_name = get_output_filename(original, DomainType::Image);

if (params.get_convert_to_mesh()) {
// Use isovalue 0.0 for distance transforms (the zero level set is the surface)
Mesh mesh = image.toMesh(0.0);
if (mesh.numPoints() == 0) {
throw std::runtime_error("Empty mesh generated from segmentation - segmentation may have no valid data");
}
// Check for valid cells
auto poly_data = mesh.getVTKMesh();
if (poly_data->GetNumberOfCells() == 0) {
throw std::runtime_error("Mesh has no cells - segmentation may have no valid surface");
}
SW_DEBUG("Mesh after toMesh: {} points, {} cells", poly_data->GetNumberOfPoints(), poly_data->GetNumberOfCells());
run_mesh_pipeline(mesh, params, original);
groomed_name = get_output_filename(original, DomainType::Mesh);
// save the groomed mesh
Expand Down Expand Up @@ -239,6 +249,9 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) {
// crop
if (params.get_crop()) {
PhysicalRegion region = image.physicalBoundingBox(0.5);
if (!region.valid()) {
throw std::runtime_error("Empty segmentation - no voxels found above threshold for cropping");
}
image.crop(region);
increment_progress();
}
Expand Down Expand Up @@ -560,7 +573,7 @@ bool Groom::run_alignment() {
bool any_alignment = false;

int reference_index = -1;
int subset_size = -1;
int subset_size = base_params.get_alignment_subset_size();

// per-domain alignment
for (size_t domain = 0; domain < num_domains; domain++) {
Expand Down Expand Up @@ -1336,7 +1349,7 @@ std::vector<std::vector<double>> Groom::get_icp_transforms(const std::vector<Mes
matrix->Identity();

Mesh source = meshes[i];
if (source.getVTKMesh()->GetNumberOfPoints() != 0) {
if (source.getVTKMesh()->GetNumberOfPoints() != 0 && reference.getVTKMesh()->GetNumberOfPoints() != 0) {
// create copies for thread safety
auto poly_data1 = vtkSmartPointer<vtkPolyData>::New();
poly_data1->DeepCopy(source.getVTKMesh());
Expand Down
Loading
Loading