From 68f45f4a67c89f39820bd87a5e6b1dab4002ba20 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 15:39:27 -0700 Subject: [PATCH 01/47] Refactor DeepSSM: add constants module and reproducible seeding * Add constants.py to centralize magic strings (file names, loader names, device strings) for improved maintainability * Add set_seed() function in net_utils.py for reproducible training by seeding Python random, NumPy, PyTorch CPU/CUDA, and cuDNN * Update loaders.py, trainer.py, model.py, eval.py to use constants * Export constants and set_seed from __init__.py Verified: test outputs are identical before and after refactoring. --- .../DeepSSMUtils/__init__.py | 4 + .../DeepSSMUtils/constants.py | 73 +++++++++++++++++++ .../DeepSSMUtilsPackage/DeepSSMUtils/eval.py | 9 ++- .../DeepSSMUtils/loaders.py | 27 +++---- .../DeepSSMUtilsPackage/DeepSSMUtils/model.py | 21 +++--- .../DeepSSMUtils/net_utils.py | 21 +++++- .../DeepSSMUtils/trainer.py | 53 +++++++------- 7 files changed, 154 insertions(+), 54 deletions(-) create mode 100644 Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 135f4f0a05..738b914861 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -6,6 +6,10 @@ from DeepSSMUtils import train_viz from DeepSSMUtils import image_utils from DeepSSMUtils import run_utils +from DeepSSMUtils import net_utils +from DeepSSMUtils import constants + +from .net_utils import set_seed from .run_utils import create_split, groom_training_shapes, groom_training_images, \ run_data_augmentation, groom_val_test_images, prep_project_for_val_particles, groom_validation_shapes, \ diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py new file mode 100644 index 0000000000..db912adcf6 --- /dev/null +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py @@ -0,0 +1,73 @@ +""" +Constants used throughout DeepSSM. + +This module centralizes magic strings and default values to improve +maintainability and reduce errors from typos. +""" + +# Model file names +BEST_MODEL_FILE = "best_model.torch" +FINAL_MODEL_FILE = "final_model.torch" +BEST_MODEL_FT_FILE = "best_model_ft.torch" +FINAL_MODEL_FT_FILE = "final_model_ft.torch" +FINAL_MODEL_AE_FILE = "final_model_ae.torch" +FINAL_MODEL_TF_FILE = "final_model_tf.torch" + +# Data loader names +TRAIN_LOADER = "train" +VALIDATION_LOADER = "validation" +TEST_LOADER = "test" + +# File names for saved statistics +MEAN_PCA_FILE = "mean_PCA.npy" +STD_PCA_FILE = "std_PCA.npy" +MEAN_IMG_FILE = "mean_img.npy" +STD_IMG_FILE = "std_img.npy" + +# Names files +TRAIN_NAMES_FILE = "train_names.txt" +VALIDATION_NAMES_FILE = "validation_names.txt" +TEST_NAMES_FILE = "test_names.txt" + +# Log and plot files +TRAIN_LOG_FILE = "train_log.csv" +TRAINING_PLOT_FILE = "training_plot.png" +TRAINING_PLOT_FT_FILE = "training_plot_ft.png" +TRAINING_PLOT_AE_FILE = "training_plot_ae.png" +TRAINING_PLOT_TF_FILE = "training_plot_tf.png" +TRAINING_PLOT_JOINT_FILE = "training_plot_joint.png" + +# PCA info directory and files +PCA_INFO_DIR = "PCA_Particle_Info" +PCA_MEAN_FILE = "mean.particles" +PCA_MODE_FILE_TEMPLATE = "pcamode{}.particles" + +# Prediction directories +WORLD_PREDICTIONS_DIR = "world_predictions" +PCA_PREDICTIONS_DIR = "pca_predictions" +LOCAL_PREDICTIONS_DIR = "local_predictions" + +# Examples directory +EXAMPLES_DIR = "examples" +TRAIN_EXAMPLES_PREFIX = "train_" +VALIDATION_EXAMPLES_PREFIX = "validation_" + +# Training stage names (for logging) +class TrainingStage: + BASE = "Base_Training" + FINE_TUNING = "Fine_Tuning" + AUTOENCODER = "AE" + T_FLANK = "T-Flank" + JOINT = "Joint" + +# Default values +class Defaults: + BATCH_SIZE = 1 + DOWN_FACTOR = 1 + TRAIN_SPLIT = 0.80 + NUM_WORKERS = 0 + VAL_FREQ = 1 + +# Device strings +DEVICE_CUDA = "cuda:0" +DEVICE_CPU = "cpu" diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py index 5f7fe30e36..90b4fbd0c1 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py @@ -6,6 +6,7 @@ import torch from torch.utils.data import DataLoader from DeepSSMUtils import model, loaders +from DeepSSMUtils import constants as C from shapeworks.utils import sw_message from shapeworks.utils import sw_progress from shapeworks.utils import sw_check_abort @@ -24,9 +25,9 @@ def test(config_file, loader="test"): pred_dir = model_dir + loader + '_predictions/' loaders.make_dir(pred_dir) if parameters["use_best_model"]: - model_path = model_dir + 'best_model.torch' + model_path = model_dir + C.BEST_MODEL_FILE else: - model_path = model_dir + 'final_model.torch' + model_path = model_dir + C.FINAL_MODEL_FILE if parameters["fine_tune"]["enabled"]: model_path_ft = model_path.replace(".torch", "_ft.torch") else: @@ -67,9 +68,9 @@ def test(config_file, loader="test"): index = 0 pred_scores = [] - pred_path = pred_dir + 'world_predictions/' + pred_path = pred_dir + C.WORLD_PREDICTIONS_DIR + '/' loaders.make_dir(pred_path) - pred_path_pca = pred_dir + 'pca_predictions/' + pred_path_pca = pred_dir + C.PCA_PREDICTIONS_DIR + '/' loaders.make_dir(pred_path_pca) predicted_particle_files = [] diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index 5573b7e9db..b3a1ca8ff7 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader import shapeworks as sw from shapeworks.utils import sw_message +from DeepSSMUtils import constants as C random.seed(1) ######################## Data loading functions #################################### @@ -44,7 +45,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + 'train' + train_path = loader_dir + C.TRAIN_LOADER torch.save(trainloader, train_path) validationloader = DataLoader( @@ -54,7 +55,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + 'validation' + val_path = loader_dir + C.VALIDATION_LOADER torch.save(validationloader, val_path) sw_message("Training and validation loaders complete.\n") return train_path, val_path @@ -77,7 +78,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + 'train' + train_path = loader_dir + C.TRAIN_LOADER torch.save(trainloader, train_path) sw_message("Training loader complete.") return train_path @@ -102,10 +103,10 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 mdl = get_particles(val_particles[index]) models.append(mdl) # Write test names to file so they are saved somewhere - name_file = open(loader_dir + 'validation_names.txt', 'w+') + name_file = open(loader_dir + C.VALIDATION_NAMES_FILE, 'w+') name_file.write(str(names)) name_file.close() - sw_message("Validation names saved to: " + loader_dir + "validation_names.txt") + sw_message("Validation names saved to: " + loader_dir + C.VALIDATION_NAMES_FILE) images = get_images(loader_dir, image_paths, down_factor, down_dir) val_data = DeepSSMdataset(images, scores, models, names) # Make loader @@ -116,7 +117,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + 'validation' + val_path = loader_dir + C.VALIDATION_LOADER torch.save(val_loader, val_path) sw_message("Validation loader complete.") return val_path @@ -143,10 +144,10 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num images = get_images(loader_dir, image_paths, down_factor, down_dir) test_data = DeepSSMdataset(images, scores, models, test_names) # Write test names to file so they are saved somewhere - name_file = open(loader_dir + 'test_names.txt', 'w+') + name_file = open(loader_dir + C.TEST_NAMES_FILE, 'w+') name_file.write(str(test_names)) name_file.close() - sw_message("Test names saved to: " + loader_dir + "test_names.txt") + sw_message("Test names saved to: " + loader_dir + C.TEST_NAMES_FILE) # Make loader testloader = DataLoader( test_data, @@ -155,7 +156,7 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - test_path = loader_dir + 'test' + test_path = loader_dir + C.TEST_LOADER torch.save(testloader, test_path) sw_message("Test loader complete.") return test_path, test_names @@ -268,8 +269,8 @@ def get_images(loader_dir, image_list, down_factor, down_dir): all_images = np.array(all_images) # get mean and std - mean_path = loader_dir + 'mean_img.npy' - std_path = loader_dir + 'std_img.npy' + mean_path = loader_dir + C.MEAN_IMG_FILE + std_path = loader_dir + C.STD_IMG_FILE mean_image = np.mean(all_images) std_image = np.std(all_images) np.save(mean_path, mean_image) @@ -305,8 +306,8 @@ def whiten_PCA_scores(scores, loader_dir): scores = np.array(scores) mean_score = np.mean(scores, 0) std_score = np.std(scores, 0) - np.save(loader_dir + 'mean_PCA.npy', mean_score) - np.save(loader_dir + 'std_PCA.npy', std_score) + np.save(loader_dir + C.MEAN_PCA_FILE, mean_score) + np.save(loader_dir + C.STD_PCA_FILE, std_score) norm_scores = [] for score in scores: norm_scores.append((score-mean_score)/std_score) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index f512f2e244..a0ffb81d4b 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -5,6 +5,7 @@ import numpy as np from collections import OrderedDict from DeepSSMUtils import net_utils +from DeepSSMUtils import constants as C class ConvolutionalBackbone(nn.Module): @@ -61,9 +62,9 @@ class DeterministicEncoder(nn.Module): def __init__(self, num_latent, img_dims, loader_dir): super(DeterministicEncoder, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device self.num_latent = num_latent self.img_dims = img_dims @@ -97,15 +98,15 @@ class DeepSSMNet(nn.Module): def __init__(self, config_file): super(DeepSSMNet, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device - with open(config_file) as json_file: + with open(config_file) as json_file: parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + "validation", weights_only=False) + loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) self.num_corr = loader.dataset.mdl_target[0].shape[0] img_dims = loader.dataset.img[0].shape self.img_dims = img_dims[1:] @@ -169,15 +170,15 @@ class DeepSSMNet_TLNet(nn.Module): def __init__(self, conflict_file): super(DeepSSMNet_TLNet, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device - with open(conflict_file) as json_file: + with open(conflict_file) as json_file: parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + "validation") + loader = torch.load(self.loader_dir + C.VALIDATION_LOADER) self.num_corr = loader.dataset.mdl_target[0].shape[0] img_dims = loader.dataset.img[0].shape self.img_dims = img_dims[1:] diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py index 3ffa0a9014..792bc6ab85 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py @@ -1,6 +1,23 @@ +import random import torch from torch import nn import numpy as np +from DeepSSMUtils import constants as C + + +def set_seed(seed=42): + """ + Set random seeds for reproducibility across all random number generators. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + class Flatten(nn.Module): def forward(self, x): @@ -14,8 +31,8 @@ def poolOutDim(inDim, kernel_size, padding=0, stride=0, dilation=1): return outDim def unwhiten_PCA_scores(torch_loading, loader_dir, device): - mean_score = torch.from_numpy(np.load(loader_dir + '/mean_PCA.npy')).to(device).float() - std_score = torch.from_numpy(np.load(loader_dir + '/std_PCA.npy')).to(device).float() + mean_score = torch.from_numpy(np.load(loader_dir + '/' + C.MEAN_PCA_FILE)).to(device).float() + std_score = torch.from_numpy(np.load(loader_dir + '/' + C.STD_PCA_FILE)).to(device).float() mean_score = mean_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) std_score = std_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) pca_new = torch_loading*(std_score) + mean_score diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index f73e26fb34..c80a776244 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -13,6 +13,8 @@ from DeepSSMUtils import losses from DeepSSMUtils import train_viz from DeepSSMUtils import loaders +from DeepSSMUtils import net_utils +from DeepSSMUtils import constants as C import DeepSSMUtils from shapeworks.utils import * @@ -68,6 +70,7 @@ def set_scheduler(opt, sched_params): def train(project, config_file): + net_utils.set_seed(42) sw.utils.initialize_project_mesh_warper(project) with open(config_file) as json_file: @@ -101,8 +104,8 @@ def supervised_train(config_file): fine_tune = parameters['fine_tune']['enabled'] loss_func = method_to_call = getattr(losses, parameters["loss"]["function"]) # load the loaders - train_loader_path = loader_dir + "train" - validation_loader_path = loader_dir + "validation" + train_loader_path = loader_dir + C.TRAIN_LOADER + validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") train_loader = torch.load(train_loader_path, weights_only=False) val_loader = torch.load(validation_loader_path, weights_only=False) @@ -119,8 +122,8 @@ def supervised_train(config_file): net.apply(weight_init(module=nn.Linear, initf=nn.init.xavier_normal_)) # these lines are for the fine tuning layer initialization - whiten_mean = np.load(loader_dir + '/mean_PCA.npy') - whiten_std = np.load(loader_dir + '/std_PCA.npy') + whiten_mean = np.load(loader_dir + '/' + C.MEAN_PCA_FILE) + whiten_std = np.load(loader_dir + '/' + C.STD_PCA_FILE) orig_mean = np.loadtxt(aug_dir + '/PCA_Particle_Info/mean.particles') orig_pc = np.zeros([num_pca, num_corr * 3]) for i in range(num_pca): @@ -146,7 +149,7 @@ def supervised_train(config_file): # train print("Beginning training on device = " + device + '\n') # Initialize logger - logger = open(model_dir + "train_log.csv", "w+", buffering=1) + logger = open(model_dir + C.TRAIN_LOG_FILE, "w+", buffering=1) log_print(logger, ["Training_Stage", "Epoch", "LR", "Train_Err", "Train_Rel_Err", "Val_Err", "Val_Rel_Err", "Sec"]) # Initialize training plot train_plot = plt.figure() @@ -158,7 +161,7 @@ def supervised_train(config_file): axe.set_xlim(0, num_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -241,17 +244,17 @@ def supervised_train(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_FILE) # save if val_rel_err < best_val_rel_error: best_val_rel_error = val_rel_err best_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FILE)) t0 = time.time() if decay_lr: scheduler.step() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FILE)) parameters['best_model_epochs'] = best_epoch with open(config_file, "w") as json_file: json.dump(parameters, json_file, indent=2) @@ -290,7 +293,7 @@ def supervised_train(config_file): axe.set_xlim(0, ft_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_ft.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_FT_FILE, dpi=300) epochs = [] plot_train_losses = [] plot_val_losses = [] @@ -355,7 +358,7 @@ def supervised_train(config_file): if val_rel_loss < best_ft_val_rel_error: best_ft_val_rel_error = val_rel_loss best_ft_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model_ft.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FT_FILE)) pred_particles.extend(pred_mdl.detach().cpu().numpy()) true_particles.extend(mdl.detach().cpu().numpy()) train_viz.write_examples(np.array(pred_particles), np.array(true_particles), val_names, @@ -376,12 +379,12 @@ def supervised_train(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_ft.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_FT_FILE) t0 = time.time() logger.close() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_ft.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FT_FILE)) parameters['best_ft_model_epochs'] = best_ft_epoch with open(config_file, "w") as json_file: @@ -411,8 +414,8 @@ def supervised_train_tl(config_file): a_lat = parameters["tl_net"]["a_lat"] c_lat = parameters["tl_net"]["c_lat"] # load the loaders - train_loader_path = loader_dir + "train" - validation_loader_path = loader_dir + "validation" + train_loader_path = loader_dir + C.TRAIN_LOADER + validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") train_loader = torch.load(train_loader_path) val_loader = torch.load(validation_loader_path) @@ -447,7 +450,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, ae_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_ae.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_AE_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -540,10 +543,10 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_ae.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_AE_FILE) t0 = time.time() # save - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_ae.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_AE_FILE)) # fix the autoencoder and train the TL-net for param in net.CorrespondenceDecoder.parameters(): param.requires_grad = False @@ -563,7 +566,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, tf_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_tf.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_TF_FILE, dpi=300) # initialize t0 = time.time() epochs = [] @@ -650,10 +653,10 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_tf.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_TF_FILE) t0 = time.time() # save - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_tf.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_TF_FILE)) # jointly train the model joint_epochs = parameters['tl_net']['joint_epochs'] alpha = parameters['tl_net']['alpha'] @@ -673,7 +676,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, joint_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_joint.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_JOINT_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -771,19 +774,19 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_joint.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_JOINT_FILE) # save val_rel_err = val_rel_ae_err + alpha * val_rel_tf_err if val_rel_err < best_val_rel_error: best_val_rel_error = val_rel_err best_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FILE)) t0 = time.time() if decay_lr: scheduler.step() logger.close() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FILE)) parameters['best_model_epochs'] = best_epoch with open(config_file, "w") as json_file: json.dump(parameters, json_file, indent=2) From 6ad0111f2a8815089fd792f8d0ff1965971ab104 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 18:18:54 -0700 Subject: [PATCH 02/47] Add type hints to DeepSSM public API functions --- .../DeepSSMUtils/__init__.py | 134 ++++++++++++++++-- .../DeepSSMUtils/net_utils.py | 49 ++++++- 2 files changed, 166 insertions(+), 17 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 738b914861..561b798162 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -1,3 +1,5 @@ +from typing import List, Optional, Tuple, Any + from DeepSSMUtils import trainer from DeepSSMUtils import loaders from DeepSSMUtils import eval @@ -20,65 +22,171 @@ import torch -def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): +def getTrainValLoaders( + loader_dir: str, + aug_data_csv: str, + batch_size: int = 1, + down_factor: float = 1, + down_dir: Optional[str] = None, + train_split: float = 0.80, + num_workers: int = 0 +) -> None: + """Create training and validation data loaders from augmented data CSV.""" testPytorch() loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): +def getTrainLoader( + loader_dir: str, + aug_data_csv: str, + batch_size: int = 1, + down_factor: float = 1, + down_dir: Optional[str] = None, + train_split: float = 0.80, + num_workers: int = 0 +) -> None: + """Create training data loader from augmented data CSV.""" testPytorch() loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): +def getValidationLoader( + loader_dir: str, + val_img_list: List[str], + val_particles: List[str], + down_factor: float = 1, + down_dir: Optional[str] = None, + num_workers: int = 0 +) -> None: + """Create validation data loader from image and particle lists.""" loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir, num_workers) -def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): +def getTestLoader( + loader_dir: str, + test_img_list: List[str], + down_factor: float = 1, + down_dir: Optional[str] = None, + num_workers: int = 0 +) -> None: + """Create test data loader from image list.""" loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir, num_workers) -def prepareConfigFile(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate, - decay_lr, fine_tune, fine_tune_epochs, fine_tune_learning_rate): +def prepareConfigFile( + config_filename: str, + model_name: str, + embedded_dim: int, + out_dir: str, + loader_dir: str, + aug_dir: str, + epochs: int, + learning_rate: float, + decay_lr: bool, + fine_tune: bool, + fine_tune_epochs: int, + fine_tune_learning_rate: float +) -> None: + """Prepare a DeepSSM configuration file with the specified parameters.""" config_file.prepare_config_file(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate, decay_lr, fine_tune, fine_tune_epochs, fine_tune_learning_rate) -def trainDeepSSM(project, config_file): +def trainDeepSSM(project: Any, config_file: str) -> None: + """Train a DeepSSM model using the given project and configuration file.""" testPytorch() trainer.train(project, config_file) return -def testDeepSSM(config_file, loader="test"): +def testDeepSSM(config_file: str, loader: str = "test") -> List[str]: + """ + Test a trained DeepSSM model and return predicted particle files. + + Args: + config_file: Path to the configuration JSON file + loader: Which loader to use ("test" or "validation") + + Returns: + List of paths to predicted particle files + """ predicted_particle_files = eval.test(config_file, loader) return predicted_particle_files -def analyzeMSE(predicted_particles, true_particles): +def analyzeMSE( + predicted_particles: List[str], + true_particles: List[str] +) -> Tuple[float, float]: + """ + Analyze mean squared error between predicted and true particles. + + Returns: + Tuple of (mean_MSE, std_MSE) + """ mean_MSE, STD_MSE = eval_utils.get_MSE(predicted_particles, true_particles) return mean_MSE, STD_MSE -def analyzeMeshDistance(predicted_particles, mesh_files, template_particles, template_mesh, out_dir, planes=None): +def analyzeMeshDistance( + predicted_particles: List[str], + mesh_files: List[str], + template_particles: str, + template_mesh: str, + out_dir: str, + planes: Optional[Any] = None +) -> float: + """ + Analyze mesh distance between predicted particles and ground truth meshes. + + Returns: + Mean surface-to-surface distance + """ mean_distance = eval_utils.get_mesh_distance(predicted_particles, mesh_files, template_particles, template_mesh, out_dir, planes) return mean_distance -def analyzeResults(out_dir, DT_dir, prediction_dir, mean_prefix): +def analyzeResults( + out_dir: str, + DT_dir: str, + prediction_dir: str, + mean_prefix: str +) -> float: + """ + Analyze results by computing distance between predicted and ground truth meshes. + + Returns: + Average surface distance + """ avg_distance = eval_utils.get_distance_meshes(out_dir, DT_dir, prediction_dir, mean_prefix) return avg_distance -def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): +def get_image_registration_transform( + fixed_image_file: str, + moving_image_file: str, + transform_type: str = 'rigid' +) -> Any: + """ + Compute image registration transform between two images. + + Args: + fixed_image_file: Path to the fixed/reference image + moving_image_file: Path to the moving image to be registered + transform_type: Type of transform ('rigid', 'affine', etc.) + + Returns: + ITK transform object + """ itk_transform = image_utils.get_image_registration_transform(fixed_image_file, moving_image_file, transform_type=transform_type) return itk_transform -def testPytorch(): +def testPytorch() -> None: + """Check if PyTorch is using GPU and print a warning if not.""" if torch.cuda.is_available(): print("Running on GPU.") else: diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py index 792bc6ab85..447d60bdb9 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py @@ -5,9 +5,12 @@ from DeepSSMUtils import constants as C -def set_seed(seed=42): +def set_seed(seed: int = 42) -> None: """ Set random seeds for reproducibility across all random number generators. + + Args: + seed: Integer seed value for random number generators """ random.seed(seed) np.random.seed(seed) @@ -20,17 +23,55 @@ def set_seed(seed=42): class Flatten(nn.Module): - def forward(self, x): + """Flatten layer to reshape tensor for fully connected layers.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(x.size(0), -1) -def poolOutDim(inDim, kernel_size, padding=0, stride=0, dilation=1): + +def poolOutDim( + inDim: int, + kernel_size: int, + padding: int = 0, + stride: int = 0, + dilation: int = 1 +) -> int: + """ + Calculate output dimension after pooling operation. + + Args: + inDim: Input dimension size + kernel_size: Size of the pooling kernel + padding: Padding applied to input + stride: Stride of pooling (defaults to kernel_size if 0) + dilation: Dilation factor + + Returns: + Output dimension size after pooling + """ if stride == 0: stride = kernel_size num = inDim + 2*padding - dilation*(kernel_size - 1) - 1 outDim = int(np.floor(num/stride + 1)) return outDim -def unwhiten_PCA_scores(torch_loading, loader_dir, device): + +def unwhiten_PCA_scores( + torch_loading: torch.Tensor, + loader_dir: str, + device: str +) -> torch.Tensor: + """ + Unwhiten (denormalize) PCA scores using saved mean and std. + + Args: + torch_loading: Whitened PCA scores tensor + loader_dir: Directory containing mean_PCA.npy and std_PCA.npy + device: Device to load tensors to ('cuda:0' or 'cpu') + + Returns: + Unwhitened PCA scores tensor + """ mean_score = torch.from_numpy(np.load(loader_dir + '/' + C.MEAN_PCA_FILE)).to(device).float() std_score = torch.from_numpy(np.load(loader_dir + '/' + C.STD_PCA_FILE)).to(device).float() mean_score = mean_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) From 4008a1eaa2baa107b90d2eac59e47065b20e670a Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 18:29:54 -0700 Subject: [PATCH 03/47] Add config schema validation for DeepSSM --- .../DeepSSMUtils/__init__.py | 1 + .../DeepSSMUtils/config_validation.py | 205 ++++++++++++++++++ .../DeepSSMUtils/trainer.py | 5 +- 3 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 561b798162..eb51ab3904 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -10,6 +10,7 @@ from DeepSSMUtils import run_utils from DeepSSMUtils import net_utils from DeepSSMUtils import constants +from DeepSSMUtils import config_validation from .net_utils import set_seed diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py new file mode 100644 index 0000000000..0725711558 --- /dev/null +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py @@ -0,0 +1,205 @@ +""" +Configuration file validation for DeepSSM. + +This module provides validation for DeepSSM config files to catch +errors early with clear error messages. +""" +import os +import json +from typing import Any, Dict, List, Optional + + +class ConfigValidationError(Exception): + """Raised when config validation fails.""" + pass + + +# Schema definition for DeepSSM config +CONFIG_SCHEMA = { + "model_name": {"type": str, "required": True}, + "num_latent_dim": {"type": int, "required": True, "min": 1}, + "paths": { + "type": dict, + "required": True, + "children": { + "out_dir": {"type": str, "required": True}, + "loader_dir": {"type": str, "required": True}, + "aug_dir": {"type": str, "required": True}, + } + }, + "encoder": { + "type": dict, + "required": True, + "children": { + "deterministic": {"type": bool, "required": True}, + } + }, + "decoder": { + "type": dict, + "required": True, + "children": { + "deterministic": {"type": bool, "required": True}, + "linear": {"type": bool, "required": True}, + } + }, + "loss": { + "type": dict, + "required": True, + "children": { + "function": {"type": str, "required": True, "choices": ["MSE", "Focal"]}, + "supervised_latent": {"type": bool, "required": True}, + } + }, + "trainer": { + "type": dict, + "required": True, + "children": { + "epochs": {"type": int, "required": True, "min": 1}, + "learning_rate": {"type": (int, float), "required": True, "min": 0}, + "val_freq": {"type": int, "required": True, "min": 1}, + "decay_lr": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "type": {"type": str, "required": False, "choices": ["Step", "CosineAnnealing"]}, + "parameters": {"type": dict, "required": False}, + } + }, + } + }, + "fine_tune": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "loss": {"type": str, "required": False, "choices": ["MSE", "Focal"]}, + "epochs": {"type": int, "required": False, "min": 1}, + "learning_rate": {"type": (int, float), "required": False, "min": 0}, + "val_freq": {"type": int, "required": False, "min": 1}, + } + }, + "use_best_model": {"type": bool, "required": True}, + "tl_net": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "ae_epochs": {"type": int, "required": False, "min": 1}, + "tf_epochs": {"type": int, "required": False, "min": 1}, + "joint_epochs": {"type": int, "required": False, "min": 1}, + "alpha": {"type": (int, float), "required": False}, + "a_ae": {"type": (int, float), "required": False}, + "c_ae": {"type": (int, float), "required": False}, + "a_lat": {"type": (int, float), "required": False}, + "c_lat": {"type": (int, float), "required": False}, + } + }, +} + + +def validate_config(config_path: str) -> Dict[str, Any]: + """ + Validate a DeepSSM configuration file. + + Args: + config_path: Path to the JSON configuration file + + Returns: + Validated configuration dictionary + + Raises: + ConfigValidationError: If validation fails + FileNotFoundError: If config file doesn't exist + json.JSONDecodeError: If config file is not valid JSON + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + + with open(config_path) as f: + try: + config = json.load(f) + except json.JSONDecodeError as e: + raise ConfigValidationError(f"Invalid JSON in config file: {e}") + + errors = _validate_dict(config, CONFIG_SCHEMA, "config") + + if errors: + error_msg = "Config validation failed:\n" + "\n".join(f" - {e}" for e in errors) + raise ConfigValidationError(error_msg) + + return config + + +def _validate_dict( + data: Dict[str, Any], + schema: Dict[str, Any], + path: str +) -> List[str]: + """ + Recursively validate a dictionary against a schema. + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + for key, rules in schema.items(): + full_path = f"{path}.{key}" + value = data.get(key) + + # Check required fields + if rules.get("required", False) and key not in data: + errors.append(f"Missing required field: {full_path}") + continue + + if key not in data: + continue + + # Check type + expected_type = rules.get("type") + if expected_type and not isinstance(value, expected_type): + type_name = expected_type.__name__ if isinstance(expected_type, type) else str(expected_type) + errors.append(f"Invalid type for {full_path}: expected {type_name}, got {type(value).__name__}") + continue + + # Check min value + if "min" in rules and isinstance(value, (int, float)): + if value < rules["min"]: + errors.append(f"Value too small for {full_path}: {value} < {rules['min']}") + + # Check choices + if "choices" in rules and value not in rules["choices"]: + errors.append(f"Invalid value for {full_path}: '{value}' not in {rules['choices']}") + + # Recurse into nested dicts + if expected_type == dict and "children" in rules: + errors.extend(_validate_dict(value, rules["children"], full_path)) + + return errors + + +def validate_paths_exist(config: Dict[str, Any], check_loader_dir: bool = True) -> List[str]: + """ + Validate that required paths in config exist. + + Args: + config: Configuration dictionary + check_loader_dir: Whether to check if loader_dir exists + + Returns: + List of warning messages for missing paths + """ + warnings = [] + paths = config.get("paths", {}) + + if check_loader_dir: + loader_dir = paths.get("loader_dir", "") + if loader_dir and not os.path.exists(loader_dir): + warnings.append(f"Loader directory does not exist: {loader_dir}") + + aug_dir = paths.get("aug_dir", "") + if aug_dir and not os.path.exists(aug_dir): + warnings.append(f"Augmentation directory does not exist: {aug_dir}") + + return warnings diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index c80a776244..869e37b8a8 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -15,6 +15,7 @@ from DeepSSMUtils import loaders from DeepSSMUtils import net_utils from DeepSSMUtils import constants as C +from DeepSSMUtils import config_validation import DeepSSMUtils from shapeworks.utils import * @@ -73,8 +74,8 @@ def train(project, config_file): net_utils.set_seed(42) sw.utils.initialize_project_mesh_warper(project) - with open(config_file) as json_file: - parameters = json.load(json_file) + # Validate config file before training + parameters = config_validation.validate_config(config_file) if parameters["tl_net"]["enabled"]: supervised_train_tl(config_file) else: From 6fa993359f34c71563ca83accd85e4776d12ef5e Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 18:41:12 -0700 Subject: [PATCH 04/47] Improve error handling in DeepSSM data loaders --- .../DeepSSMUtils/loaders.py | 114 ++++++++++++------ 1 file changed, 79 insertions(+), 35 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index b3a1ca8ff7..7a6661e064 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -12,6 +12,11 @@ from DeepSSMUtils import constants as C random.seed(1) + +class DataLoadingError(Exception): + """Raised when data loading fails.""" + pass + ######################## Data loading functions #################################### ''' @@ -88,6 +93,14 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir ''' def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating validation torch loader:") + if not val_img_list: + raise DataLoadingError("Validation image list is empty") + if not val_particles: + raise DataLoadingError("Validation particle list is empty") + if len(val_img_list) != len(val_particles): + raise DataLoadingError( + f"Mismatched validation data: {len(val_img_list)} images but {len(val_particles)} particle files" + ) # Get data image_paths = [] scores = [] @@ -127,6 +140,8 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 ''' def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating test torch loader...") + if not test_img_list: + raise DataLoadingError("Test image list is empty") # get data image_paths = [] scores = [] @@ -167,37 +182,47 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num returns images, scores, models, prefixes from CSV ''' def get_all_train_data(loader_dir, data_csv, down_factor, down_dir): + if not os.path.exists(data_csv): + raise DataLoadingError(f"CSV file not found: {data_csv}") # get all data and targets image_paths = [] scores = [] models = [] prefixes = [] - with open(data_csv, newline='') as csvfile: - datareader = csv.reader(csvfile) - index = 0 - for row in datareader: - image_path = row[0] - model_path = row[1] - pca_scores = row[2:] - # add name - prefix = get_prefix(image_path) - # data error check - # if prefix not in get_prefix(model_path): - # print("Error: Images and particles are mismatched in csv.") - # print(f"index: {index}") - # print(f"prefix: {prefix}") - # print(f"get_prefix(model_path): {get_prefix(model_path)}}") - # exit() - prefixes.append(prefix) - # add image path - image_paths.append(image_path) - # add score (un-normalized) - pca_scores = [float(i) for i in pca_scores] - scores.append(pca_scores) - # add model - mdl = get_particles(model_path) - models.append(mdl) - index += 1 + try: + with open(data_csv, newline='') as csvfile: + datareader = csv.reader(csvfile) + for row_num, row in enumerate(datareader, 1): + if len(row) < 3: + raise DataLoadingError( + f"Invalid row {row_num} in {data_csv}: expected at least 3 columns " + f"(image_path, model_path, pca_scores), got {len(row)}" + ) + image_path = row[0] + model_path = row[1] + pca_scores = row[2:] + # add name + prefix = get_prefix(image_path) + prefixes.append(prefix) + # add image path + image_paths.append(image_path) + # add score (un-normalized) + try: + pca_scores = [float(i) for i in pca_scores] + except ValueError as e: + raise DataLoadingError( + f"Invalid PCA scores in {data_csv} at row {row_num}: {e}" + ) + scores.append(pca_scores) + # add model + mdl = get_particles(model_path) + models.append(mdl) + except csv.Error as e: + raise DataLoadingError(f"Error parsing CSV file {data_csv}: {e}") + + if not image_paths: + raise DataLoadingError(f"CSV file is empty: {data_csv}") + images = get_images(loader_dir, image_paths, down_factor, down_dir) scores = whiten_PCA_scores(scores, loader_dir) return images, scores, models, prefixes @@ -241,18 +266,32 @@ def get_prefix(path): get list from .particles format ''' def get_particles(model_path): - f = open(model_path, "r") - data = [] - for line in f.readlines(): - points = line.split() - points = [float(i) for i in points] - data.append(points) - return(data) + if not os.path.exists(model_path): + raise DataLoadingError(f"Particle file not found: {model_path}") + try: + with open(model_path, "r") as f: + data = [] + for line_num, line in enumerate(f.readlines(), 1): + points = line.split() + try: + points = [float(i) for i in points] + except ValueError as e: + raise DataLoadingError( + f"Invalid particle data in {model_path} at line {line_num}: {e}" + ) + data.append(points) + if not data: + raise DataLoadingError(f"Particle file is empty: {model_path}") + return data + except IOError as e: + raise DataLoadingError(f"Error reading particle file {model_path}: {e}") ''' reads .nrrd files and returns whitened data ''' def get_images(loader_dir, image_list, down_factor, down_dir): + if not image_list: + raise DataLoadingError("Image list is empty") # get all images all_images = [] for image_path in image_list: @@ -263,8 +302,13 @@ def get_images(loader_dir, image_list, down_factor, down_dir): if not os.path.exists(res_img): apply_down_sample(image_path, res_img, down_factor) image_path = res_img - # for_viewing returns 'F' order, i.e., transpose, needed for this array - img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + try: + # for_viewing returns 'F' order, i.e., transpose, needed for this array + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") all_images.append(img) all_images = np.array(all_images) From bce52ebe9dee2258ca08d77f742b16d2cd942789 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 23:07:20 -0700 Subject: [PATCH 05/47] Add error handling to loaders and --exact_check option * Add DataLoadingError exception with descriptive messages including file paths and line numbers for debugging * Validate inputs in get_particles, get_images, get_all_train_data, get_validation_loader, and get_test_loader * Add --exact_check flag with save/verify modes for platform-specific refactoring verification * Return mean_distance from process_test_predictions for exact checking --- Examples/Python/RunUseCase.py | 2 + Examples/Python/deep_ssm.py | 37 +++++++++++++++---- .../DeepSSMUtils/run_utils.py | 5 ++- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 9394235a8f..24965ea137 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -69,6 +69,8 @@ parser.add_argument("--tiny_test", help="Run as a short test", action="store_true") parser.add_argument("--verify", help="Run as a full test", action="store_true") parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true") + parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)", + choices=["save", "verify"]) args = parser.parse_args() type = "" diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index a1fa04b330..eeaa74f2c3 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -529,10 +529,10 @@ def Run_Pipeline(args): planes=test_planes) print("Test mean mesh surface-to-surface distance: " + str(mean_dist)) - DeepSSMUtils.process_test_predictions(project, config_file) - + final_mean_dist = DeepSSMUtils.process_test_predictions(project, config_file) + # If tiny test or verify, check results and exit - check_results(args, mean_dist) + check_results(args, final_mean_dist, output_directory) open(status_dir + "step_12.txt", 'w').close() @@ -540,12 +540,35 @@ def Run_Pipeline(args): # Verification -def check_results(args, mean_dist): +def check_results(args, mean_dist, output_directory): if args.tiny_test: print("\nVerifying use case results.") - if not math.isclose(mean_dist, 10, rel_tol=1): - print("Test failed.") - exit(-1) + + exact_check_file = output_directory + "exact_check_value.txt" + + # Exact check for refactoring verification (platform-specific) + if args.exact_check == "save": + with open(exact_check_file, "w") as f: + f.write(str(mean_dist)) + print(f"Saved exact check value to: {exact_check_file}") + print(f"Value: {mean_dist}") + elif args.exact_check == "verify": + if not os.path.exists(exact_check_file): + print(f"Error: No saved value found at {exact_check_file}") + print("Run with --exact_check save first to create baseline.") + exit(-1) + with open(exact_check_file, "r") as f: + expected_mean_dist = float(f.read().strip()) + if mean_dist != expected_mean_dist: + print(f"Exact check failed: expected {expected_mean_dist}, got {mean_dist}") + exit(-1) + print(f"Exact check passed: {mean_dist}") + else: + # Relaxed check for CI/cross-platform + if not math.isclose(mean_dist, 10, rel_tol=1): + print("Test failed.") + exit(-1) + print("Done with test, verification succeeded.") exit(0) else: diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 7de1bd1e2c..723795d882 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -538,7 +538,8 @@ def process_test_predictions(project, config_file): template_particles, template_mesh, pred_dir) print("Distances: ", distances) - print("Mean distance: ", np.mean(distances)) + mean_distance = np.mean(distances) + print("Mean distance: ", mean_distance) # write to csv file in deepssm_dir csv_file = f"{deepssm_dir}/test_distances.csv" @@ -561,3 +562,5 @@ def process_test_predictions(project, config_file): mesh = sw.Mesh(local_mesh_file) mesh.applyTransform(transform) mesh.write(world_mesh_file) + + return mean_distance From 1290941264fe3e798198b7bd4f6b02676d803255 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 7 Jan 2026 01:16:27 -0700 Subject: [PATCH 06/47] Add --tl_net flag and fix TL-DeepSSM bugs - Add --tl_net flag to enable TL-DeepSSM network testing - Fix PyTorch 2.6 compatibility: add weights_only=False to torch.load calls in trainer.py and model.py for DataLoader loading - Fix eval.py returning wrong file path for tl_net mode - Fix deep_ssm.py path handling for local predictions directory --- Examples/Python/RunUseCase.py | 1 + Examples/Python/deep_ssm.py | 26 ++++++++++++------- .../DeepSSMUtilsPackage/DeepSSMUtils/eval.py | 10 +++---- .../DeepSSMUtilsPackage/DeepSSMUtils/model.py | 2 +- .../DeepSSMUtils/trainer.py | 4 +-- 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 24965ea137..98a30edfab 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -71,6 +71,7 @@ parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true") parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)", choices=["save", "verify"]) + parser.add_argument("--tl_net", help="Enable TL-DeepSSM network (deep_ssm use case only)", action="store_true") args = parser.parse_args() type = "" diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index eeaa74f2c3..9b618057c8 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -385,7 +385,7 @@ def Run_Pipeline(args): }, "use_best_model": True, "tl_net": { - "enabled": False, + "enabled": args.tl_net, "ae_epochs": 100, "tf_epochs": 100, "joint_epochs": 25, @@ -398,6 +398,10 @@ def Run_Pipeline(args): } if args.tiny_test: model_parameters["trainer"]["epochs"] = 1 + if args.tl_net: + model_parameters["tl_net"]["ae_epochs"] = 1 + model_parameters["tl_net"]["tf_epochs"] = 1 + model_parameters["tl_net"]["joint_epochs"] = 1 # Save config file with open(config_file, "w") as outfile: json.dump(model_parameters, outfile, indent=2) @@ -436,17 +440,17 @@ def Run_Pipeline(args): val_world_particles.append(project_path + subjects[index].get_world_particle_filenames()[0]) val_mesh_files.append(project_path + subjects[index].get_groomed_filenames()[0]) - val_out_dir = output_directory + model_name + '/validation_predictions/' predicted_val_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='validation') print("Validation world predictions saved.") - # Generate local predictions - local_val_prediction_dir = val_out_dir + 'local_predictions/' + # Generate local predictions - create directory next to world_predictions + world_pred_dir = os.path.dirname(predicted_val_world_particles[0]) + local_val_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions") if not os.path.exists(local_val_prediction_dir): os.makedirs(local_val_prediction_dir) predicted_val_local_particles = [] for particle_file, transform in zip(predicted_val_world_particles, val_transforms): particles = np.loadtxt(particle_file) - local_particle_file = particle_file.replace("world_predictions/", "local_predictions/") + local_particle_file = particle_file.replace("world_predictions", "local_predictions") local_particles = sw.utils.transformParticles(particles, transform, inverse=True) np.savetxt(local_particle_file, local_particles) predicted_val_local_particles.append(local_particle_file) @@ -462,6 +466,8 @@ def Run_Pipeline(args): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] # Get distance between clipped true and predicted meshes + # Get the validation output directory from the predictions path + val_out_dir = os.path.dirname(local_val_prediction_dir.rstrip('/')) + '/' mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_val_local_particles, val_mesh_files, template_particles, template_mesh, val_out_dir, planes=val_planes) @@ -500,17 +506,17 @@ def Run_Pipeline(args): with open(plane_file) as json_file: test_planes.append(json.load(json_file)['planes'][0]['points']) - test_out_dir = output_directory + model_name + '/test_predictions/' predicted_test_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='test') print("Test world predictions saved.") - # Generate local predictions - local_test_prediction_dir = test_out_dir + 'local_predictions/' + # Generate local predictions - create directory next to world_predictions + world_pred_dir = os.path.dirname(predicted_test_world_particles[0]) + local_test_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions") if not os.path.exists(local_test_prediction_dir): os.makedirs(local_test_prediction_dir) predicted_test_local_particles = [] for particle_file, transform in zip(predicted_test_world_particles, test_transforms): particles = np.loadtxt(particle_file) - local_particle_file = particle_file.replace("world_predictions/", "local_predictions/") + local_particle_file = particle_file.replace("world_predictions", "local_predictions") local_particles = sw.utils.transformParticles(particles, transform, inverse=True) np.savetxt(local_particle_file, local_particles) predicted_test_local_particles.append(local_particle_file) @@ -524,6 +530,8 @@ def Run_Pipeline(args): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] + # Get the test output directory from the predictions path + test_out_dir = os.path.dirname(local_test_prediction_dir.rstrip('/')) + '/' mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_test_local_particles, test_mesh_files, template_particles, template_mesh, test_out_dir, planes=test_planes) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py index 90b4fbd0c1..a850d10bd1 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py @@ -87,18 +87,18 @@ def test(config_file, loader="test"): [pred_tf, pred_mdl_tl] = model_tl(mdl, img) pred_scores.append(pred_tf.cpu().data.numpy()) # save the AE latent space as shape descriptors - filename = pred_path + test_names[index] + '.npy' - np.save(filename, pred_tf.squeeze().detach().cpu().numpy()) + latent_filename = pred_path + test_names[index] + '.npy' + np.save(latent_filename, pred_tf.squeeze().detach().cpu().numpy()) np.savetxt(particle_filename, pred_mdl_tl.squeeze().detach().cpu().numpy()) else: [pred, pred_mdl_pca] = model_pca(img) [pred, pred_mdl_ft] = model_ft(img) pred_scores.append(pred.cpu().data.numpy()[0]) - filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles' - np.savetxt(filename, pred_mdl_pca.squeeze().detach().cpu().numpy()) + pca_filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles' + np.savetxt(pca_filename, pred_mdl_pca.squeeze().detach().cpu().numpy()) np.savetxt(particle_filename, pred_mdl_ft.squeeze().detach().cpu().numpy()) print("Predicted particle file: ", particle_filename) - predicted_particle_files.append(filename) + predicted_particle_files.append(particle_filename) index += 1 sw_message("Test completed.") return predicted_particle_files diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index a0ffb81d4b..7d684ee62d 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -178,7 +178,7 @@ def __init__(self, conflict_file): parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + C.VALIDATION_LOADER) + loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) self.num_corr = loader.dataset.mdl_target[0].shape[0] img_dims = loader.dataset.img[0].shape self.img_dims = img_dims[1:] diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index 869e37b8a8..1dd9fcc575 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -418,8 +418,8 @@ def supervised_train_tl(config_file): train_loader_path = loader_dir + C.TRAIN_LOADER validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path) - val_loader = torch.load(validation_loader_path) + train_loader = torch.load(train_loader_path, weights_only=False) + val_loader = torch.load(validation_loader_path, weights_only=False) print("Done.") print("Defining model...") net = model.DeepSSMNet_TLNet(config_file) From 1c65c2d25dd290652cebf5220cd0d66153e45393 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 7 Jan 2026 01:27:09 -0700 Subject: [PATCH 07/47] Validate --exact_check and --tl_net are only used with deep_ssm --- Examples/Python/RunUseCase.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 98a30edfab..1648a6d8bb 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -74,6 +74,12 @@ parser.add_argument("--tl_net", help="Enable TL-DeepSSM network (deep_ssm use case only)", action="store_true") args = parser.parse_args() + # Validate deep_ssm-specific arguments + if args.exact_check and args.use_case != "deep_ssm": + parser.error("--exact_check is only supported for the deep_ssm use case") + if args.tl_net and args.use_case != "deep_ssm": + parser.error("--tl_net is only supported for the deep_ssm use case") + type = "" if args.tiny_test: type = "tiny_test_" From 0efe52cceebd16f66ebd11558076436fc706c76f Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 7 Jan 2026 01:33:12 -0700 Subject: [PATCH 08/47] Use separate exact_check files for standard and tl_net modes --- Examples/Python/deep_ssm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index 9b618057c8..d6234d1b54 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -552,7 +552,8 @@ def check_results(args, mean_dist, output_directory): if args.tiny_test: print("\nVerifying use case results.") - exact_check_file = output_directory + "exact_check_value.txt" + suffix = "_tl_net" if args.tl_net else "" + exact_check_file = output_directory + f"exact_check_value{suffix}.txt" # Exact check for refactoring verification (platform-specific) if args.exact_check == "save": From 881bc34da09fddccccaacb5a083a559fb1434b1b Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Thu, 8 Jan 2026 12:27:40 -0700 Subject: [PATCH 09/47] Add GTest-based tests for DeepSSM that use shapeworks project files. - Add Testing/DeepSSMTests/ with C++ test harness and shell scripts - Add deepssm_test_data.zip (6MB) containing femur meshes, images, constraints, and pre-configured project files - Fix bug in Commands.cpp where DeepSSM command returned false (exit code 1) on success instead of true (exit code 0) - Remove --tl_net argument from Python use case since testing different DeepSSM configurations is now done via project files --- Applications/shapeworks/Commands.cpp | 2 +- Examples/Python/RunUseCase.py | 3 --- Examples/Python/deep_ssm.py | 10 +++------- Testing/CMakeLists.txt | 1 + Testing/DeepSSMTests/CMakeLists.txt | 13 ++++++++++++ Testing/DeepSSMTests/DeepSSMTests.cpp | 20 +++++++++++++++++++ Testing/DeepSSMTests/deepssm_default.sh | 12 +++++++++++ Testing/DeepSSMTests/deepssm_fine_tune.sh | 12 +++++++++++ Testing/DeepSSMTests/deepssm_tl_net.sh | 12 +++++++++++ .../DeepSSMTests/deepssm_tl_net_fine_tune.sh | 12 +++++++++++ Testing/data/deepssm_test_data.zip | 3 +++ 11 files changed, 89 insertions(+), 11 deletions(-) create mode 100644 Testing/DeepSSMTests/CMakeLists.txt create mode 100644 Testing/DeepSSMTests/DeepSSMTests.cpp create mode 100755 Testing/DeepSSMTests/deepssm_default.sh create mode 100755 Testing/DeepSSMTests/deepssm_fine_tune.sh create mode 100755 Testing/DeepSSMTests/deepssm_tl_net.sh create mode 100755 Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh create mode 100644 Testing/data/deepssm_test_data.zip diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index da59fe0c67..e161d07256 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -499,7 +499,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& project->save(); - return false; + return true; } } // namespace shapeworks diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 1648a6d8bb..aa1a175f50 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -71,14 +71,11 @@ parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true") parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)", choices=["save", "verify"]) - parser.add_argument("--tl_net", help="Enable TL-DeepSSM network (deep_ssm use case only)", action="store_true") args = parser.parse_args() # Validate deep_ssm-specific arguments if args.exact_check and args.use_case != "deep_ssm": parser.error("--exact_check is only supported for the deep_ssm use case") - if args.tl_net and args.use_case != "deep_ssm": - parser.error("--tl_net is only supported for the deep_ssm use case") type = "" if args.tiny_test: diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index d6234d1b54..3e696997b4 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -385,7 +385,7 @@ def Run_Pipeline(args): }, "use_best_model": True, "tl_net": { - "enabled": args.tl_net, + "enabled": False, "ae_epochs": 100, "tf_epochs": 100, "joint_epochs": 25, @@ -396,12 +396,9 @@ def Run_Pipeline(args): "c_lat": 6.3 } } + if args.tiny_test: model_parameters["trainer"]["epochs"] = 1 - if args.tl_net: - model_parameters["tl_net"]["ae_epochs"] = 1 - model_parameters["tl_net"]["tf_epochs"] = 1 - model_parameters["tl_net"]["joint_epochs"] = 1 # Save config file with open(config_file, "w") as outfile: json.dump(model_parameters, outfile, indent=2) @@ -552,8 +549,7 @@ def check_results(args, mean_dist, output_directory): if args.tiny_test: print("\nVerifying use case results.") - suffix = "_tl_net" if args.tl_net else "" - exact_check_file = output_directory + f"exact_check_value{suffix}.txt" + exact_check_file = output_directory + "exact_check_value.txt" # Exact check for refactoring verification (platform-specific) if args.exact_check == "save": diff --git a/Testing/CMakeLists.txt b/Testing/CMakeLists.txt index c03ca89ef2..58dfb2fe16 100644 --- a/Testing/CMakeLists.txt +++ b/Testing/CMakeLists.txt @@ -77,3 +77,4 @@ add_subdirectory(ProjectTests) add_subdirectory(UseCaseTests) add_subdirectory(shapeworksTests) add_subdirectory(UtilsTests) +add_subdirectory(DeepSSMTests) diff --git a/Testing/DeepSSMTests/CMakeLists.txt b/Testing/DeepSSMTests/CMakeLists.txt new file mode 100644 index 0000000000..7119af3cef --- /dev/null +++ b/Testing/DeepSSMTests/CMakeLists.txt @@ -0,0 +1,13 @@ +set(TEST_SRCS + DeepSSMTests.cpp + ) + +add_executable(DeepSSMTests + ${TEST_SRCS} + ) + +target_link_libraries(DeepSSMTests + Testing + ) + +add_test(NAME DeepSSMTests COMMAND DeepSSMTests) diff --git a/Testing/DeepSSMTests/DeepSSMTests.cpp b/Testing/DeepSSMTests/DeepSSMTests.cpp new file mode 100644 index 0000000000..05f12e8299 --- /dev/null +++ b/Testing/DeepSSMTests/DeepSSMTests.cpp @@ -0,0 +1,20 @@ +#include "Testing.h" + +using namespace shapeworks; + +//--------------------------------------------------------------------------- +void run_deepssm_test(const std::string& name) { + setupenv(std::string(TEST_DATA_DIR) + "/../DeepSSMTests"); + + std::string command = "bash " + name; + ASSERT_FALSE(system(command.c_str())); +} + +//--------------------------------------------------------------------------- +TEST(DeepSSMTests, defaultTest) { run_deepssm_test("deepssm_default.sh"); } + +TEST(DeepSSMTests, tlNetTest) { run_deepssm_test("deepssm_tl_net.sh"); } + +TEST(DeepSSMTests, fineTuneTest) { run_deepssm_test("deepssm_fine_tune.sh"); } + +TEST(DeepSSMTests, tlNetFineTuneTest) { run_deepssm_test("deepssm_tl_net_fine_tune.sh"); } diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh new file mode 100755 index 0000000000..93bb64296e --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with default settings (no tl_net, no fine_tune) +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name default.swproj --all diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh new file mode 100755 index 0000000000..cc2b9095a6 --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with fine tuning enabled +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name fine_tune.swproj --all diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh new file mode 100755 index 0000000000..42450340fa --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with TL-DeepSSM network enabled +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name tl_net.swproj --all diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh new file mode 100755 index 0000000000..36083e3d88 --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with both TL-DeepSSM and fine tuning enabled +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name tl_net_fine_tune.swproj --all diff --git a/Testing/data/deepssm_test_data.zip b/Testing/data/deepssm_test_data.zip new file mode 100644 index 0000000000..621d3a1556 --- /dev/null +++ b/Testing/data/deepssm_test_data.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99c6a0a3f6bfa91cc00095db64cf9155fe037a9a56afd918aee25b9c3f4770d5 +size 6196905 From a16e305149bf64e2c05b54ab06d7daac9000b172 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Thu, 8 Jan 2026 12:51:37 -0700 Subject: [PATCH 10/47] Add result verification to DeepSSM tests Add verify_deepssm_results.py script that validates test output by checking mean surface-to-surface distance from test_distances.csv. Uses loose tolerance (0-300) for quick 1-epoch tests to catch catastrophic failures while keeping tests fast. Supports --exact_check save/verify for platform-specific refactoring verification with tighter tolerances. --- Testing/DeepSSMTests/deepssm_default.sh | 5 ++ Testing/DeepSSMTests/deepssm_fine_tune.sh | 5 ++ Testing/DeepSSMTests/deepssm_tl_net.sh | 5 ++ .../DeepSSMTests/deepssm_tl_net_fine_tune.sh | 5 ++ .../DeepSSMTests/verify_deepssm_results.py | 86 +++++++++++++++++++ 5 files changed, 106 insertions(+) create mode 100644 Testing/DeepSSMTests/verify_deepssm_results.py diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index 93bb64296e..c8a7305829 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -2,6 +2,8 @@ # Test DeepSSM with default settings (no tl_net, no fine_tune) set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name default.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh index cc2b9095a6..5b991a3f84 100755 --- a/Testing/DeepSSMTests/deepssm_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -2,6 +2,8 @@ # Test DeepSSM with fine tuning enabled set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name fine_tune.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh index 42450340fa..f246158782 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -2,6 +2,8 @@ # Test DeepSSM with TL-DeepSSM network enabled set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh index 36083e3d88..70ea18f1f8 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -2,6 +2,8 @@ # Test DeepSSM with both TL-DeepSSM and fine tuning enabled set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net_fine_tune.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/verify_deepssm_results.py b/Testing/DeepSSMTests/verify_deepssm_results.py new file mode 100644 index 0000000000..4152f2f407 --- /dev/null +++ b/Testing/DeepSSMTests/verify_deepssm_results.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Verify DeepSSM test results by checking the mean distance from test_distances.csv. + +Usage: + python verify_deepssm_results.py [--exact_check save|verify] [--expected ] + +The script checks that the mean surface-to-surface distance is reasonable (roughly 10, within tolerance). +For exact refactoring verification, use --exact_check save/verify to save or compare exact values. +""" + +import argparse +import csv +import math +import os +import sys + + +def get_mean_distance(project_dir: str) -> float: + """Read mean distance from test_distances.csv.""" + csv_path = os.path.join(project_dir, "deepssm", "test_distances.csv") + if not os.path.exists(csv_path): + raise FileNotFoundError(f"Results file not found: {csv_path}") + + distances = [] + with open(csv_path, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + distances.append(float(row['Distance'])) + + if not distances: + raise ValueError(f"No distances found in {csv_path}") + + return sum(distances) / len(distances) + + +def main(): + parser = argparse.ArgumentParser(description="Verify DeepSSM test results") + parser.add_argument("project_dir", help="Path to the project directory containing deepssm/ output") + parser.add_argument("--exact_check", choices=["save", "verify"], + help="Save or verify exact values for refactoring verification") + parser.add_argument("--expected", type=float, default=150.0, + help="Expected mean distance for relaxed check (default: 150.0)") + parser.add_argument("--tolerance", type=float, default=1.0, + help="Relative tolerance for relaxed check (default: 1.0 = 100%%)") + args = parser.parse_args() + + try: + mean_dist = get_mean_distance(args.project_dir) + print(f"Mean distance: {mean_dist}") + except (FileNotFoundError, ValueError) as e: + print(f"Error: {e}") + sys.exit(1) + + exact_check_file = os.path.join(args.project_dir, "exact_check_value.txt") + + if args.exact_check == "save": + with open(exact_check_file, "w") as f: + f.write(str(mean_dist)) + print(f"Saved exact check value to: {exact_check_file}") + sys.exit(0) + + elif args.exact_check == "verify": + if not os.path.exists(exact_check_file): + print(f"Error: No saved value found at {exact_check_file}") + print("Run with --exact_check save first to create baseline.") + sys.exit(1) + with open(exact_check_file, "r") as f: + expected = float(f.read().strip()) + if mean_dist != expected: + print(f"Exact check FAILED: expected {expected}, got {mean_dist}") + sys.exit(1) + print(f"Exact check PASSED: {mean_dist}") + sys.exit(0) + + else: + # Relaxed check for CI/cross-platform + if not math.isclose(mean_dist, args.expected, rel_tol=args.tolerance): + print(f"FAILED: mean distance {mean_dist} not close to {args.expected} (tolerance {args.tolerance})") + sys.exit(1) + print(f"PASSED: mean distance {mean_dist} is within tolerance of {args.expected}") + sys.exit(0) + + +if __name__ == "__main__": + main() From 45f4e25274325033b346a9fe7b942208f3f30b7c Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 9 Jan 2026 01:34:14 -0700 Subject: [PATCH 11/47] Add documentation and extended test infrastructure for DeepSSM - Add README.md with instructions for running tests and exact check mode - Add run_exact_check.sh to verify all quick test configurations - Add run_extended_tests.sh to run tests on a directory of projects - Add --baseline_file option to verify script for per-project baselines --- Testing/DeepSSMTests/README.md | 107 +++++++++++++ Testing/DeepSSMTests/run_exact_check.sh | 45 ++++++ Testing/DeepSSMTests/run_extended_tests.sh | 145 ++++++++++++++++++ .../DeepSSMTests/verify_deepssm_results.py | 4 +- 4 files changed, 300 insertions(+), 1 deletion(-) create mode 100644 Testing/DeepSSMTests/README.md create mode 100755 Testing/DeepSSMTests/run_exact_check.sh create mode 100755 Testing/DeepSSMTests/run_extended_tests.sh diff --git a/Testing/DeepSSMTests/README.md b/Testing/DeepSSMTests/README.md new file mode 100644 index 0000000000..e326fa8525 --- /dev/null +++ b/Testing/DeepSSMTests/README.md @@ -0,0 +1,107 @@ +# DeepSSM Tests + +Automated tests for DeepSSM using ShapeWorks project files (.swproj). + +## Test Configurations + +| Test | Description | +|------|-------------| +| `deepssm_default` | Standard DeepSSM (no TL-Net, no fine-tuning) | +| `deepssm_tl_net` | TL-DeepSSM network enabled | +| `deepssm_fine_tune` | Fine-tuning enabled | +| `deepssm_tl_net_fine_tune` | Both TL-DeepSSM and fine-tuning enabled | + +## Running Tests + +### Run all DeepSSM tests: +```bash +cd /path/to/build +ctest -R DeepSSMTests -V +``` + +### Run a specific test: +```bash +./bin/DeepSSMTests --gtest_filter="*default*" +./bin/DeepSSMTests --gtest_filter="*tl_net*" +``` + +### Run tests directly via shell scripts: +```bash +export DATA=/path/to/Testing/data +bash Testing/DeepSSMTests/deepssm_default.sh +``` + +## Test Data + +Test data is stored in `Testing/data/deepssm_test_data.zip` and automatically extracted on first run. Contains: +- 5 femur meshes, CT images, and constraint files +- Pre-configured project files for each test configuration + +## Result Verification + +Tests verify that the mean surface-to-surface distance is within tolerance. The default tolerance is loose (0-300) for quick 1-epoch tests. + +### Exact Check Mode (for refactoring verification) + +When refactoring DeepSSM code, you can verify results are identical before and after changes. + +**Run all configurations:** +```bash +# Save baselines (before refactoring) +bash Testing/DeepSSMTests/run_exact_check.sh save + +# Verify after refactoring +bash Testing/DeepSSMTests/run_exact_check.sh verify +``` + +**Run a single configuration:** +```bash +cd Testing/data/deepssm/projects +rm -rf deepssm groomed *_particles +shapeworks deepssm --name default.swproj --all + +# Save or verify +python Testing/DeepSSMTests/verify_deepssm_results.py . --exact_check save +python Testing/DeepSSMTests/verify_deepssm_results.py . --exact_check verify +``` + +Baseline values are saved to `exact_check_*.txt` in the project directory. + +**Note:** Exact check is platform-specific due to floating-point differences. Only compare results from the same machine. + +## Extended Tests (Manual) + +Extended tests run on a directory of projects for meaningful accuracy checks. These are not part of automated CI. + +### Directory Structure + +``` +/path/to/projects/ + project1/ + project1.swproj + femur/... + project2/ + project2.swproj + data/... +``` + +Each subdirectory should contain a `.swproj` file and its associated data. + +### Running Extended Tests + +```bash +# Run all projects with relaxed tolerance +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects + +# Save baselines for exact check +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects save + +# Verify against baselines +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects verify + +# Run specific project only +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects save femur +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects verify femur +``` + +Baseline values are saved to `exact_check_.txt` in each project directory. diff --git a/Testing/DeepSSMTests/run_exact_check.sh b/Testing/DeepSSMTests/run_exact_check.sh new file mode 100755 index 0000000000..d4bd7ad54c --- /dev/null +++ b/Testing/DeepSSMTests/run_exact_check.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Run exact check for all DeepSSM test configurations +# Usage: ./run_exact_check.sh save|verify + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${DATA:-$(dirname "$SCRIPT_DIR")/data}" + +if [ "$1" != "save" ] && [ "$1" != "verify" ]; then + echo "Usage: $0 save|verify" + echo " save - Save baseline values (run before refactoring)" + echo " verify - Verify against saved values (run after refactoring)" + exit 1 +fi + +MODE="$1" +CONFIGS="default tl_net fine_tune tl_net_fine_tune" + +# Unzip test data if not already extracted +if [ ! -d "${DATA_DIR}/deepssm" ]; then + unzip -q "${DATA_DIR}/deepssm_test_data.zip" -d "${DATA_DIR}/deepssm" +fi + +cd "${DATA_DIR}/deepssm/projects" + +for config in $CONFIGS; do + echo "========================================" + echo "Running $config..." + echo "========================================" + + rm -rf deepssm groomed *_particles + shapeworks deepssm --name ${config}.swproj --all + + # Run exact check with config-specific file + python "${SCRIPT_DIR}/verify_deepssm_results.py" . \ + --exact_check "$MODE" \ + --baseline_file "exact_check_${config}.txt" + + echo "" +done + +echo "========================================" +echo "All configurations: $MODE complete!" +echo "========================================" diff --git a/Testing/DeepSSMTests/run_extended_tests.sh b/Testing/DeepSSMTests/run_extended_tests.sh new file mode 100755 index 0000000000..1fa3f3afa7 --- /dev/null +++ b/Testing/DeepSSMTests/run_extended_tests.sh @@ -0,0 +1,145 @@ +#!/bin/bash +# Run extended DeepSSM tests on a directory of projects +# +# Usage: ./run_extended_tests.sh [save|verify|relaxed] [project] +# +# Arguments: +# base_dir - Directory containing project subdirectories +# mode - save: save baseline values +# verify: verify against saved baselines +# relaxed: run with loose tolerance (default) +# project - Optional: run only this project (default: all) +# +# Examples: +# ./run_extended_tests.sh /path/to/projects # Run all with relaxed check +# ./run_extended_tests.sh /path/to/projects save # Save baselines for all +# ./run_extended_tests.sh /path/to/projects verify # Verify all against baselines +# ./run_extended_tests.sh /path/to/projects save femur # Save baseline for femur only +# +# Directory structure: +# base_dir/ +# project1/ +# *.swproj +# femur/ (or other data) +# project2/ +# *.swproj +# ... + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +usage() { + echo "Usage: $0 [save|verify|relaxed] [project]" + echo "" + echo "Arguments:" + echo " base_dir - Directory containing project subdirectories" + echo " mode - save|verify|relaxed (default: relaxed)" + echo " project - Run only this project (default: all)" + echo "" + echo "Examples:" + echo " $0 /path/to/projects" + echo " $0 /path/to/projects save" + echo " $0 /path/to/projects verify" + echo " $0 /path/to/projects save femur" +} + +if [ $# -lt 1 ] || [ "$1" = "-h" ] || [ "$1" = "--help" ]; then + usage + exit 0 +fi + +BASE_DIR="$1" +MODE="${2:-relaxed}" +PROJECT="${3:-all}" + +if [ ! -d "$BASE_DIR" ]; then + echo "Error: Directory not found: $BASE_DIR" + exit 1 +fi + +if [ "$MODE" != "save" ] && [ "$MODE" != "verify" ] && [ "$MODE" != "relaxed" ]; then + echo "Error: Unknown mode: $MODE" + usage + exit 1 +fi + +run_project() { + local project_dir="$1" + local project_name="$(basename "$project_dir")" + + echo "========================================" + echo "Project: $project_name" + echo "========================================" + + # Find .swproj file + local swproj=$(find "$project_dir" -maxdepth 1 -name "*.swproj" | head -1) + if [ -z "$swproj" ]; then + echo "Warning: No .swproj file found in $project_dir, skipping" + return 0 + fi + + local swproj_name="$(basename "$swproj")" + echo "Using project file: $swproj_name" + + cd "$project_dir" + rm -rf deepssm groomed *_particles + + shapeworks deepssm --name "$swproj_name" --all + + # Verify results + local baseline_file="exact_check_${project_name}.txt" + local verify_args="" + + if [ "$MODE" = "save" ]; then + verify_args="--exact_check save --baseline_file $baseline_file" + elif [ "$MODE" = "verify" ]; then + verify_args="--exact_check verify --baseline_file $baseline_file" + else + verify_args="--expected 10 --tolerance 1.0" + fi + + python "${SCRIPT_DIR}/verify_deepssm_results.py" . $verify_args + + echo "" +} + +echo "Extended DeepSSM Tests" +echo "Base directory: ${BASE_DIR}" +echo "Mode: ${MODE}" +echo "" + +# Find all project directories (directories containing .swproj files) +ran_any=false +for project_dir in "$BASE_DIR"/*/; do + if [ ! -d "$project_dir" ]; then + continue + fi + + project_name="$(basename "$project_dir")" + + # Skip if specific project requested and this isn't it + if [ "$PROJECT" != "all" ] && [ "$PROJECT" != "$project_name" ]; then + continue + fi + + # Check if this directory has a .swproj file + if ls "$project_dir"/*.swproj 1>/dev/null 2>&1; then + run_project "$project_dir" + ran_any=true + fi +done + +if [ "$ran_any" = false ]; then + if [ "$PROJECT" = "all" ]; then + echo "Error: No projects found in $BASE_DIR" + echo "Each project should be a subdirectory containing a .swproj file." + else + echo "Error: Project not found: $PROJECT" + fi + exit 1 +fi + +echo "========================================" +echo "All projects complete!" +echo "========================================" diff --git a/Testing/DeepSSMTests/verify_deepssm_results.py b/Testing/DeepSSMTests/verify_deepssm_results.py index 4152f2f407..6375b4df27 100644 --- a/Testing/DeepSSMTests/verify_deepssm_results.py +++ b/Testing/DeepSSMTests/verify_deepssm_results.py @@ -43,6 +43,8 @@ def main(): help="Expected mean distance for relaxed check (default: 150.0)") parser.add_argument("--tolerance", type=float, default=1.0, help="Relative tolerance for relaxed check (default: 1.0 = 100%%)") + parser.add_argument("--baseline_file", type=str, default="exact_check_value.txt", + help="Filename for exact check baseline (default: exact_check_value.txt)") args = parser.parse_args() try: @@ -52,7 +54,7 @@ def main(): print(f"Error: {e}") sys.exit(1) - exact_check_file = os.path.join(args.project_dir, "exact_check_value.txt") + exact_check_file = os.path.join(args.project_dir, args.baseline_file) if args.exact_check == "save": with open(exact_check_file, "w") as f: From d200fbe9a653ce293f5ace86e4884e4f59ce4e4d Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 9 Jan 2026 12:01:09 -0700 Subject: [PATCH 12/47] Fix DeepSSM command arg parsing after return value fix --- Applications/shapeworks/Command.h | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Applications/shapeworks/Command.h b/Applications/shapeworks/Command.h index 8c6db366ef..a6e8b3a6a1 100644 --- a/Applications/shapeworks/Command.h +++ b/Applications/shapeworks/Command.h @@ -33,7 +33,7 @@ class Command { const std::string desc() const { return parser.description(); } /// parses the arguments for this command, saving them in the parser and returning the leftovers - std::vector parse_args(const std::vector &arguments); + virtual std::vector parse_args(const std::vector &arguments); /// calls execute for this command using the parsed args, returning system exit value int run(SharedCommandData &sharedData); @@ -108,6 +108,12 @@ class DeepSSMCommandGroup : public Command public: const std::string type() override { return "DeepSSM"; } + // DeepSSM is a terminal command - don't pass remaining args to other commands + std::vector parse_args(const std::vector &arguments) override { + Command::parse_args(arguments); + return {}; // return empty - DeepSSM consumes all args + } + private: }; From 8be54c0faac3c170b022c234d24d8c1c5c26253b Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:17:02 -0700 Subject: [PATCH 13/47] Fix toMesh pipeline and add empty mesh validation - Improve toMesh() pipeline in Image.cpp: add TriangleFilter to handle degenerate cells from vtkContourFilter, CleanPolyData to remove duplicates, and ConnectivityFilter to extract largest region - Add empty mesh validation in Groom after toMesh() - Add empty segmentation check before crop operation - Check both source and reference mesh in ICP transforms - Add validation in Mesh::extractLargestComponent() for empty/degenerate cells --- Libs/Groom/Groom.cpp | 15 +++++++++++++- Libs/Image/Image.cpp | 49 +++++++++++++++++++++++++++++++++++++++++--- Libs/Mesh/Mesh.cpp | 26 +++++++++++++++++++++++ 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/Libs/Groom/Groom.cpp b/Libs/Groom/Groom.cpp index e5b2a4ac55..3c83f16322 100644 --- a/Libs/Groom/Groom.cpp +++ b/Libs/Groom/Groom.cpp @@ -186,7 +186,17 @@ bool Groom::image_pipeline(std::shared_ptr subject, size_t domain) { std::string groomed_name = get_output_filename(original, DomainType::Image); if (params.get_convert_to_mesh()) { + // Use isovalue 0.0 for distance transforms (the zero level set is the surface) Mesh mesh = image.toMesh(0.0); + if (mesh.numPoints() == 0) { + throw std::runtime_error("Empty mesh generated from segmentation - segmentation may have no valid data"); + } + // Check for valid cells + auto poly_data = mesh.getVTKMesh(); + if (poly_data->GetNumberOfCells() == 0) { + throw std::runtime_error("Mesh has no cells - segmentation may have no valid surface"); + } + SW_DEBUG("Mesh after toMesh: {} points, {} cells", poly_data->GetNumberOfPoints(), poly_data->GetNumberOfCells()); run_mesh_pipeline(mesh, params, original); groomed_name = get_output_filename(original, DomainType::Mesh); // save the groomed mesh @@ -239,6 +249,9 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) { // crop if (params.get_crop()) { PhysicalRegion region = image.physicalBoundingBox(0.5); + if (!region.valid()) { + throw std::runtime_error("Empty segmentation - no voxels found above threshold for cropping"); + } image.crop(region); increment_progress(); } @@ -1336,7 +1349,7 @@ std::vector> Groom::get_icp_transforms(const std::vectorIdentity(); Mesh source = meshes[i]; - if (source.getVTKMesh()->GetNumberOfPoints() != 0) { + if (source.getVTKMesh()->GetNumberOfPoints() != 0 && reference.getVTKMesh()->GetNumberOfPoints() != 0) { // create copies for thread safety auto poly_data1 = vtkSmartPointer::New(); poly_data1->DeepCopy(source.getVTKMesh()); diff --git a/Libs/Image/Image.cpp b/Libs/Image/Image.cpp index fc4788daa7..26937abdf5 100644 --- a/Libs/Image/Image.cpp +++ b/Libs/Image/Image.cpp @@ -32,10 +32,13 @@ #include #include #include +#include #include #include #include #include +#include +#include #include #include @@ -1019,7 +1022,40 @@ Mesh Image::toMesh(PixelType isoValue) const { targetContour->SetValue(0, isoValue); targetContour->Update(); - return Mesh(targetContour->GetOutput()); + auto contourOutput = targetContour->GetOutput(); + + // Use vtkTriangleFilter FIRST to convert all polygons to proper triangles + // This removes degenerate cells that can crash downstream filters + auto triangleFilter = vtkSmartPointer::New(); + triangleFilter->SetInputData(contourOutput); + triangleFilter->PassVertsOff(); + triangleFilter->PassLinesOff(); + triangleFilter->Update(); + + // Clean the mesh to remove degenerate points and merge duplicates + auto clean = vtkSmartPointer::New(); + clean->SetInputData(triangleFilter->GetOutput()); + clean->ConvertPolysToLinesOff(); + clean->ConvertLinesToPointsOff(); + clean->ConvertStripsToPolysOff(); + clean->PointMergingOn(); + clean->SetTolerance(0.0); + clean->Update(); + + // Check if we have any data to process + auto cleanOutput = clean->GetOutput(); + if (cleanOutput->GetNumberOfPoints() == 0 || cleanOutput->GetNumberOfCells() == 0) { + // Return empty mesh + return Mesh(cleanOutput); + } + + // Use connectivity filter to extract only connected surface regions + auto connectivity = vtkSmartPointer::New(); + connectivity->SetInputData(cleanOutput); + connectivity->SetExtractionModeToLargestRegion(); + connectivity->Update(); + + return Mesh(connectivity->GetOutput()); } Image::PixelType Image::evaluate(Point p) { @@ -1170,11 +1206,18 @@ TransformPtr Image::createRigidRegistrationTransform(const Image& target_dt, flo Mesh sourceContour = toMesh(isoValue); Mesh targetContour = target_dt.toMesh(isoValue); + // Check for empty meshes before attempting ICP + if (sourceContour.numPoints() == 0 || targetContour.numPoints() == 0) { + SW_WARN("Cannot create ICP transform: source has {} points, target has {} points", + sourceContour.numPoints(), targetContour.numPoints()); + return AffineTransform::New(); + } + try { auto mat = MeshUtils::createICPTransform(sourceContour, targetContour, Mesh::Rigid, iterations); return shapeworks::createTransform(ShapeWorksUtils::convert_matrix(mat), ShapeWorksUtils::get_offset(mat)); - } catch (std::invalid_argument) { - std::cerr << "failed to create ICP transform.\n"; + } catch (std::invalid_argument& e) { + std::cerr << "failed to create ICP transform: " << e.what() << "\n"; if (sourceContour.numPoints() == 0) { std::cerr << "\tspecified isoValue (" << isoValue << ") results in an empty mesh for source\n"; } diff --git a/Libs/Mesh/Mesh.cpp b/Libs/Mesh/Mesh.cpp index 42df05e6fb..6023bbab74 100644 --- a/Libs/Mesh/Mesh.cpp +++ b/Libs/Mesh/Mesh.cpp @@ -606,6 +606,24 @@ Mesh& Mesh::fixNonManifold() { } Mesh& Mesh::extractLargestComponent() { + // Check for valid cells before attempting connectivity filter + if (poly_data_->GetNumberOfCells() == 0) { + SW_WARN("extractLargestComponent: mesh has no cells"); + return *this; + } + + // Verify mesh has at least some valid cells + bool hasValidCells = false; + for (vtkIdType i = 0; i < poly_data_->GetNumberOfCells() && !hasValidCells; i++) { + if (poly_data_->GetCellType(i) != 0) { // VTK_EMPTY_CELL = 0 + hasValidCells = true; + } + } + if (!hasValidCells) { + SW_WARN("extractLargestComponent: mesh has no valid cells (all cells are type 0)"); + return *this; + } + auto connectivityFilter = vtkSmartPointer::New(); connectivityFilter->SetExtractionModeToLargestRegion(); connectivityFilter->SetInputData(poly_data_); @@ -1603,6 +1621,14 @@ bool Mesh::compare(const Mesh& other, const double eps) const { MeshTransform Mesh::createRegistrationTransform(const Mesh& target, Mesh::AlignmentType align, unsigned iterations) const { + // Check for empty meshes before attempting ICP + if (numPoints() == 0 || target.numPoints() == 0) { + SW_WARN("Cannot create registration transform: source has {} points, target has {} points", + numPoints(), target.numPoints()); + vtkSmartPointer identity = vtkSmartPointer::New(); + identity->Identity(); + return createMeshTransform(identity); + } const vtkSmartPointer mat( MeshUtils::createICPTransform(this->poly_data_, target.getVTKMesh(), align, iterations, true)); return createMeshTransform(mat); From f2114c8c0791773332cec6883e09a4b467d6c8a5 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:02:30 -0700 Subject: [PATCH 14/47] Return identity transform for empty meshes in ICP When createICPTransform receives empty source or target meshes, return an identity transform with a warning instead of throwing an exception. This allows batch processing to continue gracefully when some shapes fail to generate valid meshes. --- Libs/Mesh/MeshUtils.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Libs/Mesh/MeshUtils.cpp b/Libs/Mesh/MeshUtils.cpp index 2468230ecb..bb9bbce1d1 100644 --- a/Libs/Mesh/MeshUtils.cpp +++ b/Libs/Mesh/MeshUtils.cpp @@ -71,7 +71,11 @@ const vtkSmartPointer MeshUtils::createICPTransform(const Mesh sou Mesh::AlignmentType align, const unsigned iterations, bool meshTransform) { if (source.numPoints() == 0 || target.numPoints() == 0) { - throw std::invalid_argument("empty mesh passed to MeshUtils::createICPTransform"); + SW_WARN("Empty mesh in createICPTransform: source has {} points, target has {} points - returning identity", + source.numPoints(), target.numPoints()); + vtkSmartPointer identity = vtkSmartPointer::New(); + identity->Identity(); + return identity; } vtkSmartPointer icp = vtkSmartPointer::New(); From c50c21bd5252b908908faf2d22401ba6293f2024 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:09:10 -0700 Subject: [PATCH 15/47] Add streaming data loaders to reduce DeepSSM memory usage Instead of loading all images into memory when creating DataLoaders, use streaming datasets that load images on-demand during training. This significantly reduces memory usage for large datasets. Key changes: - DeepSSMdatasetStreaming class loads images lazily from disk - Training/validation/test loaders save metadata instead of full data - load_data_loader() reconstructs loaders from metadata - get_loader_info() extracts dimensions without loading full dataset - Backward compatible with legacy pre-loaded loaders --- .../DeepSSMUtilsPackage/DeepSSMUtils/eval.py | 2 +- .../DeepSSMUtils/loaders.py | 446 ++++++++++++++++-- .../DeepSSMUtilsPackage/DeepSSMUtils/model.py | 15 +- .../DeepSSMUtils/trainer.py | 8 +- 4 files changed, 420 insertions(+), 51 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py index a850d10bd1..ee64b568d8 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py @@ -36,7 +36,7 @@ def test(config_file, loader="test"): # load the loaders sw_message("Loading " + loader + " data loader...") - test_loader = torch.load(loader_dir + loader, weights_only=False) + test_loader = loaders.load_data_loader(loader_dir + loader, loader_type='test') # initialization sw_message("Loading trained model...") diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index 7a6661e064..48391df834 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -6,12 +6,15 @@ import subprocess import torch from torch import nn -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset import shapeworks as sw from shapeworks.utils import sw_message from DeepSSMUtils import constants as C random.seed(1) +# Use streaming data loading to avoid loading all images into memory +USE_STREAMING = True + class DataLoadingError(Exception): """Raised when data loading fails.""" @@ -26,6 +29,83 @@ def make_dir(dirPath): if not os.path.exists(dirPath): os.makedirs(dirPath) + +''' +Load a DataLoader from a saved file. Handles both streaming (metadata) and legacy (full loader) formats. +''' +def load_data_loader(loader_path, loader_type='train'): + data = torch.load(loader_path, weights_only=False) + + # Check if it's streaming metadata or a full DataLoader + if isinstance(data, dict) and data.get('streaming', False): + # Reconstruct streaming DataLoader from metadata + if loader_type == 'train': + dataset = DeepSSMdatasetStreaming( + data['image_paths'], + data['scores'], + data['models'], + data['prefixes'], + data['mean_img'], + data['std_img'] + ) + return DataLoader( + dataset, + batch_size=data.get('batch_size', 1), + shuffle=True, + num_workers=data.get('num_workers', 0), + pin_memory=torch.cuda.is_available() + ) + else: + # Validation or test + dataset = DeepSSMdatasetStreaming( + data['image_paths'], + data['scores'], + data['models'], + data['names'], + data['mean_img'], + data['std_img'] + ) + return DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=data.get('num_workers', 0), + pin_memory=torch.cuda.is_available() + ) + else: + # Legacy format - data is already a DataLoader + return data + + +''' +Get dataset info (image dimensions, num_corr) from a loader file. +Works with both streaming and legacy formats. +''' +def get_loader_info(loader_path): + data = torch.load(loader_path, weights_only=False) + + if isinstance(data, dict) and data.get('streaming', False): + # Streaming format - load one image to get dimensions + image_path = data['image_paths'][0] + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + img_dims = img.shape + num_corr = len(data['models'][0]) + num_pca = len(data['scores'][0]) if data['scores'][0] != [1] else data.get('num_pca', 0) + return { + 'img_dims': img_dims, + 'num_corr': num_corr, + 'num_pca': num_pca, + 'streaming': True + } + else: + # Legacy format + return { + 'img_dims': data.dataset.img[0].shape[1:], + 'num_corr': data.dataset.mdl_target[0].shape[0], + 'num_pca': data.dataset.pca_target[0].shape[0], + 'streaming': False + } + ''' Reads csv and makes both train and validation data loaders from it ''' @@ -70,23 +150,66 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow ''' def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): sw_message("Creating training torch loader...") - # Get data make_dir(loader_dir) - images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir) - images, scores, models, prefixes = shuffle_data(images, scores, models, prefixes) - train_data = DeepSSMdataset(images, scores, models, prefixes) - # Save - trainloader = DataLoader( + + if USE_STREAMING: + # Streaming approach - don't load all images into memory + image_paths, scores, models, prefixes = get_all_train_data_streaming( + loader_dir, data_csv, down_factor, down_dir + ) + image_paths, scores, models, prefixes = shuffle_data(image_paths, scores, models, prefixes) + + # Load saved mean/std + mean_img = np.load(loader_dir + C.MEAN_IMG_FILE) + std_img = np.load(loader_dir + C.STD_IMG_FILE) + + train_data = DeepSSMdatasetStreaming( + list(image_paths), list(scores), list(models), list(prefixes), + float(mean_img), float(std_img) + ) + + # For streaming, we don't save the full DataLoader (it would try to pickle the dataset) + # Instead, save metadata that can be used to reconstruct the loader + trainloader = DataLoader( train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + C.TRAIN_LOADER - torch.save(trainloader, train_path) - sw_message("Training loader complete.") - return train_path + + # Save metadata for reconstruction + train_meta = { + 'image_paths': list(image_paths), + 'scores': list(scores), + 'models': list(models), + 'prefixes': list(prefixes), + 'mean_img': float(mean_img), + 'std_img': float(std_img), + 'batch_size': batch_size, + 'num_workers': num_workers, + 'streaming': True + } + train_path = loader_dir + C.TRAIN_LOADER + torch.save(train_meta, train_path) + sw_message("Training loader complete.") + return train_path + else: + # Legacy approach - load all into memory + images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir) + images, scores, models, prefixes = shuffle_data(images, scores, models, prefixes) + train_data = DeepSSMdataset(images, scores, models, prefixes) + trainloader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + train_path = loader_dir + C.TRAIN_LOADER + torch.save(trainloader, train_path) + sw_message("Training loader complete.") + return train_path ''' Makes validation data loader @@ -101,6 +224,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 raise DataLoadingError( f"Mismatched validation data: {len(val_img_list)} images but {len(val_particles)} particle files" ) + # Get data image_paths = [] scores = [] @@ -108,32 +232,67 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 names = [] for index in range(len(val_img_list)): image_path = val_img_list[index] - # add name prefix = get_prefix(image_path) names.append(prefix) image_paths.append(image_path) - scores.append([1]) # placeholder + scores.append([1]) # placeholder mdl = get_particles(val_particles[index]) models.append(mdl) - # Write test names to file so they are saved somewhere + + # Write validation names to file name_file = open(loader_dir + C.VALIDATION_NAMES_FILE, 'w+') name_file.write(str(names)) name_file.close() sw_message("Validation names saved to: " + loader_dir + C.VALIDATION_NAMES_FILE) - images = get_images(loader_dir, image_paths, down_factor, down_dir) - val_data = DeepSSMdataset(images, scores, models, names) - # Make loader - val_loader = DataLoader( + + if USE_STREAMING: + # Prepare image paths + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Load mean/std from training (should already exist) + mean_img = float(np.load(loader_dir + C.MEAN_IMG_FILE)) + std_img = float(np.load(loader_dir + C.STD_IMG_FILE)) + + val_data = DeepSSMdatasetStreaming(image_paths, scores, models, names, mean_img, std_img) + + val_loader = DataLoader( val_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + C.VALIDATION_LOADER - torch.save(val_loader, val_path) - sw_message("Validation loader complete.") - return val_path + + # Save metadata + val_meta = { + 'image_paths': image_paths, + 'scores': scores, + 'models': models, + 'names': names, + 'mean_img': mean_img, + 'std_img': std_img, + 'num_workers': num_workers, + 'streaming': True + } + val_path = loader_dir + C.VALIDATION_LOADER + torch.save(val_meta, val_path) + sw_message("Validation loader complete.") + return val_path + else: + # Legacy approach + images = get_images(loader_dir, image_paths, down_factor, down_dir) + val_data = DeepSSMdataset(images, scores, models, names) + val_loader = DataLoader( + val_data, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + val_path = loader_dir + C.VALIDATION_LOADER + torch.save(val_loader, val_path) + sw_message("Validation loader complete.") + return val_path ''' Makes test data loader @@ -142,44 +301,141 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num sw_message("Creating test torch loader...") if not test_img_list: raise DataLoadingError("Test image list is empty") - # get data + + # Get data image_paths = [] scores = [] models = [] test_names = [] for index in range(len(test_img_list)): image_path = test_img_list[index] - # add name prefix = get_prefix(image_path) test_names.append(prefix) image_paths.append(image_path) - # add label placeholders - scores.append([1]) - models.append([1]) - images = get_images(loader_dir, image_paths, down_factor, down_dir) - test_data = DeepSSMdataset(images, scores, models, test_names) - # Write test names to file so they are saved somewhere + scores.append([1]) # placeholder + models.append([1]) # placeholder + + # Write test names to file name_file = open(loader_dir + C.TEST_NAMES_FILE, 'w+') name_file.write(str(test_names)) name_file.close() sw_message("Test names saved to: " + loader_dir + C.TEST_NAMES_FILE) - # Make loader - testloader = DataLoader( + + if USE_STREAMING: + # Prepare image paths + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Load mean/std from training + mean_img = float(np.load(loader_dir + C.MEAN_IMG_FILE)) + std_img = float(np.load(loader_dir + C.STD_IMG_FILE)) + + test_data = DeepSSMdatasetStreaming(image_paths, scores, models, test_names, mean_img, std_img) + + testloader = DataLoader( test_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - test_path = loader_dir + C.TEST_LOADER - torch.save(testloader, test_path) - sw_message("Test loader complete.") - return test_path, test_names + + # Save metadata + test_meta = { + 'image_paths': image_paths, + 'scores': scores, + 'models': models, + 'names': test_names, + 'mean_img': mean_img, + 'std_img': std_img, + 'num_workers': num_workers, + 'streaming': True + } + test_path = loader_dir + C.TEST_LOADER + torch.save(test_meta, test_path) + sw_message("Test loader complete.") + return test_path, test_names + else: + # Legacy approach + images = get_images(loader_dir, image_paths, down_factor, down_dir) + test_data = DeepSSMdataset(images, scores, models, test_names) + testloader = DataLoader( + test_data, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + test_path = loader_dir + C.TEST_LOADER + torch.save(testloader, test_path) + sw_message("Test loader complete.") + return test_path, test_names ################################ Helper functions ###################################### ''' -returns images, scores, models, prefixes from CSV +Returns image_paths, scores, models, prefixes from CSV for streaming. +Computes mean/std incrementally without loading all images. +''' +def get_all_train_data_streaming(loader_dir, data_csv, down_factor, down_dir): + if not os.path.exists(data_csv): + raise DataLoadingError(f"CSV file not found: {data_csv}") + + image_paths = [] + scores = [] + models = [] + prefixes = [] + + try: + with open(data_csv, newline='') as csvfile: + datareader = csv.reader(csvfile) + for row_num, row in enumerate(datareader, 1): + if len(row) < 3: + raise DataLoadingError( + f"Invalid row {row_num} in {data_csv}: expected at least 3 columns " + f"(image_path, model_path, pca_scores), got {len(row)}" + ) + image_path = row[0] + model_path = row[1] + pca_scores = row[2:] + + prefix = get_prefix(image_path) + prefixes.append(prefix) + image_paths.append(image_path) + + try: + pca_scores = [float(i) for i in pca_scores] + except ValueError as e: + raise DataLoadingError( + f"Invalid PCA scores in {data_csv} at row {row_num}: {e}" + ) + scores.append(pca_scores) + + mdl = get_particles(model_path) + models.append(mdl) + except csv.Error as e: + raise DataLoadingError(f"Error parsing CSV file {data_csv}: {e}") + + if not image_paths: + raise DataLoadingError(f"CSV file is empty: {data_csv}") + + # Prepare image paths (apply downsampling if needed) + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Compute mean/std incrementally + sw_message("Computing image statistics incrementally...") + mean_img, std_img = compute_image_stats_incremental(image_paths, down_factor, down_dir) + np.save(loader_dir + C.MEAN_IMG_FILE, mean_img) + np.save(loader_dir + C.STD_IMG_FILE, std_img) + sw_message(f"Image stats: mean={mean_img:.4f}, std={std_img:.4f}") + + # Whiten PCA scores + scores = whiten_PCA_scores(scores, loader_dir) + + return image_paths, scores, models, prefixes + + +''' +returns images, scores, models, prefixes from CSV (legacy - loads all into memory) ''' def get_all_train_data(loader_dir, data_csv, down_factor, down_dir): if not os.path.exists(data_csv): @@ -238,6 +494,7 @@ def shuffle_data(images, scores, models, prefixes): ''' Class for DeepSSM datasets that works with Pytorch DataLoader +Loads all images into memory upfront (legacy approach). ''' class DeepSSMdataset(): def __init__(self, img, pca_target, mdl_target, names): @@ -254,6 +511,40 @@ def __getitem__(self, index): def __len__(self): return len(self.img) + +''' +Streaming dataset that loads images on-demand to minimize memory usage. +Only keeps file paths in memory, loads each image when accessed. +''' +class DeepSSMdatasetStreaming(Dataset): + def __init__(self, image_paths, pca_target, mdl_target, names, mean_img, std_img): + self.image_paths = image_paths + self.pca_target = torch.FloatTensor(np.array(pca_target)) + self.mdl_target = torch.FloatTensor(np.array(mdl_target)) + self.names = names + self.mean_img = mean_img + self.std_img = std_img + + def __getitem__(self, index): + # Load image on-demand + image_path = self.image_paths[index] + try: + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") + + # Normalize + img = (img - self.mean_img) / self.std_img + x = torch.FloatTensor(img).unsqueeze(0) # Add channel dimension + + y1 = self.pca_target[index] + y2 = self.mdl_target[index] + name = self.names[index] + return x, y1, y2, name + + def __len__(self): + return len(self.image_paths) + ''' returns sample prefix from path string ''' @@ -287,7 +578,86 @@ def get_particles(model_path): raise DataLoadingError(f"Error reading particle file {model_path}: {e}") ''' -reads .nrrd files and returns whitened data +Compute image mean and std incrementally without loading all images into memory. +Uses Welford's online algorithm for numerical stability. +''' +def compute_image_stats_incremental(image_list, down_factor=1, down_dir=None): + if not image_list: + raise DataLoadingError("Image list is empty") + + n = 0 + mean = 0.0 + M2 = 0.0 # Sum of squared differences from mean + + for i, image_path in enumerate(image_list): + # Handle downsampling + if down_dir is not None: + make_dir(down_dir) + img_name = os.path.basename(image_path) + res_img = os.path.join(down_dir, img_name) + if not os.path.exists(res_img): + apply_down_sample(image_path, res_img, down_factor) + image_path = res_img + + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + + try: + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") + + # Welford's online algorithm for each pixel value + for val in img.flat: + n += 1 + delta = val - mean + mean += delta / n + delta2 = val - mean + M2 += delta * delta2 + + # Free memory + del img + + if (i + 1) % 10 == 0: + sw_message(f" Computing stats: {i + 1}/{len(image_list)} images processed") + + if n < 2: + raise DataLoadingError("Need at least 2 pixel values to compute statistics") + + variance = M2 / n + std = np.sqrt(variance) + + return mean, std + + +''' +Prepare image paths, applying downsampling if needed. +Returns list of paths to use (either original or downsampled). +''' +def prepare_image_paths(image_list, down_factor=1, down_dir=None): + if not image_list: + raise DataLoadingError("Image list is empty") + + prepared_paths = [] + for image_path in image_list: + if down_dir is not None: + make_dir(down_dir) + img_name = os.path.basename(image_path) + res_img = os.path.join(down_dir, img_name) + if not os.path.exists(res_img): + apply_down_sample(image_path, res_img, down_factor) + image_path = res_img + + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + + prepared_paths.append(image_path) + + return prepared_paths + + +''' +reads .nrrd files and returns whitened data (legacy - loads all into memory) ''' def get_images(loader_dir, image_list, down_factor, down_dir): if not image_list: diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index 7d684ee62d..51d9514368 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -6,6 +6,7 @@ from collections import OrderedDict from DeepSSMUtils import net_utils from DeepSSMUtils import constants as C +from DeepSSMUtils import loaders class ConvolutionalBackbone(nn.Module): @@ -106,10 +107,9 @@ def __init__(self, config_file): parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) - self.num_corr = loader.dataset.mdl_target[0].shape[0] - img_dims = loader.dataset.img[0].shape - self.img_dims = img_dims[1:] + loader_info = loaders.get_loader_info(self.loader_dir + C.VALIDATION_LOADER) + self.num_corr = loader_info['num_corr'] + self.img_dims = loader_info['img_dims'] # encoder if parameters['encoder']['deterministic']: self.encoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir ) @@ -178,10 +178,9 @@ def __init__(self, conflict_file): parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) - self.num_corr = loader.dataset.mdl_target[0].shape[0] - img_dims = loader.dataset.img[0].shape - self.img_dims = img_dims[1:] + loader_info = loaders.get_loader_info(self.loader_dir + C.VALIDATION_LOADER) + self.num_corr = loader_info['num_corr'] + self.img_dims = loader_info['img_dims'] self.CorrespondenceEncoder = CorrespondenceEncoder(self.num_latent, self.num_corr) self.CorrespondenceDecoder = CorrespondenceDecoder(self.num_latent, self.num_corr) self.ImageEncoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index 1dd9fcc575..0151710b81 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -108,8 +108,8 @@ def supervised_train(config_file): train_loader_path = loader_dir + C.TRAIN_LOADER validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path, weights_only=False) - val_loader = torch.load(validation_loader_path, weights_only=False) + train_loader = loaders.load_data_loader(train_loader_path, loader_type='train') + val_loader = loaders.load_data_loader(validation_loader_path, loader_type='validation') print("Done.") # initializations num_pca = train_loader.dataset.pca_target[0].shape[0] @@ -418,8 +418,8 @@ def supervised_train_tl(config_file): train_loader_path = loader_dir + C.TRAIN_LOADER validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path, weights_only=False) - val_loader = torch.load(validation_loader_path, weights_only=False) + train_loader = loaders.load_data_loader(train_loader_path, loader_type='train') + val_loader = loaders.load_data_loader(validation_loader_path, loader_type='validation') print("Done.") print("Defining model...") net = model.DeepSSMNet_TLNet(config_file) From ab35dfd85814b5f167329b7819434a92990defcf Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:11:09 -0700 Subject: [PATCH 16/47] Fix bounding box calculation and add error handling in run_utils - Use world particle positions for bounding box calculation instead of transformed groomed meshes. World particles reflect actual aligned positions including optimization transforms. - Add periodic garbage collection during training image grooming - Add try/except around validation/test image registration to continue processing even if individual subjects fail - Skip missing validation/test images gracefully with warnings - Skip test subjects without predictions during post-processing --- .../DeepSSMUtils/run_utils.py | 242 +++++++++++------- 1 file changed, 155 insertions(+), 87 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 723795d882..b969fe2d99 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -1,6 +1,7 @@ import random import math import os +import gc import numpy as np import json @@ -155,21 +156,33 @@ def get_training_indices(project): def get_training_bounding_box(project): - """ Get the bounding box of the training subjects. """ + """ Get the bounding box of the training subjects. + + Uses world particle positions to compute the bounding box. This ensures + consistency with the actual aligned particle positions used during training, + which may include additional transforms applied during optimization that + aren't captured by get_groomed_transforms() alone. + """ subjects = project.get_subjects() training_indices = get_training_indices(project) - training_bounding_box = None - train_mesh_list = [] + + # Compute bounding box from world particles + min_pt = np.array([np.inf, np.inf, np.inf]) + max_pt = np.array([-np.inf, -np.inf, -np.inf]) + for i in training_indices: subject = subjects[i] - mesh = subject.get_groomed_clipped_mesh() - # apply transform - alignment = convert_transform_to_numpy(subject.get_groomed_transforms()[0]) - mesh.applyTransform(alignment) - train_mesh_list.append(mesh) + world_particle_files = subject.get_world_particle_filenames() + if world_particle_files: + particles = np.loadtxt(world_particle_files[0]) + min_pt = np.minimum(min_pt, particles.min(axis=0)) + max_pt = np.maximum(max_pt, particles.max(axis=0)) + + # Create bounding box from particle extents + # PhysicalRegion takes two sequences: min point and max point + bounding_box = sw.PhysicalRegion(min_pt.tolist(), max_pt.tolist()) - bounding_box = sw.MeshUtils.boundingBox(train_mesh_list).pad(10) - return bounding_box + return bounding_box.pad(10) def convert_transform_to_numpy(transform): @@ -229,14 +242,15 @@ def groom_training_images(project): f.write(bounding_box_string) sw_message("Grooming training images") - for i in get_training_indices(project): + training_indices = get_training_indices(project) + for count, i in enumerate(training_indices): if sw_check_abort(): sw_message("Aborted") return image_name = sw.utils.get_image_filename(subjects[i]) - sw_progress(i / (len(subjects) + 1), f"Grooming Training Image: {image_name}") + sw_progress(count / (len(training_indices) + 1), f"Grooming Training Image: {image_name}") image = sw.Image(image_name) subject = subjects[i] # get alignment transform @@ -257,6 +271,15 @@ def groom_training_images(project): # write image using the index of the subject image.write(deepssm_dir + f"/train_images/{i}.nrrd") + # Explicitly delete the image and run garbage collection periodically + # to prevent memory accumulation + del image + if count % 50 == 0: + gc.collect() + + # Final cleanup after processing all training images + gc.collect() + def run_data_augmentation(project, num_samples, num_dim, percent_variability, sampler, mixture_num=0, processes=1): """ Run data augmentation on the training images. """ @@ -362,86 +385,105 @@ def groom_val_test_images(project, indices): val_test_transforms = [] val_test_image_files = [] + failed_indices = [] - count = 1 - for i in val_test_indices: + for count, i in enumerate(val_test_indices): if sw_check_abort(): sw_message("Aborted") return image_name = sw.utils.get_image_filename(subjects[i]) sw_progress(count / (len(val_test_indices) + 1), - f"Grooming val/test image {image_name} ({count}/{len(val_test_indices)})") - count = count + 1 - image = sw.Image(image_name) + f"Grooming val/test image {image_name} ({count + 1}/{len(val_test_indices)})") + + try: + image = sw.Image(image_name) + + image_file = val_test_images_dir + f"{i}.nrrd" + + # check if this subject needs reflection + needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) + + # 1. Apply reflection + reflection = np.eye(4) + if needs_reflection: + reflection[axis, axis] = -1 + # account for offset + reflection[-1][0] = 2 * image.center()[0] + + image.applyTransform(reflection) + transform = sw.utils.getVTKtransform(reflection) + + # 2. Translate to have ref center to make rigid registration easier + translation = ref_center - image.center() + image.setOrigin(image.origin() + translation).write(image_file) + transform[:3, -1] += translation + + # 3. Translate with respect to slightly cropped ref + image = sw.Image(image_file).fitRegion(large_bb).write(image_file) + itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, + image_file, + transform_type='translation') + # 4. Apply transform + image.applyTransform(itk_translation_transform, + large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), + large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) + transform = np.matmul(vtk_translation_transform, transform) + + # 5. Crop with medium bounding box and find rigid transform + image.fitRegion(medium_bb).write(image_file) + itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, + image_file, transform_type='rigid') + + # 6. Apply transform + image.applyTransform(itk_rigid_transform, + medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), + medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) + transform = np.matmul(vtk_rigid_transform, transform) + + # 7. Get similarity transform from image registration and apply + image.fitRegion(bounding_box).write(image_file) + itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, + image_file, + transform_type='similarity') + image.applyTransform(itk_similarity_transform, + cropped_ref_image.origin(), cropped_ref_image.dims(), + cropped_ref_image.spacing(), cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + image.write(image_file) + vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) + transform = np.matmul(vtk_similarity_transform, transform) + + # 8. Save transform + val_test_transforms.append(transform) + extra_values = subjects[i].get_extra_values() + extra_values["registration_transform"] = transform_to_string(transform) - image_file = val_test_images_dir + f"{i}.nrrd" - - # check if this subject needs reflection - needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) - - # 1. Apply reflection - reflection = np.eye(4) - if needs_reflection: - reflection[axis, axis] = -1 - # account for offset - reflection[-1][0] = 2 * image.center()[0] - - image.applyTransform(reflection) - transform = sw.utils.getVTKtransform(reflection) - - # 2. Translate to have ref center to make rigid registration easier - translation = ref_center - image.center() - image.setOrigin(image.origin() + translation).write(image_file) - transform[:3, -1] += translation - - # 3. Translate with respect to slightly cropped ref - image = sw.Image(image_file).fitRegion(large_bb).write(image_file) - itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, - image_file, - transform_type='translation') - # 4. Apply transform - image.applyTransform(itk_translation_transform, - large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), - large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) - transform = np.matmul(vtk_translation_transform, transform) - - # 5. Crop with medium bounding box and find rigid transform - image.fitRegion(medium_bb).write(image_file) - itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, - image_file, transform_type='rigid') - - # 6. Apply transform - image.applyTransform(itk_rigid_transform, - medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), - medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) - transform = np.matmul(vtk_rigid_transform, transform) - - # 7. Get similarity transform from image registration and apply - image.fitRegion(bounding_box).write(image_file) - itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, - image_file, - transform_type='similarity') - image.applyTransform(itk_similarity_transform, - cropped_ref_image.origin(), cropped_ref_image.dims(), - cropped_ref_image.spacing(), cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - image.write(image_file) - vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) - transform = np.matmul(vtk_similarity_transform, transform) - - # 8. Save transform - val_test_transforms.append(transform) - extra_values = subjects[i].get_extra_values() - extra_values["registration_transform"] = transform_to_string(transform) + subjects[i].set_extra_values(extra_values) - subjects[i].set_extra_values(extra_values) + # Explicitly delete image and run garbage collection periodically + del image + except Exception as e: + sw_message(f"Warning: Failed to process val/test image for subject {i}: {e}") + failed_indices.append(i) + # Clean up partial file if it exists + if os.path.exists(val_test_images_dir + f"{i}.nrrd"): + os.remove(val_test_images_dir + f"{i}.nrrd") + + if count % 20 == 0: + gc.collect() + + # Final cleanup + gc.collect() project.set_subjects(subjects) + if failed_indices: + sw_message(f"Warning: {len(failed_indices)} val/test images failed to process: {failed_indices}") + def prepare_data_loaders(project, batch_size, split="all", num_workers=0): """ Prepare PyTorch laoders """ @@ -454,10 +496,17 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): val_image_files = [] val_world_particles = [] val_indices = get_split_indices(project, "val") + skipped_val = [] for i in val_indices: - val_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") - particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] - val_world_particles.append(particle_file) + image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" + if os.path.exists(image_file): + val_image_files.append(image_file) + particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] + val_world_particles.append(particle_file) + else: + skipped_val.append(i) + if skipped_val: + sw_message(f"Warning: Skipping {len(skipped_val)} missing validation images: {skipped_val}") DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers) if split == "all" or split == "train": @@ -468,8 +517,15 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): if split == "all" or split == "test": test_image_files = [] test_indices = get_split_indices(project, "test") + skipped_test = [] for i in test_indices: - test_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") + image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" + if os.path.exists(image_file): + test_image_files.append(image_file) + else: + skipped_test.append(i) + if skipped_test: + sw_message(f"Warning: Skipping {len(skipped_test)} missing test images: {skipped_test}") DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers) @@ -508,16 +564,25 @@ def process_test_predictions(project, config_file): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] - test_indices = get_split_indices(project, "test") + all_test_indices = get_split_indices(project, "test") predicted_test_local_particles = [] predicted_test_world_particles = [] test_transforms = [] test_mesh_files = [] + test_indices = [] # Only indices with valid predictions + skipped_indices = [] - for index in test_indices: + for index in all_test_indices: world_particle_file = f"{world_predictions_dir}/{index}.particles" + + # Skip subjects that don't have predictions (e.g., failed during image grooming) + if not os.path.exists(world_particle_file): + skipped_indices.append(index) + continue + print(f"world_particle_file: {world_particle_file}") + test_indices.append(index) predicted_test_world_particles.append(world_particle_file) transform = get_test_alignment_transform(project, index) @@ -534,6 +599,9 @@ def process_test_predictions(project, config_file): np.savetxt(local_particle_file, local_particles) predicted_test_local_particles.append(local_particle_file) + if skipped_indices: + sw_message(f"Warning: Skipping {len(skipped_indices)} test subjects without predictions: {skipped_indices}") + distances = eval_utils.get_mesh_distances(predicted_test_local_particles, test_mesh_files, template_particles, template_mesh, pred_dir) From 36a681011e67dd1c09c34291c1fd02ce7ab8296a Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 4 Feb 2026 10:18:16 -0700 Subject: [PATCH 17/47] Fail with clear errors instead of silently skipping missing files --- .../DeepSSMUtils/run_utils.py | 185 ++++++++---------- 1 file changed, 79 insertions(+), 106 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index b969fe2d99..77d1834fa7 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -385,7 +385,6 @@ def groom_val_test_images(project, indices): val_test_transforms = [] val_test_image_files = [] - failed_indices = [] for count, i in enumerate(val_test_indices): if sw_check_abort(): @@ -396,83 +395,76 @@ def groom_val_test_images(project, indices): sw_progress(count / (len(val_test_indices) + 1), f"Grooming val/test image {image_name} ({count + 1}/{len(val_test_indices)})") - try: - image = sw.Image(image_name) - - image_file = val_test_images_dir + f"{i}.nrrd" - - # check if this subject needs reflection - needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) - - # 1. Apply reflection - reflection = np.eye(4) - if needs_reflection: - reflection[axis, axis] = -1 - # account for offset - reflection[-1][0] = 2 * image.center()[0] - - image.applyTransform(reflection) - transform = sw.utils.getVTKtransform(reflection) - - # 2. Translate to have ref center to make rigid registration easier - translation = ref_center - image.center() - image.setOrigin(image.origin() + translation).write(image_file) - transform[:3, -1] += translation - - # 3. Translate with respect to slightly cropped ref - image = sw.Image(image_file).fitRegion(large_bb).write(image_file) - itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, - image_file, - transform_type='translation') - # 4. Apply transform - image.applyTransform(itk_translation_transform, - large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), - large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) - transform = np.matmul(vtk_translation_transform, transform) - - # 5. Crop with medium bounding box and find rigid transform - image.fitRegion(medium_bb).write(image_file) - itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, - image_file, transform_type='rigid') - - # 6. Apply transform - image.applyTransform(itk_rigid_transform, - medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), - medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) - transform = np.matmul(vtk_rigid_transform, transform) - - # 7. Get similarity transform from image registration and apply - image.fitRegion(bounding_box).write(image_file) - itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, - image_file, - transform_type='similarity') - image.applyTransform(itk_similarity_transform, - cropped_ref_image.origin(), cropped_ref_image.dims(), - cropped_ref_image.spacing(), cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - image.write(image_file) - vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) - transform = np.matmul(vtk_similarity_transform, transform) - - # 8. Save transform - val_test_transforms.append(transform) - extra_values = subjects[i].get_extra_values() - extra_values["registration_transform"] = transform_to_string(transform) + image = sw.Image(image_name) - subjects[i].set_extra_values(extra_values) + image_file = val_test_images_dir + f"{i}.nrrd" + + # check if this subject needs reflection + needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) + + # 1. Apply reflection + reflection = np.eye(4) + if needs_reflection: + reflection[axis, axis] = -1 + # account for offset + reflection[-1][0] = 2 * image.center()[0] + + image.applyTransform(reflection) + transform = sw.utils.getVTKtransform(reflection) + + # 2. Translate to have ref center to make rigid registration easier + translation = ref_center - image.center() + image.setOrigin(image.origin() + translation).write(image_file) + transform[:3, -1] += translation + + # 3. Translate with respect to slightly cropped ref + image = sw.Image(image_file).fitRegion(large_bb).write(image_file) + itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, + image_file, + transform_type='translation') + # 4. Apply transform + image.applyTransform(itk_translation_transform, + large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), + large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) + transform = np.matmul(vtk_translation_transform, transform) + + # 5. Crop with medium bounding box and find rigid transform + image.fitRegion(medium_bb).write(image_file) + itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, + image_file, transform_type='rigid') + + # 6. Apply transform + image.applyTransform(itk_rigid_transform, + medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), + medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) + transform = np.matmul(vtk_rigid_transform, transform) + + # 7. Get similarity transform from image registration and apply + image.fitRegion(bounding_box).write(image_file) + itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, + image_file, + transform_type='similarity') + image.applyTransform(itk_similarity_transform, + cropped_ref_image.origin(), cropped_ref_image.dims(), + cropped_ref_image.spacing(), cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + image.write(image_file) + vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) + transform = np.matmul(vtk_similarity_transform, transform) + + # 8. Save transform + val_test_transforms.append(transform) + extra_values = subjects[i].get_extra_values() + extra_values["registration_transform"] = transform_to_string(transform) + + subjects[i].set_extra_values(extra_values) - # Explicitly delete image and run garbage collection periodically - del image - except Exception as e: - sw_message(f"Warning: Failed to process val/test image for subject {i}: {e}") - failed_indices.append(i) - # Clean up partial file if it exists - if os.path.exists(val_test_images_dir + f"{i}.nrrd"): - os.remove(val_test_images_dir + f"{i}.nrrd") + # Explicitly delete image and run garbage collection periodically + del image if count % 20 == 0: gc.collect() @@ -481,9 +473,6 @@ def groom_val_test_images(project, indices): gc.collect() project.set_subjects(subjects) - if failed_indices: - sw_message(f"Warning: {len(failed_indices)} val/test images failed to process: {failed_indices}") - def prepare_data_loaders(project, batch_size, split="all", num_workers=0): """ Prepare PyTorch laoders """ @@ -496,17 +485,13 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): val_image_files = [] val_world_particles = [] val_indices = get_split_indices(project, "val") - skipped_val = [] for i in val_indices: image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" - if os.path.exists(image_file): - val_image_files.append(image_file) - particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] - val_world_particles.append(particle_file) - else: - skipped_val.append(i) - if skipped_val: - sw_message(f"Warning: Skipping {len(skipped_val)} missing validation images: {skipped_val}") + if not os.path.exists(image_file): + raise FileNotFoundError(f"Missing validation image for subject {i}: {image_file}") + val_image_files.append(image_file) + particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] + val_world_particles.append(particle_file) DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers) if split == "all" or split == "train": @@ -517,15 +502,11 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): if split == "all" or split == "test": test_image_files = [] test_indices = get_split_indices(project, "test") - skipped_test = [] for i in test_indices: image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" - if os.path.exists(image_file): - test_image_files.append(image_file) - else: - skipped_test.append(i) - if skipped_test: - sw_message(f"Warning: Skipping {len(skipped_test)} missing test images: {skipped_test}") + if not os.path.exists(image_file): + raise FileNotFoundError(f"Missing test image for subject {i}: {image_file}") + test_image_files.append(image_file) DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers) @@ -564,25 +545,20 @@ def process_test_predictions(project, config_file): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] - all_test_indices = get_split_indices(project, "test") + test_indices = get_split_indices(project, "test") predicted_test_local_particles = [] predicted_test_world_particles = [] test_transforms = [] test_mesh_files = [] - test_indices = [] # Only indices with valid predictions - skipped_indices = [] - for index in all_test_indices: + for index in test_indices: world_particle_file = f"{world_predictions_dir}/{index}.particles" - # Skip subjects that don't have predictions (e.g., failed during image grooming) if not os.path.exists(world_particle_file): - skipped_indices.append(index) - continue + raise FileNotFoundError(f"Missing prediction for test subject {index}: {world_particle_file}") print(f"world_particle_file: {world_particle_file}") - test_indices.append(index) predicted_test_world_particles.append(world_particle_file) transform = get_test_alignment_transform(project, index) @@ -599,9 +575,6 @@ def process_test_predictions(project, config_file): np.savetxt(local_particle_file, local_particles) predicted_test_local_particles.append(local_particle_file) - if skipped_indices: - sw_message(f"Warning: Skipping {len(skipped_indices)} test subjects without predictions: {skipped_indices}") - distances = eval_utils.get_mesh_distances(predicted_test_local_particles, test_mesh_files, template_particles, template_mesh, pred_dir) From ca9c7a3d01203b95f12a77a8879862d779d3965d Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Thu, 5 Feb 2026 12:16:08 -0700 Subject: [PATCH 18/47] Reduce DeepSSM tests from 4 to 2 configurations Run only default and tl_net_fine_tune tests, which together cover all code paths (standard DeepSSM, TL-DeepSSM, and fine tuning). Cuts test time from ~3 minutes to ~90 seconds. --- Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py | 6 +++--- Testing/DeepSSMTests/DeepSSMTests.cpp | 7 +++---- Testing/DeepSSMTests/deepssm_default.sh | 2 +- Testing/DeepSSMTests/deepssm_fine_tune.sh | 2 +- Testing/DeepSSMTests/deepssm_tl_net.sh | 2 +- Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh | 2 +- Testing/DeepSSMTests/run_exact_check.sh | 2 +- Testing/DeepSSMTests/run_extended_tests.sh | 2 +- 8 files changed, 12 insertions(+), 13 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py index 86e12fc03f..638158a577 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py @@ -2,15 +2,15 @@ import SimpleITK import numpy as np -def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): - # Prepare parameter map +def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid', max_iterations=1024): + # Prepare parameter map parameter_object = itk.ParameterObject.New() parameter_map = parameter_object.GetDefaultParameterMap('rigid') if transform_type == 'similarity': parameter_map['Transform'] = ['SimilarityTransform'] elif transform_type == 'translation': parameter_map['Transform'] = ['TranslationTransform'] - parameter_map['MaximumNumberOfIterations'] = ['1024'] + parameter_map['MaximumNumberOfIterations'] = [str(max_iterations)] parameter_object.AddParameterMap(parameter_map) # Load images diff --git a/Testing/DeepSSMTests/DeepSSMTests.cpp b/Testing/DeepSSMTests/DeepSSMTests.cpp index 05f12e8299..6783e325b9 100644 --- a/Testing/DeepSSMTests/DeepSSMTests.cpp +++ b/Testing/DeepSSMTests/DeepSSMTests.cpp @@ -11,10 +11,9 @@ void run_deepssm_test(const std::string& name) { } //--------------------------------------------------------------------------- +// Run 2 configurations that cover all code paths: +// - default: standard DeepSSM +// - tl_net_fine_tune: TL-DeepSSM with fine tuning (covers both tl_net and fine_tune paths) TEST(DeepSSMTests, defaultTest) { run_deepssm_test("deepssm_default.sh"); } -TEST(DeepSSMTests, tlNetTest) { run_deepssm_test("deepssm_tl_net.sh"); } - -TEST(DeepSSMTests, fineTuneTest) { run_deepssm_test("deepssm_fine_tune.sh"); } - TEST(DeepSSMTests, tlNetFineTuneTest) { run_deepssm_test("deepssm_tl_net_fine_tune.sh"); } diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index c8a7305829..fcdac3e31c 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name default.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh index 5b991a3f84..c0e96b800a 100755 --- a/Testing/DeepSSMTests/deepssm_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name fine_tune.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh index f246158782..2ed22c47c1 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh index 70ea18f1f8..9a2d154e1a 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net_fine_tune.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/run_exact_check.sh b/Testing/DeepSSMTests/run_exact_check.sh index d4bd7ad54c..e31cb61697 100755 --- a/Testing/DeepSSMTests/run_exact_check.sh +++ b/Testing/DeepSSMTests/run_exact_check.sh @@ -33,7 +33,7 @@ for config in $CONFIGS; do shapeworks deepssm --name ${config}.swproj --all # Run exact check with config-specific file - python "${SCRIPT_DIR}/verify_deepssm_results.py" . \ + python3 "${SCRIPT_DIR}/verify_deepssm_results.py" . \ --exact_check "$MODE" \ --baseline_file "exact_check_${config}.txt" diff --git a/Testing/DeepSSMTests/run_extended_tests.sh b/Testing/DeepSSMTests/run_extended_tests.sh index 1fa3f3afa7..46e96e6e6c 100755 --- a/Testing/DeepSSMTests/run_extended_tests.sh +++ b/Testing/DeepSSMTests/run_extended_tests.sh @@ -99,7 +99,7 @@ run_project() { verify_args="--expected 10 --tolerance 1.0" fi - python "${SCRIPT_DIR}/verify_deepssm_results.py" . $verify_args + python3 "${SCRIPT_DIR}/verify_deepssm_results.py" . $verify_args echo "" } From fb5e2d7f85547706d436093e86cc289e174f674e Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Thu, 5 Feb 2026 12:45:46 -0700 Subject: [PATCH 19/47] Resolve #2487 - Auto subset size in grooming should pick a smart auto auto (-1) defaults to a subset of 30 to avoid O(n^2) pairwise ICP on large datasets --- Libs/Groom/Groom.cpp | 2 +- Libs/Mesh/MeshUtils.cpp | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Libs/Groom/Groom.cpp b/Libs/Groom/Groom.cpp index e5b2a4ac55..0b0f24ab9a 100644 --- a/Libs/Groom/Groom.cpp +++ b/Libs/Groom/Groom.cpp @@ -560,7 +560,7 @@ bool Groom::run_alignment() { bool any_alignment = false; int reference_index = -1; - int subset_size = -1; + int subset_size = base_params.get_alignment_subset_size(); // per-domain alignment for (size_t domain = 0; domain < num_domains; domain++) { diff --git a/Libs/Mesh/MeshUtils.cpp b/Libs/Mesh/MeshUtils.cpp index 2468230ecb..fdf69958b1 100644 --- a/Libs/Mesh/MeshUtils.cpp +++ b/Libs/Mesh/MeshUtils.cpp @@ -182,6 +182,10 @@ PhysicalRegion MeshUtils::boundingBox(const std::vector& meshes, bool cent } int MeshUtils::findReferenceMesh(std::vector& meshes, int random_subset_size) { + // auto (-1) defaults to a subset of 30 to avoid O(n^2) pairwise ICP on large datasets + if (random_subset_size < 0) { + random_subset_size = 30; + } bool use_random_subset = random_subset_size > 0 && random_subset_size < meshes.size(); int num_meshes = use_random_subset ? random_subset_size : meshes.size(); From c77de4ce3b0a2da2b97a921828ba8986fd090444 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 15:39:27 -0700 Subject: [PATCH 20/47] Refactor DeepSSM: add constants module and reproducible seeding * Add constants.py to centralize magic strings (file names, loader names, device strings) for improved maintainability * Add set_seed() function in net_utils.py for reproducible training by seeding Python random, NumPy, PyTorch CPU/CUDA, and cuDNN * Update loaders.py, trainer.py, model.py, eval.py to use constants * Export constants and set_seed from __init__.py Verified: test outputs are identical before and after refactoring. --- .../DeepSSMUtils/__init__.py | 4 + .../DeepSSMUtils/constants.py | 73 +++++++++++++++++++ .../DeepSSMUtilsPackage/DeepSSMUtils/eval.py | 9 ++- .../DeepSSMUtils/loaders.py | 27 +++---- .../DeepSSMUtilsPackage/DeepSSMUtils/model.py | 21 +++--- .../DeepSSMUtils/net_utils.py | 21 +++++- .../DeepSSMUtils/trainer.py | 53 +++++++------- 7 files changed, 154 insertions(+), 54 deletions(-) create mode 100644 Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 135f4f0a05..738b914861 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -6,6 +6,10 @@ from DeepSSMUtils import train_viz from DeepSSMUtils import image_utils from DeepSSMUtils import run_utils +from DeepSSMUtils import net_utils +from DeepSSMUtils import constants + +from .net_utils import set_seed from .run_utils import create_split, groom_training_shapes, groom_training_images, \ run_data_augmentation, groom_val_test_images, prep_project_for_val_particles, groom_validation_shapes, \ diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py new file mode 100644 index 0000000000..db912adcf6 --- /dev/null +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py @@ -0,0 +1,73 @@ +""" +Constants used throughout DeepSSM. + +This module centralizes magic strings and default values to improve +maintainability and reduce errors from typos. +""" + +# Model file names +BEST_MODEL_FILE = "best_model.torch" +FINAL_MODEL_FILE = "final_model.torch" +BEST_MODEL_FT_FILE = "best_model_ft.torch" +FINAL_MODEL_FT_FILE = "final_model_ft.torch" +FINAL_MODEL_AE_FILE = "final_model_ae.torch" +FINAL_MODEL_TF_FILE = "final_model_tf.torch" + +# Data loader names +TRAIN_LOADER = "train" +VALIDATION_LOADER = "validation" +TEST_LOADER = "test" + +# File names for saved statistics +MEAN_PCA_FILE = "mean_PCA.npy" +STD_PCA_FILE = "std_PCA.npy" +MEAN_IMG_FILE = "mean_img.npy" +STD_IMG_FILE = "std_img.npy" + +# Names files +TRAIN_NAMES_FILE = "train_names.txt" +VALIDATION_NAMES_FILE = "validation_names.txt" +TEST_NAMES_FILE = "test_names.txt" + +# Log and plot files +TRAIN_LOG_FILE = "train_log.csv" +TRAINING_PLOT_FILE = "training_plot.png" +TRAINING_PLOT_FT_FILE = "training_plot_ft.png" +TRAINING_PLOT_AE_FILE = "training_plot_ae.png" +TRAINING_PLOT_TF_FILE = "training_plot_tf.png" +TRAINING_PLOT_JOINT_FILE = "training_plot_joint.png" + +# PCA info directory and files +PCA_INFO_DIR = "PCA_Particle_Info" +PCA_MEAN_FILE = "mean.particles" +PCA_MODE_FILE_TEMPLATE = "pcamode{}.particles" + +# Prediction directories +WORLD_PREDICTIONS_DIR = "world_predictions" +PCA_PREDICTIONS_DIR = "pca_predictions" +LOCAL_PREDICTIONS_DIR = "local_predictions" + +# Examples directory +EXAMPLES_DIR = "examples" +TRAIN_EXAMPLES_PREFIX = "train_" +VALIDATION_EXAMPLES_PREFIX = "validation_" + +# Training stage names (for logging) +class TrainingStage: + BASE = "Base_Training" + FINE_TUNING = "Fine_Tuning" + AUTOENCODER = "AE" + T_FLANK = "T-Flank" + JOINT = "Joint" + +# Default values +class Defaults: + BATCH_SIZE = 1 + DOWN_FACTOR = 1 + TRAIN_SPLIT = 0.80 + NUM_WORKERS = 0 + VAL_FREQ = 1 + +# Device strings +DEVICE_CUDA = "cuda:0" +DEVICE_CPU = "cpu" diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py index 5f7fe30e36..90b4fbd0c1 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py @@ -6,6 +6,7 @@ import torch from torch.utils.data import DataLoader from DeepSSMUtils import model, loaders +from DeepSSMUtils import constants as C from shapeworks.utils import sw_message from shapeworks.utils import sw_progress from shapeworks.utils import sw_check_abort @@ -24,9 +25,9 @@ def test(config_file, loader="test"): pred_dir = model_dir + loader + '_predictions/' loaders.make_dir(pred_dir) if parameters["use_best_model"]: - model_path = model_dir + 'best_model.torch' + model_path = model_dir + C.BEST_MODEL_FILE else: - model_path = model_dir + 'final_model.torch' + model_path = model_dir + C.FINAL_MODEL_FILE if parameters["fine_tune"]["enabled"]: model_path_ft = model_path.replace(".torch", "_ft.torch") else: @@ -67,9 +68,9 @@ def test(config_file, loader="test"): index = 0 pred_scores = [] - pred_path = pred_dir + 'world_predictions/' + pred_path = pred_dir + C.WORLD_PREDICTIONS_DIR + '/' loaders.make_dir(pred_path) - pred_path_pca = pred_dir + 'pca_predictions/' + pred_path_pca = pred_dir + C.PCA_PREDICTIONS_DIR + '/' loaders.make_dir(pred_path_pca) predicted_particle_files = [] diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index 5573b7e9db..b3a1ca8ff7 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader import shapeworks as sw from shapeworks.utils import sw_message +from DeepSSMUtils import constants as C random.seed(1) ######################## Data loading functions #################################### @@ -44,7 +45,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + 'train' + train_path = loader_dir + C.TRAIN_LOADER torch.save(trainloader, train_path) validationloader = DataLoader( @@ -54,7 +55,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + 'validation' + val_path = loader_dir + C.VALIDATION_LOADER torch.save(validationloader, val_path) sw_message("Training and validation loaders complete.\n") return train_path, val_path @@ -77,7 +78,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + 'train' + train_path = loader_dir + C.TRAIN_LOADER torch.save(trainloader, train_path) sw_message("Training loader complete.") return train_path @@ -102,10 +103,10 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 mdl = get_particles(val_particles[index]) models.append(mdl) # Write test names to file so they are saved somewhere - name_file = open(loader_dir + 'validation_names.txt', 'w+') + name_file = open(loader_dir + C.VALIDATION_NAMES_FILE, 'w+') name_file.write(str(names)) name_file.close() - sw_message("Validation names saved to: " + loader_dir + "validation_names.txt") + sw_message("Validation names saved to: " + loader_dir + C.VALIDATION_NAMES_FILE) images = get_images(loader_dir, image_paths, down_factor, down_dir) val_data = DeepSSMdataset(images, scores, models, names) # Make loader @@ -116,7 +117,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + 'validation' + val_path = loader_dir + C.VALIDATION_LOADER torch.save(val_loader, val_path) sw_message("Validation loader complete.") return val_path @@ -143,10 +144,10 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num images = get_images(loader_dir, image_paths, down_factor, down_dir) test_data = DeepSSMdataset(images, scores, models, test_names) # Write test names to file so they are saved somewhere - name_file = open(loader_dir + 'test_names.txt', 'w+') + name_file = open(loader_dir + C.TEST_NAMES_FILE, 'w+') name_file.write(str(test_names)) name_file.close() - sw_message("Test names saved to: " + loader_dir + "test_names.txt") + sw_message("Test names saved to: " + loader_dir + C.TEST_NAMES_FILE) # Make loader testloader = DataLoader( test_data, @@ -155,7 +156,7 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - test_path = loader_dir + 'test' + test_path = loader_dir + C.TEST_LOADER torch.save(testloader, test_path) sw_message("Test loader complete.") return test_path, test_names @@ -268,8 +269,8 @@ def get_images(loader_dir, image_list, down_factor, down_dir): all_images = np.array(all_images) # get mean and std - mean_path = loader_dir + 'mean_img.npy' - std_path = loader_dir + 'std_img.npy' + mean_path = loader_dir + C.MEAN_IMG_FILE + std_path = loader_dir + C.STD_IMG_FILE mean_image = np.mean(all_images) std_image = np.std(all_images) np.save(mean_path, mean_image) @@ -305,8 +306,8 @@ def whiten_PCA_scores(scores, loader_dir): scores = np.array(scores) mean_score = np.mean(scores, 0) std_score = np.std(scores, 0) - np.save(loader_dir + 'mean_PCA.npy', mean_score) - np.save(loader_dir + 'std_PCA.npy', std_score) + np.save(loader_dir + C.MEAN_PCA_FILE, mean_score) + np.save(loader_dir + C.STD_PCA_FILE, std_score) norm_scores = [] for score in scores: norm_scores.append((score-mean_score)/std_score) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index f512f2e244..a0ffb81d4b 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -5,6 +5,7 @@ import numpy as np from collections import OrderedDict from DeepSSMUtils import net_utils +from DeepSSMUtils import constants as C class ConvolutionalBackbone(nn.Module): @@ -61,9 +62,9 @@ class DeterministicEncoder(nn.Module): def __init__(self, num_latent, img_dims, loader_dir): super(DeterministicEncoder, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device self.num_latent = num_latent self.img_dims = img_dims @@ -97,15 +98,15 @@ class DeepSSMNet(nn.Module): def __init__(self, config_file): super(DeepSSMNet, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device - with open(config_file) as json_file: + with open(config_file) as json_file: parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + "validation", weights_only=False) + loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) self.num_corr = loader.dataset.mdl_target[0].shape[0] img_dims = loader.dataset.img[0].shape self.img_dims = img_dims[1:] @@ -169,15 +170,15 @@ class DeepSSMNet_TLNet(nn.Module): def __init__(self, conflict_file): super(DeepSSMNet_TLNet, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device - with open(conflict_file) as json_file: + with open(conflict_file) as json_file: parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + "validation") + loader = torch.load(self.loader_dir + C.VALIDATION_LOADER) self.num_corr = loader.dataset.mdl_target[0].shape[0] img_dims = loader.dataset.img[0].shape self.img_dims = img_dims[1:] diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py index 3ffa0a9014..792bc6ab85 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py @@ -1,6 +1,23 @@ +import random import torch from torch import nn import numpy as np +from DeepSSMUtils import constants as C + + +def set_seed(seed=42): + """ + Set random seeds for reproducibility across all random number generators. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + class Flatten(nn.Module): def forward(self, x): @@ -14,8 +31,8 @@ def poolOutDim(inDim, kernel_size, padding=0, stride=0, dilation=1): return outDim def unwhiten_PCA_scores(torch_loading, loader_dir, device): - mean_score = torch.from_numpy(np.load(loader_dir + '/mean_PCA.npy')).to(device).float() - std_score = torch.from_numpy(np.load(loader_dir + '/std_PCA.npy')).to(device).float() + mean_score = torch.from_numpy(np.load(loader_dir + '/' + C.MEAN_PCA_FILE)).to(device).float() + std_score = torch.from_numpy(np.load(loader_dir + '/' + C.STD_PCA_FILE)).to(device).float() mean_score = mean_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) std_score = std_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) pca_new = torch_loading*(std_score) + mean_score diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index f73e26fb34..c80a776244 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -13,6 +13,8 @@ from DeepSSMUtils import losses from DeepSSMUtils import train_viz from DeepSSMUtils import loaders +from DeepSSMUtils import net_utils +from DeepSSMUtils import constants as C import DeepSSMUtils from shapeworks.utils import * @@ -68,6 +70,7 @@ def set_scheduler(opt, sched_params): def train(project, config_file): + net_utils.set_seed(42) sw.utils.initialize_project_mesh_warper(project) with open(config_file) as json_file: @@ -101,8 +104,8 @@ def supervised_train(config_file): fine_tune = parameters['fine_tune']['enabled'] loss_func = method_to_call = getattr(losses, parameters["loss"]["function"]) # load the loaders - train_loader_path = loader_dir + "train" - validation_loader_path = loader_dir + "validation" + train_loader_path = loader_dir + C.TRAIN_LOADER + validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") train_loader = torch.load(train_loader_path, weights_only=False) val_loader = torch.load(validation_loader_path, weights_only=False) @@ -119,8 +122,8 @@ def supervised_train(config_file): net.apply(weight_init(module=nn.Linear, initf=nn.init.xavier_normal_)) # these lines are for the fine tuning layer initialization - whiten_mean = np.load(loader_dir + '/mean_PCA.npy') - whiten_std = np.load(loader_dir + '/std_PCA.npy') + whiten_mean = np.load(loader_dir + '/' + C.MEAN_PCA_FILE) + whiten_std = np.load(loader_dir + '/' + C.STD_PCA_FILE) orig_mean = np.loadtxt(aug_dir + '/PCA_Particle_Info/mean.particles') orig_pc = np.zeros([num_pca, num_corr * 3]) for i in range(num_pca): @@ -146,7 +149,7 @@ def supervised_train(config_file): # train print("Beginning training on device = " + device + '\n') # Initialize logger - logger = open(model_dir + "train_log.csv", "w+", buffering=1) + logger = open(model_dir + C.TRAIN_LOG_FILE, "w+", buffering=1) log_print(logger, ["Training_Stage", "Epoch", "LR", "Train_Err", "Train_Rel_Err", "Val_Err", "Val_Rel_Err", "Sec"]) # Initialize training plot train_plot = plt.figure() @@ -158,7 +161,7 @@ def supervised_train(config_file): axe.set_xlim(0, num_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -241,17 +244,17 @@ def supervised_train(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_FILE) # save if val_rel_err < best_val_rel_error: best_val_rel_error = val_rel_err best_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FILE)) t0 = time.time() if decay_lr: scheduler.step() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FILE)) parameters['best_model_epochs'] = best_epoch with open(config_file, "w") as json_file: json.dump(parameters, json_file, indent=2) @@ -290,7 +293,7 @@ def supervised_train(config_file): axe.set_xlim(0, ft_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_ft.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_FT_FILE, dpi=300) epochs = [] plot_train_losses = [] plot_val_losses = [] @@ -355,7 +358,7 @@ def supervised_train(config_file): if val_rel_loss < best_ft_val_rel_error: best_ft_val_rel_error = val_rel_loss best_ft_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model_ft.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FT_FILE)) pred_particles.extend(pred_mdl.detach().cpu().numpy()) true_particles.extend(mdl.detach().cpu().numpy()) train_viz.write_examples(np.array(pred_particles), np.array(true_particles), val_names, @@ -376,12 +379,12 @@ def supervised_train(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_ft.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_FT_FILE) t0 = time.time() logger.close() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_ft.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FT_FILE)) parameters['best_ft_model_epochs'] = best_ft_epoch with open(config_file, "w") as json_file: @@ -411,8 +414,8 @@ def supervised_train_tl(config_file): a_lat = parameters["tl_net"]["a_lat"] c_lat = parameters["tl_net"]["c_lat"] # load the loaders - train_loader_path = loader_dir + "train" - validation_loader_path = loader_dir + "validation" + train_loader_path = loader_dir + C.TRAIN_LOADER + validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") train_loader = torch.load(train_loader_path) val_loader = torch.load(validation_loader_path) @@ -447,7 +450,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, ae_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_ae.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_AE_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -540,10 +543,10 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_ae.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_AE_FILE) t0 = time.time() # save - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_ae.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_AE_FILE)) # fix the autoencoder and train the TL-net for param in net.CorrespondenceDecoder.parameters(): param.requires_grad = False @@ -563,7 +566,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, tf_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_tf.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_TF_FILE, dpi=300) # initialize t0 = time.time() epochs = [] @@ -650,10 +653,10 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_tf.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_TF_FILE) t0 = time.time() # save - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_tf.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_TF_FILE)) # jointly train the model joint_epochs = parameters['tl_net']['joint_epochs'] alpha = parameters['tl_net']['alpha'] @@ -673,7 +676,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, joint_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_joint.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_JOINT_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -771,19 +774,19 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_joint.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_JOINT_FILE) # save val_rel_err = val_rel_ae_err + alpha * val_rel_tf_err if val_rel_err < best_val_rel_error: best_val_rel_error = val_rel_err best_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FILE)) t0 = time.time() if decay_lr: scheduler.step() logger.close() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FILE)) parameters['best_model_epochs'] = best_epoch with open(config_file, "w") as json_file: json.dump(parameters, json_file, indent=2) From 3554d87cd970326404d863f6c53b6f0be7f092f1 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 18:18:54 -0700 Subject: [PATCH 21/47] Add type hints to DeepSSM public API functions --- .../DeepSSMUtils/__init__.py | 134 ++++++++++++++++-- .../DeepSSMUtils/net_utils.py | 49 ++++++- 2 files changed, 166 insertions(+), 17 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 738b914861..561b798162 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -1,3 +1,5 @@ +from typing import List, Optional, Tuple, Any + from DeepSSMUtils import trainer from DeepSSMUtils import loaders from DeepSSMUtils import eval @@ -20,65 +22,171 @@ import torch -def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): +def getTrainValLoaders( + loader_dir: str, + aug_data_csv: str, + batch_size: int = 1, + down_factor: float = 1, + down_dir: Optional[str] = None, + train_split: float = 0.80, + num_workers: int = 0 +) -> None: + """Create training and validation data loaders from augmented data CSV.""" testPytorch() loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): +def getTrainLoader( + loader_dir: str, + aug_data_csv: str, + batch_size: int = 1, + down_factor: float = 1, + down_dir: Optional[str] = None, + train_split: float = 0.80, + num_workers: int = 0 +) -> None: + """Create training data loader from augmented data CSV.""" testPytorch() loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): +def getValidationLoader( + loader_dir: str, + val_img_list: List[str], + val_particles: List[str], + down_factor: float = 1, + down_dir: Optional[str] = None, + num_workers: int = 0 +) -> None: + """Create validation data loader from image and particle lists.""" loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir, num_workers) -def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): +def getTestLoader( + loader_dir: str, + test_img_list: List[str], + down_factor: float = 1, + down_dir: Optional[str] = None, + num_workers: int = 0 +) -> None: + """Create test data loader from image list.""" loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir, num_workers) -def prepareConfigFile(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate, - decay_lr, fine_tune, fine_tune_epochs, fine_tune_learning_rate): +def prepareConfigFile( + config_filename: str, + model_name: str, + embedded_dim: int, + out_dir: str, + loader_dir: str, + aug_dir: str, + epochs: int, + learning_rate: float, + decay_lr: bool, + fine_tune: bool, + fine_tune_epochs: int, + fine_tune_learning_rate: float +) -> None: + """Prepare a DeepSSM configuration file with the specified parameters.""" config_file.prepare_config_file(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate, decay_lr, fine_tune, fine_tune_epochs, fine_tune_learning_rate) -def trainDeepSSM(project, config_file): +def trainDeepSSM(project: Any, config_file: str) -> None: + """Train a DeepSSM model using the given project and configuration file.""" testPytorch() trainer.train(project, config_file) return -def testDeepSSM(config_file, loader="test"): +def testDeepSSM(config_file: str, loader: str = "test") -> List[str]: + """ + Test a trained DeepSSM model and return predicted particle files. + + Args: + config_file: Path to the configuration JSON file + loader: Which loader to use ("test" or "validation") + + Returns: + List of paths to predicted particle files + """ predicted_particle_files = eval.test(config_file, loader) return predicted_particle_files -def analyzeMSE(predicted_particles, true_particles): +def analyzeMSE( + predicted_particles: List[str], + true_particles: List[str] +) -> Tuple[float, float]: + """ + Analyze mean squared error between predicted and true particles. + + Returns: + Tuple of (mean_MSE, std_MSE) + """ mean_MSE, STD_MSE = eval_utils.get_MSE(predicted_particles, true_particles) return mean_MSE, STD_MSE -def analyzeMeshDistance(predicted_particles, mesh_files, template_particles, template_mesh, out_dir, planes=None): +def analyzeMeshDistance( + predicted_particles: List[str], + mesh_files: List[str], + template_particles: str, + template_mesh: str, + out_dir: str, + planes: Optional[Any] = None +) -> float: + """ + Analyze mesh distance between predicted particles and ground truth meshes. + + Returns: + Mean surface-to-surface distance + """ mean_distance = eval_utils.get_mesh_distance(predicted_particles, mesh_files, template_particles, template_mesh, out_dir, planes) return mean_distance -def analyzeResults(out_dir, DT_dir, prediction_dir, mean_prefix): +def analyzeResults( + out_dir: str, + DT_dir: str, + prediction_dir: str, + mean_prefix: str +) -> float: + """ + Analyze results by computing distance between predicted and ground truth meshes. + + Returns: + Average surface distance + """ avg_distance = eval_utils.get_distance_meshes(out_dir, DT_dir, prediction_dir, mean_prefix) return avg_distance -def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): +def get_image_registration_transform( + fixed_image_file: str, + moving_image_file: str, + transform_type: str = 'rigid' +) -> Any: + """ + Compute image registration transform between two images. + + Args: + fixed_image_file: Path to the fixed/reference image + moving_image_file: Path to the moving image to be registered + transform_type: Type of transform ('rigid', 'affine', etc.) + + Returns: + ITK transform object + """ itk_transform = image_utils.get_image_registration_transform(fixed_image_file, moving_image_file, transform_type=transform_type) return itk_transform -def testPytorch(): +def testPytorch() -> None: + """Check if PyTorch is using GPU and print a warning if not.""" if torch.cuda.is_available(): print("Running on GPU.") else: diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py index 792bc6ab85..447d60bdb9 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py @@ -5,9 +5,12 @@ from DeepSSMUtils import constants as C -def set_seed(seed=42): +def set_seed(seed: int = 42) -> None: """ Set random seeds for reproducibility across all random number generators. + + Args: + seed: Integer seed value for random number generators """ random.seed(seed) np.random.seed(seed) @@ -20,17 +23,55 @@ def set_seed(seed=42): class Flatten(nn.Module): - def forward(self, x): + """Flatten layer to reshape tensor for fully connected layers.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(x.size(0), -1) -def poolOutDim(inDim, kernel_size, padding=0, stride=0, dilation=1): + +def poolOutDim( + inDim: int, + kernel_size: int, + padding: int = 0, + stride: int = 0, + dilation: int = 1 +) -> int: + """ + Calculate output dimension after pooling operation. + + Args: + inDim: Input dimension size + kernel_size: Size of the pooling kernel + padding: Padding applied to input + stride: Stride of pooling (defaults to kernel_size if 0) + dilation: Dilation factor + + Returns: + Output dimension size after pooling + """ if stride == 0: stride = kernel_size num = inDim + 2*padding - dilation*(kernel_size - 1) - 1 outDim = int(np.floor(num/stride + 1)) return outDim -def unwhiten_PCA_scores(torch_loading, loader_dir, device): + +def unwhiten_PCA_scores( + torch_loading: torch.Tensor, + loader_dir: str, + device: str +) -> torch.Tensor: + """ + Unwhiten (denormalize) PCA scores using saved mean and std. + + Args: + torch_loading: Whitened PCA scores tensor + loader_dir: Directory containing mean_PCA.npy and std_PCA.npy + device: Device to load tensors to ('cuda:0' or 'cpu') + + Returns: + Unwhitened PCA scores tensor + """ mean_score = torch.from_numpy(np.load(loader_dir + '/' + C.MEAN_PCA_FILE)).to(device).float() std_score = torch.from_numpy(np.load(loader_dir + '/' + C.STD_PCA_FILE)).to(device).float() mean_score = mean_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) From 5b8c663387220939776f43372a50184495fde942 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 18:29:54 -0700 Subject: [PATCH 22/47] Add config schema validation for DeepSSM --- .../DeepSSMUtils/__init__.py | 1 + .../DeepSSMUtils/config_validation.py | 205 ++++++++++++++++++ .../DeepSSMUtils/trainer.py | 5 +- 3 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 561b798162..eb51ab3904 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -10,6 +10,7 @@ from DeepSSMUtils import run_utils from DeepSSMUtils import net_utils from DeepSSMUtils import constants +from DeepSSMUtils import config_validation from .net_utils import set_seed diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py new file mode 100644 index 0000000000..0725711558 --- /dev/null +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py @@ -0,0 +1,205 @@ +""" +Configuration file validation for DeepSSM. + +This module provides validation for DeepSSM config files to catch +errors early with clear error messages. +""" +import os +import json +from typing import Any, Dict, List, Optional + + +class ConfigValidationError(Exception): + """Raised when config validation fails.""" + pass + + +# Schema definition for DeepSSM config +CONFIG_SCHEMA = { + "model_name": {"type": str, "required": True}, + "num_latent_dim": {"type": int, "required": True, "min": 1}, + "paths": { + "type": dict, + "required": True, + "children": { + "out_dir": {"type": str, "required": True}, + "loader_dir": {"type": str, "required": True}, + "aug_dir": {"type": str, "required": True}, + } + }, + "encoder": { + "type": dict, + "required": True, + "children": { + "deterministic": {"type": bool, "required": True}, + } + }, + "decoder": { + "type": dict, + "required": True, + "children": { + "deterministic": {"type": bool, "required": True}, + "linear": {"type": bool, "required": True}, + } + }, + "loss": { + "type": dict, + "required": True, + "children": { + "function": {"type": str, "required": True, "choices": ["MSE", "Focal"]}, + "supervised_latent": {"type": bool, "required": True}, + } + }, + "trainer": { + "type": dict, + "required": True, + "children": { + "epochs": {"type": int, "required": True, "min": 1}, + "learning_rate": {"type": (int, float), "required": True, "min": 0}, + "val_freq": {"type": int, "required": True, "min": 1}, + "decay_lr": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "type": {"type": str, "required": False, "choices": ["Step", "CosineAnnealing"]}, + "parameters": {"type": dict, "required": False}, + } + }, + } + }, + "fine_tune": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "loss": {"type": str, "required": False, "choices": ["MSE", "Focal"]}, + "epochs": {"type": int, "required": False, "min": 1}, + "learning_rate": {"type": (int, float), "required": False, "min": 0}, + "val_freq": {"type": int, "required": False, "min": 1}, + } + }, + "use_best_model": {"type": bool, "required": True}, + "tl_net": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "ae_epochs": {"type": int, "required": False, "min": 1}, + "tf_epochs": {"type": int, "required": False, "min": 1}, + "joint_epochs": {"type": int, "required": False, "min": 1}, + "alpha": {"type": (int, float), "required": False}, + "a_ae": {"type": (int, float), "required": False}, + "c_ae": {"type": (int, float), "required": False}, + "a_lat": {"type": (int, float), "required": False}, + "c_lat": {"type": (int, float), "required": False}, + } + }, +} + + +def validate_config(config_path: str) -> Dict[str, Any]: + """ + Validate a DeepSSM configuration file. + + Args: + config_path: Path to the JSON configuration file + + Returns: + Validated configuration dictionary + + Raises: + ConfigValidationError: If validation fails + FileNotFoundError: If config file doesn't exist + json.JSONDecodeError: If config file is not valid JSON + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + + with open(config_path) as f: + try: + config = json.load(f) + except json.JSONDecodeError as e: + raise ConfigValidationError(f"Invalid JSON in config file: {e}") + + errors = _validate_dict(config, CONFIG_SCHEMA, "config") + + if errors: + error_msg = "Config validation failed:\n" + "\n".join(f" - {e}" for e in errors) + raise ConfigValidationError(error_msg) + + return config + + +def _validate_dict( + data: Dict[str, Any], + schema: Dict[str, Any], + path: str +) -> List[str]: + """ + Recursively validate a dictionary against a schema. + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + for key, rules in schema.items(): + full_path = f"{path}.{key}" + value = data.get(key) + + # Check required fields + if rules.get("required", False) and key not in data: + errors.append(f"Missing required field: {full_path}") + continue + + if key not in data: + continue + + # Check type + expected_type = rules.get("type") + if expected_type and not isinstance(value, expected_type): + type_name = expected_type.__name__ if isinstance(expected_type, type) else str(expected_type) + errors.append(f"Invalid type for {full_path}: expected {type_name}, got {type(value).__name__}") + continue + + # Check min value + if "min" in rules and isinstance(value, (int, float)): + if value < rules["min"]: + errors.append(f"Value too small for {full_path}: {value} < {rules['min']}") + + # Check choices + if "choices" in rules and value not in rules["choices"]: + errors.append(f"Invalid value for {full_path}: '{value}' not in {rules['choices']}") + + # Recurse into nested dicts + if expected_type == dict and "children" in rules: + errors.extend(_validate_dict(value, rules["children"], full_path)) + + return errors + + +def validate_paths_exist(config: Dict[str, Any], check_loader_dir: bool = True) -> List[str]: + """ + Validate that required paths in config exist. + + Args: + config: Configuration dictionary + check_loader_dir: Whether to check if loader_dir exists + + Returns: + List of warning messages for missing paths + """ + warnings = [] + paths = config.get("paths", {}) + + if check_loader_dir: + loader_dir = paths.get("loader_dir", "") + if loader_dir and not os.path.exists(loader_dir): + warnings.append(f"Loader directory does not exist: {loader_dir}") + + aug_dir = paths.get("aug_dir", "") + if aug_dir and not os.path.exists(aug_dir): + warnings.append(f"Augmentation directory does not exist: {aug_dir}") + + return warnings diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index c80a776244..869e37b8a8 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -15,6 +15,7 @@ from DeepSSMUtils import loaders from DeepSSMUtils import net_utils from DeepSSMUtils import constants as C +from DeepSSMUtils import config_validation import DeepSSMUtils from shapeworks.utils import * @@ -73,8 +74,8 @@ def train(project, config_file): net_utils.set_seed(42) sw.utils.initialize_project_mesh_warper(project) - with open(config_file) as json_file: - parameters = json.load(json_file) + # Validate config file before training + parameters = config_validation.validate_config(config_file) if parameters["tl_net"]["enabled"]: supervised_train_tl(config_file) else: From 5c815663ce2e6d6e3370fff7a271faf38db88f7e Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 18:41:12 -0700 Subject: [PATCH 23/47] Improve error handling in DeepSSM data loaders --- .../DeepSSMUtils/loaders.py | 114 ++++++++++++------ 1 file changed, 79 insertions(+), 35 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index b3a1ca8ff7..7a6661e064 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -12,6 +12,11 @@ from DeepSSMUtils import constants as C random.seed(1) + +class DataLoadingError(Exception): + """Raised when data loading fails.""" + pass + ######################## Data loading functions #################################### ''' @@ -88,6 +93,14 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir ''' def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating validation torch loader:") + if not val_img_list: + raise DataLoadingError("Validation image list is empty") + if not val_particles: + raise DataLoadingError("Validation particle list is empty") + if len(val_img_list) != len(val_particles): + raise DataLoadingError( + f"Mismatched validation data: {len(val_img_list)} images but {len(val_particles)} particle files" + ) # Get data image_paths = [] scores = [] @@ -127,6 +140,8 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 ''' def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating test torch loader...") + if not test_img_list: + raise DataLoadingError("Test image list is empty") # get data image_paths = [] scores = [] @@ -167,37 +182,47 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num returns images, scores, models, prefixes from CSV ''' def get_all_train_data(loader_dir, data_csv, down_factor, down_dir): + if not os.path.exists(data_csv): + raise DataLoadingError(f"CSV file not found: {data_csv}") # get all data and targets image_paths = [] scores = [] models = [] prefixes = [] - with open(data_csv, newline='') as csvfile: - datareader = csv.reader(csvfile) - index = 0 - for row in datareader: - image_path = row[0] - model_path = row[1] - pca_scores = row[2:] - # add name - prefix = get_prefix(image_path) - # data error check - # if prefix not in get_prefix(model_path): - # print("Error: Images and particles are mismatched in csv.") - # print(f"index: {index}") - # print(f"prefix: {prefix}") - # print(f"get_prefix(model_path): {get_prefix(model_path)}}") - # exit() - prefixes.append(prefix) - # add image path - image_paths.append(image_path) - # add score (un-normalized) - pca_scores = [float(i) for i in pca_scores] - scores.append(pca_scores) - # add model - mdl = get_particles(model_path) - models.append(mdl) - index += 1 + try: + with open(data_csv, newline='') as csvfile: + datareader = csv.reader(csvfile) + for row_num, row in enumerate(datareader, 1): + if len(row) < 3: + raise DataLoadingError( + f"Invalid row {row_num} in {data_csv}: expected at least 3 columns " + f"(image_path, model_path, pca_scores), got {len(row)}" + ) + image_path = row[0] + model_path = row[1] + pca_scores = row[2:] + # add name + prefix = get_prefix(image_path) + prefixes.append(prefix) + # add image path + image_paths.append(image_path) + # add score (un-normalized) + try: + pca_scores = [float(i) for i in pca_scores] + except ValueError as e: + raise DataLoadingError( + f"Invalid PCA scores in {data_csv} at row {row_num}: {e}" + ) + scores.append(pca_scores) + # add model + mdl = get_particles(model_path) + models.append(mdl) + except csv.Error as e: + raise DataLoadingError(f"Error parsing CSV file {data_csv}: {e}") + + if not image_paths: + raise DataLoadingError(f"CSV file is empty: {data_csv}") + images = get_images(loader_dir, image_paths, down_factor, down_dir) scores = whiten_PCA_scores(scores, loader_dir) return images, scores, models, prefixes @@ -241,18 +266,32 @@ def get_prefix(path): get list from .particles format ''' def get_particles(model_path): - f = open(model_path, "r") - data = [] - for line in f.readlines(): - points = line.split() - points = [float(i) for i in points] - data.append(points) - return(data) + if not os.path.exists(model_path): + raise DataLoadingError(f"Particle file not found: {model_path}") + try: + with open(model_path, "r") as f: + data = [] + for line_num, line in enumerate(f.readlines(), 1): + points = line.split() + try: + points = [float(i) for i in points] + except ValueError as e: + raise DataLoadingError( + f"Invalid particle data in {model_path} at line {line_num}: {e}" + ) + data.append(points) + if not data: + raise DataLoadingError(f"Particle file is empty: {model_path}") + return data + except IOError as e: + raise DataLoadingError(f"Error reading particle file {model_path}: {e}") ''' reads .nrrd files and returns whitened data ''' def get_images(loader_dir, image_list, down_factor, down_dir): + if not image_list: + raise DataLoadingError("Image list is empty") # get all images all_images = [] for image_path in image_list: @@ -263,8 +302,13 @@ def get_images(loader_dir, image_list, down_factor, down_dir): if not os.path.exists(res_img): apply_down_sample(image_path, res_img, down_factor) image_path = res_img - # for_viewing returns 'F' order, i.e., transpose, needed for this array - img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + try: + # for_viewing returns 'F' order, i.e., transpose, needed for this array + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") all_images.append(img) all_images = np.array(all_images) From fc58e5262fb86fd37b340ed53cb7a6e41eab86eb Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Tue, 6 Jan 2026 23:07:20 -0700 Subject: [PATCH 24/47] Add error handling to loaders and --exact_check option * Add DataLoadingError exception with descriptive messages including file paths and line numbers for debugging * Validate inputs in get_particles, get_images, get_all_train_data, get_validation_loader, and get_test_loader * Add --exact_check flag with save/verify modes for platform-specific refactoring verification * Return mean_distance from process_test_predictions for exact checking --- Examples/Python/RunUseCase.py | 2 + Examples/Python/deep_ssm.py | 37 +++++++++++++++---- .../DeepSSMUtils/run_utils.py | 5 ++- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 9394235a8f..24965ea137 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -69,6 +69,8 @@ parser.add_argument("--tiny_test", help="Run as a short test", action="store_true") parser.add_argument("--verify", help="Run as a full test", action="store_true") parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true") + parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)", + choices=["save", "verify"]) args = parser.parse_args() type = "" diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index a1fa04b330..eeaa74f2c3 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -529,10 +529,10 @@ def Run_Pipeline(args): planes=test_planes) print("Test mean mesh surface-to-surface distance: " + str(mean_dist)) - DeepSSMUtils.process_test_predictions(project, config_file) - + final_mean_dist = DeepSSMUtils.process_test_predictions(project, config_file) + # If tiny test or verify, check results and exit - check_results(args, mean_dist) + check_results(args, final_mean_dist, output_directory) open(status_dir + "step_12.txt", 'w').close() @@ -540,12 +540,35 @@ def Run_Pipeline(args): # Verification -def check_results(args, mean_dist): +def check_results(args, mean_dist, output_directory): if args.tiny_test: print("\nVerifying use case results.") - if not math.isclose(mean_dist, 10, rel_tol=1): - print("Test failed.") - exit(-1) + + exact_check_file = output_directory + "exact_check_value.txt" + + # Exact check for refactoring verification (platform-specific) + if args.exact_check == "save": + with open(exact_check_file, "w") as f: + f.write(str(mean_dist)) + print(f"Saved exact check value to: {exact_check_file}") + print(f"Value: {mean_dist}") + elif args.exact_check == "verify": + if not os.path.exists(exact_check_file): + print(f"Error: No saved value found at {exact_check_file}") + print("Run with --exact_check save first to create baseline.") + exit(-1) + with open(exact_check_file, "r") as f: + expected_mean_dist = float(f.read().strip()) + if mean_dist != expected_mean_dist: + print(f"Exact check failed: expected {expected_mean_dist}, got {mean_dist}") + exit(-1) + print(f"Exact check passed: {mean_dist}") + else: + # Relaxed check for CI/cross-platform + if not math.isclose(mean_dist, 10, rel_tol=1): + print("Test failed.") + exit(-1) + print("Done with test, verification succeeded.") exit(0) else: diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 7de1bd1e2c..723795d882 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -538,7 +538,8 @@ def process_test_predictions(project, config_file): template_particles, template_mesh, pred_dir) print("Distances: ", distances) - print("Mean distance: ", np.mean(distances)) + mean_distance = np.mean(distances) + print("Mean distance: ", mean_distance) # write to csv file in deepssm_dir csv_file = f"{deepssm_dir}/test_distances.csv" @@ -561,3 +562,5 @@ def process_test_predictions(project, config_file): mesh = sw.Mesh(local_mesh_file) mesh.applyTransform(transform) mesh.write(world_mesh_file) + + return mean_distance From 3b2e4ae9f2bbdb39295697ebc5631c4b6906e691 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 7 Jan 2026 01:16:27 -0700 Subject: [PATCH 25/47] Add --tl_net flag and fix TL-DeepSSM bugs - Add --tl_net flag to enable TL-DeepSSM network testing - Fix PyTorch 2.6 compatibility: add weights_only=False to torch.load calls in trainer.py and model.py for DataLoader loading - Fix eval.py returning wrong file path for tl_net mode - Fix deep_ssm.py path handling for local predictions directory --- Examples/Python/RunUseCase.py | 1 + Examples/Python/deep_ssm.py | 26 ++++++++++++------- .../DeepSSMUtilsPackage/DeepSSMUtils/eval.py | 10 +++---- .../DeepSSMUtilsPackage/DeepSSMUtils/model.py | 2 +- .../DeepSSMUtils/trainer.py | 4 +-- 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 24965ea137..98a30edfab 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -71,6 +71,7 @@ parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true") parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)", choices=["save", "verify"]) + parser.add_argument("--tl_net", help="Enable TL-DeepSSM network (deep_ssm use case only)", action="store_true") args = parser.parse_args() type = "" diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index eeaa74f2c3..9b618057c8 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -385,7 +385,7 @@ def Run_Pipeline(args): }, "use_best_model": True, "tl_net": { - "enabled": False, + "enabled": args.tl_net, "ae_epochs": 100, "tf_epochs": 100, "joint_epochs": 25, @@ -398,6 +398,10 @@ def Run_Pipeline(args): } if args.tiny_test: model_parameters["trainer"]["epochs"] = 1 + if args.tl_net: + model_parameters["tl_net"]["ae_epochs"] = 1 + model_parameters["tl_net"]["tf_epochs"] = 1 + model_parameters["tl_net"]["joint_epochs"] = 1 # Save config file with open(config_file, "w") as outfile: json.dump(model_parameters, outfile, indent=2) @@ -436,17 +440,17 @@ def Run_Pipeline(args): val_world_particles.append(project_path + subjects[index].get_world_particle_filenames()[0]) val_mesh_files.append(project_path + subjects[index].get_groomed_filenames()[0]) - val_out_dir = output_directory + model_name + '/validation_predictions/' predicted_val_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='validation') print("Validation world predictions saved.") - # Generate local predictions - local_val_prediction_dir = val_out_dir + 'local_predictions/' + # Generate local predictions - create directory next to world_predictions + world_pred_dir = os.path.dirname(predicted_val_world_particles[0]) + local_val_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions") if not os.path.exists(local_val_prediction_dir): os.makedirs(local_val_prediction_dir) predicted_val_local_particles = [] for particle_file, transform in zip(predicted_val_world_particles, val_transforms): particles = np.loadtxt(particle_file) - local_particle_file = particle_file.replace("world_predictions/", "local_predictions/") + local_particle_file = particle_file.replace("world_predictions", "local_predictions") local_particles = sw.utils.transformParticles(particles, transform, inverse=True) np.savetxt(local_particle_file, local_particles) predicted_val_local_particles.append(local_particle_file) @@ -462,6 +466,8 @@ def Run_Pipeline(args): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] # Get distance between clipped true and predicted meshes + # Get the validation output directory from the predictions path + val_out_dir = os.path.dirname(local_val_prediction_dir.rstrip('/')) + '/' mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_val_local_particles, val_mesh_files, template_particles, template_mesh, val_out_dir, planes=val_planes) @@ -500,17 +506,17 @@ def Run_Pipeline(args): with open(plane_file) as json_file: test_planes.append(json.load(json_file)['planes'][0]['points']) - test_out_dir = output_directory + model_name + '/test_predictions/' predicted_test_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='test') print("Test world predictions saved.") - # Generate local predictions - local_test_prediction_dir = test_out_dir + 'local_predictions/' + # Generate local predictions - create directory next to world_predictions + world_pred_dir = os.path.dirname(predicted_test_world_particles[0]) + local_test_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions") if not os.path.exists(local_test_prediction_dir): os.makedirs(local_test_prediction_dir) predicted_test_local_particles = [] for particle_file, transform in zip(predicted_test_world_particles, test_transforms): particles = np.loadtxt(particle_file) - local_particle_file = particle_file.replace("world_predictions/", "local_predictions/") + local_particle_file = particle_file.replace("world_predictions", "local_predictions") local_particles = sw.utils.transformParticles(particles, transform, inverse=True) np.savetxt(local_particle_file, local_particles) predicted_test_local_particles.append(local_particle_file) @@ -524,6 +530,8 @@ def Run_Pipeline(args): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] + # Get the test output directory from the predictions path + test_out_dir = os.path.dirname(local_test_prediction_dir.rstrip('/')) + '/' mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_test_local_particles, test_mesh_files, template_particles, template_mesh, test_out_dir, planes=test_planes) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py index 90b4fbd0c1..a850d10bd1 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py @@ -87,18 +87,18 @@ def test(config_file, loader="test"): [pred_tf, pred_mdl_tl] = model_tl(mdl, img) pred_scores.append(pred_tf.cpu().data.numpy()) # save the AE latent space as shape descriptors - filename = pred_path + test_names[index] + '.npy' - np.save(filename, pred_tf.squeeze().detach().cpu().numpy()) + latent_filename = pred_path + test_names[index] + '.npy' + np.save(latent_filename, pred_tf.squeeze().detach().cpu().numpy()) np.savetxt(particle_filename, pred_mdl_tl.squeeze().detach().cpu().numpy()) else: [pred, pred_mdl_pca] = model_pca(img) [pred, pred_mdl_ft] = model_ft(img) pred_scores.append(pred.cpu().data.numpy()[0]) - filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles' - np.savetxt(filename, pred_mdl_pca.squeeze().detach().cpu().numpy()) + pca_filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles' + np.savetxt(pca_filename, pred_mdl_pca.squeeze().detach().cpu().numpy()) np.savetxt(particle_filename, pred_mdl_ft.squeeze().detach().cpu().numpy()) print("Predicted particle file: ", particle_filename) - predicted_particle_files.append(filename) + predicted_particle_files.append(particle_filename) index += 1 sw_message("Test completed.") return predicted_particle_files diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index a0ffb81d4b..7d684ee62d 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -178,7 +178,7 @@ def __init__(self, conflict_file): parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + C.VALIDATION_LOADER) + loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) self.num_corr = loader.dataset.mdl_target[0].shape[0] img_dims = loader.dataset.img[0].shape self.img_dims = img_dims[1:] diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index 869e37b8a8..1dd9fcc575 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -418,8 +418,8 @@ def supervised_train_tl(config_file): train_loader_path = loader_dir + C.TRAIN_LOADER validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path) - val_loader = torch.load(validation_loader_path) + train_loader = torch.load(train_loader_path, weights_only=False) + val_loader = torch.load(validation_loader_path, weights_only=False) print("Done.") print("Defining model...") net = model.DeepSSMNet_TLNet(config_file) From 6566ecf58a662785055bc02994191a25a5688494 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 7 Jan 2026 01:27:09 -0700 Subject: [PATCH 26/47] Validate --exact_check and --tl_net are only used with deep_ssm --- Examples/Python/RunUseCase.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 98a30edfab..1648a6d8bb 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -74,6 +74,12 @@ parser.add_argument("--tl_net", help="Enable TL-DeepSSM network (deep_ssm use case only)", action="store_true") args = parser.parse_args() + # Validate deep_ssm-specific arguments + if args.exact_check and args.use_case != "deep_ssm": + parser.error("--exact_check is only supported for the deep_ssm use case") + if args.tl_net and args.use_case != "deep_ssm": + parser.error("--tl_net is only supported for the deep_ssm use case") + type = "" if args.tiny_test: type = "tiny_test_" From 83d7789525e8e7fad8f16820f794447c3bbd0e95 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 7 Jan 2026 01:33:12 -0700 Subject: [PATCH 27/47] Use separate exact_check files for standard and tl_net modes --- Examples/Python/deep_ssm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index 9b618057c8..d6234d1b54 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -552,7 +552,8 @@ def check_results(args, mean_dist, output_directory): if args.tiny_test: print("\nVerifying use case results.") - exact_check_file = output_directory + "exact_check_value.txt" + suffix = "_tl_net" if args.tl_net else "" + exact_check_file = output_directory + f"exact_check_value{suffix}.txt" # Exact check for refactoring verification (platform-specific) if args.exact_check == "save": From 36a499cda4b238504772f547a8e2c4e37622ac0d Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Thu, 8 Jan 2026 12:27:40 -0700 Subject: [PATCH 28/47] Add GTest-based tests for DeepSSM that use shapeworks project files. - Add Testing/DeepSSMTests/ with C++ test harness and shell scripts - Add deepssm_test_data.zip (6MB) containing femur meshes, images, constraints, and pre-configured project files - Fix bug in Commands.cpp where DeepSSM command returned false (exit code 1) on success instead of true (exit code 0) - Remove --tl_net argument from Python use case since testing different DeepSSM configurations is now done via project files --- Applications/shapeworks/Commands.cpp | 2 +- Examples/Python/RunUseCase.py | 3 --- Examples/Python/deep_ssm.py | 10 +++------- Testing/CMakeLists.txt | 1 + Testing/DeepSSMTests/CMakeLists.txt | 13 ++++++++++++ Testing/DeepSSMTests/DeepSSMTests.cpp | 20 +++++++++++++++++++ Testing/DeepSSMTests/deepssm_default.sh | 12 +++++++++++ Testing/DeepSSMTests/deepssm_fine_tune.sh | 12 +++++++++++ Testing/DeepSSMTests/deepssm_tl_net.sh | 12 +++++++++++ .../DeepSSMTests/deepssm_tl_net_fine_tune.sh | 12 +++++++++++ Testing/data/deepssm_test_data.zip | 3 +++ 11 files changed, 89 insertions(+), 11 deletions(-) create mode 100644 Testing/DeepSSMTests/CMakeLists.txt create mode 100644 Testing/DeepSSMTests/DeepSSMTests.cpp create mode 100755 Testing/DeepSSMTests/deepssm_default.sh create mode 100755 Testing/DeepSSMTests/deepssm_fine_tune.sh create mode 100755 Testing/DeepSSMTests/deepssm_tl_net.sh create mode 100755 Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh create mode 100644 Testing/data/deepssm_test_data.zip diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index da59fe0c67..e161d07256 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -499,7 +499,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& project->save(); - return false; + return true; } } // namespace shapeworks diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 1648a6d8bb..aa1a175f50 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -71,14 +71,11 @@ parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true") parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)", choices=["save", "verify"]) - parser.add_argument("--tl_net", help="Enable TL-DeepSSM network (deep_ssm use case only)", action="store_true") args = parser.parse_args() # Validate deep_ssm-specific arguments if args.exact_check and args.use_case != "deep_ssm": parser.error("--exact_check is only supported for the deep_ssm use case") - if args.tl_net and args.use_case != "deep_ssm": - parser.error("--tl_net is only supported for the deep_ssm use case") type = "" if args.tiny_test: diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index d6234d1b54..3e696997b4 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -385,7 +385,7 @@ def Run_Pipeline(args): }, "use_best_model": True, "tl_net": { - "enabled": args.tl_net, + "enabled": False, "ae_epochs": 100, "tf_epochs": 100, "joint_epochs": 25, @@ -396,12 +396,9 @@ def Run_Pipeline(args): "c_lat": 6.3 } } + if args.tiny_test: model_parameters["trainer"]["epochs"] = 1 - if args.tl_net: - model_parameters["tl_net"]["ae_epochs"] = 1 - model_parameters["tl_net"]["tf_epochs"] = 1 - model_parameters["tl_net"]["joint_epochs"] = 1 # Save config file with open(config_file, "w") as outfile: json.dump(model_parameters, outfile, indent=2) @@ -552,8 +549,7 @@ def check_results(args, mean_dist, output_directory): if args.tiny_test: print("\nVerifying use case results.") - suffix = "_tl_net" if args.tl_net else "" - exact_check_file = output_directory + f"exact_check_value{suffix}.txt" + exact_check_file = output_directory + "exact_check_value.txt" # Exact check for refactoring verification (platform-specific) if args.exact_check == "save": diff --git a/Testing/CMakeLists.txt b/Testing/CMakeLists.txt index c03ca89ef2..58dfb2fe16 100644 --- a/Testing/CMakeLists.txt +++ b/Testing/CMakeLists.txt @@ -77,3 +77,4 @@ add_subdirectory(ProjectTests) add_subdirectory(UseCaseTests) add_subdirectory(shapeworksTests) add_subdirectory(UtilsTests) +add_subdirectory(DeepSSMTests) diff --git a/Testing/DeepSSMTests/CMakeLists.txt b/Testing/DeepSSMTests/CMakeLists.txt new file mode 100644 index 0000000000..7119af3cef --- /dev/null +++ b/Testing/DeepSSMTests/CMakeLists.txt @@ -0,0 +1,13 @@ +set(TEST_SRCS + DeepSSMTests.cpp + ) + +add_executable(DeepSSMTests + ${TEST_SRCS} + ) + +target_link_libraries(DeepSSMTests + Testing + ) + +add_test(NAME DeepSSMTests COMMAND DeepSSMTests) diff --git a/Testing/DeepSSMTests/DeepSSMTests.cpp b/Testing/DeepSSMTests/DeepSSMTests.cpp new file mode 100644 index 0000000000..05f12e8299 --- /dev/null +++ b/Testing/DeepSSMTests/DeepSSMTests.cpp @@ -0,0 +1,20 @@ +#include "Testing.h" + +using namespace shapeworks; + +//--------------------------------------------------------------------------- +void run_deepssm_test(const std::string& name) { + setupenv(std::string(TEST_DATA_DIR) + "/../DeepSSMTests"); + + std::string command = "bash " + name; + ASSERT_FALSE(system(command.c_str())); +} + +//--------------------------------------------------------------------------- +TEST(DeepSSMTests, defaultTest) { run_deepssm_test("deepssm_default.sh"); } + +TEST(DeepSSMTests, tlNetTest) { run_deepssm_test("deepssm_tl_net.sh"); } + +TEST(DeepSSMTests, fineTuneTest) { run_deepssm_test("deepssm_fine_tune.sh"); } + +TEST(DeepSSMTests, tlNetFineTuneTest) { run_deepssm_test("deepssm_tl_net_fine_tune.sh"); } diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh new file mode 100755 index 0000000000..93bb64296e --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with default settings (no tl_net, no fine_tune) +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name default.swproj --all diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh new file mode 100755 index 0000000000..cc2b9095a6 --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with fine tuning enabled +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name fine_tune.swproj --all diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh new file mode 100755 index 0000000000..42450340fa --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with TL-DeepSSM network enabled +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name tl_net.swproj --all diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh new file mode 100755 index 0000000000..36083e3d88 --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Test DeepSSM with both TL-DeepSSM and fine tuning enabled +set -e + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name tl_net_fine_tune.swproj --all diff --git a/Testing/data/deepssm_test_data.zip b/Testing/data/deepssm_test_data.zip new file mode 100644 index 0000000000..621d3a1556 --- /dev/null +++ b/Testing/data/deepssm_test_data.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99c6a0a3f6bfa91cc00095db64cf9155fe037a9a56afd918aee25b9c3f4770d5 +size 6196905 From 0728968c445657bb29cd8fd15e6c9548b5435724 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Thu, 8 Jan 2026 12:51:37 -0700 Subject: [PATCH 29/47] Add result verification to DeepSSM tests Add verify_deepssm_results.py script that validates test output by checking mean surface-to-surface distance from test_distances.csv. Uses loose tolerance (0-300) for quick 1-epoch tests to catch catastrophic failures while keeping tests fast. Supports --exact_check save/verify for platform-specific refactoring verification with tighter tolerances. --- Testing/DeepSSMTests/deepssm_default.sh | 5 ++ Testing/DeepSSMTests/deepssm_fine_tune.sh | 5 ++ Testing/DeepSSMTests/deepssm_tl_net.sh | 5 ++ .../DeepSSMTests/deepssm_tl_net_fine_tune.sh | 5 ++ .../DeepSSMTests/verify_deepssm_results.py | 86 +++++++++++++++++++ 5 files changed, 106 insertions(+) create mode 100644 Testing/DeepSSMTests/verify_deepssm_results.py diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index 93bb64296e..c8a7305829 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -2,6 +2,8 @@ # Test DeepSSM with default settings (no tl_net, no fine_tune) set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name default.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh index cc2b9095a6..5b991a3f84 100755 --- a/Testing/DeepSSMTests/deepssm_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -2,6 +2,8 @@ # Test DeepSSM with fine tuning enabled set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name fine_tune.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh index 42450340fa..f246158782 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -2,6 +2,8 @@ # Test DeepSSM with TL-DeepSSM network enabled set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh index 36083e3d88..70ea18f1f8 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -2,6 +2,8 @@ # Test DeepSSM with both TL-DeepSSM and fine tuning enabled set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" @@ -10,3 +12,6 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net_fine_tune.swproj --all + +# Verify results +python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/verify_deepssm_results.py b/Testing/DeepSSMTests/verify_deepssm_results.py new file mode 100644 index 0000000000..4152f2f407 --- /dev/null +++ b/Testing/DeepSSMTests/verify_deepssm_results.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Verify DeepSSM test results by checking the mean distance from test_distances.csv. + +Usage: + python verify_deepssm_results.py [--exact_check save|verify] [--expected ] + +The script checks that the mean surface-to-surface distance is reasonable (roughly 10, within tolerance). +For exact refactoring verification, use --exact_check save/verify to save or compare exact values. +""" + +import argparse +import csv +import math +import os +import sys + + +def get_mean_distance(project_dir: str) -> float: + """Read mean distance from test_distances.csv.""" + csv_path = os.path.join(project_dir, "deepssm", "test_distances.csv") + if not os.path.exists(csv_path): + raise FileNotFoundError(f"Results file not found: {csv_path}") + + distances = [] + with open(csv_path, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + distances.append(float(row['Distance'])) + + if not distances: + raise ValueError(f"No distances found in {csv_path}") + + return sum(distances) / len(distances) + + +def main(): + parser = argparse.ArgumentParser(description="Verify DeepSSM test results") + parser.add_argument("project_dir", help="Path to the project directory containing deepssm/ output") + parser.add_argument("--exact_check", choices=["save", "verify"], + help="Save or verify exact values for refactoring verification") + parser.add_argument("--expected", type=float, default=150.0, + help="Expected mean distance for relaxed check (default: 150.0)") + parser.add_argument("--tolerance", type=float, default=1.0, + help="Relative tolerance for relaxed check (default: 1.0 = 100%%)") + args = parser.parse_args() + + try: + mean_dist = get_mean_distance(args.project_dir) + print(f"Mean distance: {mean_dist}") + except (FileNotFoundError, ValueError) as e: + print(f"Error: {e}") + sys.exit(1) + + exact_check_file = os.path.join(args.project_dir, "exact_check_value.txt") + + if args.exact_check == "save": + with open(exact_check_file, "w") as f: + f.write(str(mean_dist)) + print(f"Saved exact check value to: {exact_check_file}") + sys.exit(0) + + elif args.exact_check == "verify": + if not os.path.exists(exact_check_file): + print(f"Error: No saved value found at {exact_check_file}") + print("Run with --exact_check save first to create baseline.") + sys.exit(1) + with open(exact_check_file, "r") as f: + expected = float(f.read().strip()) + if mean_dist != expected: + print(f"Exact check FAILED: expected {expected}, got {mean_dist}") + sys.exit(1) + print(f"Exact check PASSED: {mean_dist}") + sys.exit(0) + + else: + # Relaxed check for CI/cross-platform + if not math.isclose(mean_dist, args.expected, rel_tol=args.tolerance): + print(f"FAILED: mean distance {mean_dist} not close to {args.expected} (tolerance {args.tolerance})") + sys.exit(1) + print(f"PASSED: mean distance {mean_dist} is within tolerance of {args.expected}") + sys.exit(0) + + +if __name__ == "__main__": + main() From ba4a1936503d06e1256c1ddfbd5a1e9cba45acb3 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 9 Jan 2026 01:34:14 -0700 Subject: [PATCH 30/47] Add documentation and extended test infrastructure for DeepSSM - Add README.md with instructions for running tests and exact check mode - Add run_exact_check.sh to verify all quick test configurations - Add run_extended_tests.sh to run tests on a directory of projects - Add --baseline_file option to verify script for per-project baselines --- Testing/DeepSSMTests/README.md | 107 +++++++++++++ Testing/DeepSSMTests/run_exact_check.sh | 45 ++++++ Testing/DeepSSMTests/run_extended_tests.sh | 145 ++++++++++++++++++ .../DeepSSMTests/verify_deepssm_results.py | 4 +- 4 files changed, 300 insertions(+), 1 deletion(-) create mode 100644 Testing/DeepSSMTests/README.md create mode 100755 Testing/DeepSSMTests/run_exact_check.sh create mode 100755 Testing/DeepSSMTests/run_extended_tests.sh diff --git a/Testing/DeepSSMTests/README.md b/Testing/DeepSSMTests/README.md new file mode 100644 index 0000000000..e326fa8525 --- /dev/null +++ b/Testing/DeepSSMTests/README.md @@ -0,0 +1,107 @@ +# DeepSSM Tests + +Automated tests for DeepSSM using ShapeWorks project files (.swproj). + +## Test Configurations + +| Test | Description | +|------|-------------| +| `deepssm_default` | Standard DeepSSM (no TL-Net, no fine-tuning) | +| `deepssm_tl_net` | TL-DeepSSM network enabled | +| `deepssm_fine_tune` | Fine-tuning enabled | +| `deepssm_tl_net_fine_tune` | Both TL-DeepSSM and fine-tuning enabled | + +## Running Tests + +### Run all DeepSSM tests: +```bash +cd /path/to/build +ctest -R DeepSSMTests -V +``` + +### Run a specific test: +```bash +./bin/DeepSSMTests --gtest_filter="*default*" +./bin/DeepSSMTests --gtest_filter="*tl_net*" +``` + +### Run tests directly via shell scripts: +```bash +export DATA=/path/to/Testing/data +bash Testing/DeepSSMTests/deepssm_default.sh +``` + +## Test Data + +Test data is stored in `Testing/data/deepssm_test_data.zip` and automatically extracted on first run. Contains: +- 5 femur meshes, CT images, and constraint files +- Pre-configured project files for each test configuration + +## Result Verification + +Tests verify that the mean surface-to-surface distance is within tolerance. The default tolerance is loose (0-300) for quick 1-epoch tests. + +### Exact Check Mode (for refactoring verification) + +When refactoring DeepSSM code, you can verify results are identical before and after changes. + +**Run all configurations:** +```bash +# Save baselines (before refactoring) +bash Testing/DeepSSMTests/run_exact_check.sh save + +# Verify after refactoring +bash Testing/DeepSSMTests/run_exact_check.sh verify +``` + +**Run a single configuration:** +```bash +cd Testing/data/deepssm/projects +rm -rf deepssm groomed *_particles +shapeworks deepssm --name default.swproj --all + +# Save or verify +python Testing/DeepSSMTests/verify_deepssm_results.py . --exact_check save +python Testing/DeepSSMTests/verify_deepssm_results.py . --exact_check verify +``` + +Baseline values are saved to `exact_check_*.txt` in the project directory. + +**Note:** Exact check is platform-specific due to floating-point differences. Only compare results from the same machine. + +## Extended Tests (Manual) + +Extended tests run on a directory of projects for meaningful accuracy checks. These are not part of automated CI. + +### Directory Structure + +``` +/path/to/projects/ + project1/ + project1.swproj + femur/... + project2/ + project2.swproj + data/... +``` + +Each subdirectory should contain a `.swproj` file and its associated data. + +### Running Extended Tests + +```bash +# Run all projects with relaxed tolerance +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects + +# Save baselines for exact check +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects save + +# Verify against baselines +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects verify + +# Run specific project only +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects save femur +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects verify femur +``` + +Baseline values are saved to `exact_check_.txt` in each project directory. diff --git a/Testing/DeepSSMTests/run_exact_check.sh b/Testing/DeepSSMTests/run_exact_check.sh new file mode 100755 index 0000000000..d4bd7ad54c --- /dev/null +++ b/Testing/DeepSSMTests/run_exact_check.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Run exact check for all DeepSSM test configurations +# Usage: ./run_exact_check.sh save|verify + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${DATA:-$(dirname "$SCRIPT_DIR")/data}" + +if [ "$1" != "save" ] && [ "$1" != "verify" ]; then + echo "Usage: $0 save|verify" + echo " save - Save baseline values (run before refactoring)" + echo " verify - Verify against saved values (run after refactoring)" + exit 1 +fi + +MODE="$1" +CONFIGS="default tl_net fine_tune tl_net_fine_tune" + +# Unzip test data if not already extracted +if [ ! -d "${DATA_DIR}/deepssm" ]; then + unzip -q "${DATA_DIR}/deepssm_test_data.zip" -d "${DATA_DIR}/deepssm" +fi + +cd "${DATA_DIR}/deepssm/projects" + +for config in $CONFIGS; do + echo "========================================" + echo "Running $config..." + echo "========================================" + + rm -rf deepssm groomed *_particles + shapeworks deepssm --name ${config}.swproj --all + + # Run exact check with config-specific file + python "${SCRIPT_DIR}/verify_deepssm_results.py" . \ + --exact_check "$MODE" \ + --baseline_file "exact_check_${config}.txt" + + echo "" +done + +echo "========================================" +echo "All configurations: $MODE complete!" +echo "========================================" diff --git a/Testing/DeepSSMTests/run_extended_tests.sh b/Testing/DeepSSMTests/run_extended_tests.sh new file mode 100755 index 0000000000..1fa3f3afa7 --- /dev/null +++ b/Testing/DeepSSMTests/run_extended_tests.sh @@ -0,0 +1,145 @@ +#!/bin/bash +# Run extended DeepSSM tests on a directory of projects +# +# Usage: ./run_extended_tests.sh [save|verify|relaxed] [project] +# +# Arguments: +# base_dir - Directory containing project subdirectories +# mode - save: save baseline values +# verify: verify against saved baselines +# relaxed: run with loose tolerance (default) +# project - Optional: run only this project (default: all) +# +# Examples: +# ./run_extended_tests.sh /path/to/projects # Run all with relaxed check +# ./run_extended_tests.sh /path/to/projects save # Save baselines for all +# ./run_extended_tests.sh /path/to/projects verify # Verify all against baselines +# ./run_extended_tests.sh /path/to/projects save femur # Save baseline for femur only +# +# Directory structure: +# base_dir/ +# project1/ +# *.swproj +# femur/ (or other data) +# project2/ +# *.swproj +# ... + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +usage() { + echo "Usage: $0 [save|verify|relaxed] [project]" + echo "" + echo "Arguments:" + echo " base_dir - Directory containing project subdirectories" + echo " mode - save|verify|relaxed (default: relaxed)" + echo " project - Run only this project (default: all)" + echo "" + echo "Examples:" + echo " $0 /path/to/projects" + echo " $0 /path/to/projects save" + echo " $0 /path/to/projects verify" + echo " $0 /path/to/projects save femur" +} + +if [ $# -lt 1 ] || [ "$1" = "-h" ] || [ "$1" = "--help" ]; then + usage + exit 0 +fi + +BASE_DIR="$1" +MODE="${2:-relaxed}" +PROJECT="${3:-all}" + +if [ ! -d "$BASE_DIR" ]; then + echo "Error: Directory not found: $BASE_DIR" + exit 1 +fi + +if [ "$MODE" != "save" ] && [ "$MODE" != "verify" ] && [ "$MODE" != "relaxed" ]; then + echo "Error: Unknown mode: $MODE" + usage + exit 1 +fi + +run_project() { + local project_dir="$1" + local project_name="$(basename "$project_dir")" + + echo "========================================" + echo "Project: $project_name" + echo "========================================" + + # Find .swproj file + local swproj=$(find "$project_dir" -maxdepth 1 -name "*.swproj" | head -1) + if [ -z "$swproj" ]; then + echo "Warning: No .swproj file found in $project_dir, skipping" + return 0 + fi + + local swproj_name="$(basename "$swproj")" + echo "Using project file: $swproj_name" + + cd "$project_dir" + rm -rf deepssm groomed *_particles + + shapeworks deepssm --name "$swproj_name" --all + + # Verify results + local baseline_file="exact_check_${project_name}.txt" + local verify_args="" + + if [ "$MODE" = "save" ]; then + verify_args="--exact_check save --baseline_file $baseline_file" + elif [ "$MODE" = "verify" ]; then + verify_args="--exact_check verify --baseline_file $baseline_file" + else + verify_args="--expected 10 --tolerance 1.0" + fi + + python "${SCRIPT_DIR}/verify_deepssm_results.py" . $verify_args + + echo "" +} + +echo "Extended DeepSSM Tests" +echo "Base directory: ${BASE_DIR}" +echo "Mode: ${MODE}" +echo "" + +# Find all project directories (directories containing .swproj files) +ran_any=false +for project_dir in "$BASE_DIR"/*/; do + if [ ! -d "$project_dir" ]; then + continue + fi + + project_name="$(basename "$project_dir")" + + # Skip if specific project requested and this isn't it + if [ "$PROJECT" != "all" ] && [ "$PROJECT" != "$project_name" ]; then + continue + fi + + # Check if this directory has a .swproj file + if ls "$project_dir"/*.swproj 1>/dev/null 2>&1; then + run_project "$project_dir" + ran_any=true + fi +done + +if [ "$ran_any" = false ]; then + if [ "$PROJECT" = "all" ]; then + echo "Error: No projects found in $BASE_DIR" + echo "Each project should be a subdirectory containing a .swproj file." + else + echo "Error: Project not found: $PROJECT" + fi + exit 1 +fi + +echo "========================================" +echo "All projects complete!" +echo "========================================" diff --git a/Testing/DeepSSMTests/verify_deepssm_results.py b/Testing/DeepSSMTests/verify_deepssm_results.py index 4152f2f407..6375b4df27 100644 --- a/Testing/DeepSSMTests/verify_deepssm_results.py +++ b/Testing/DeepSSMTests/verify_deepssm_results.py @@ -43,6 +43,8 @@ def main(): help="Expected mean distance for relaxed check (default: 150.0)") parser.add_argument("--tolerance", type=float, default=1.0, help="Relative tolerance for relaxed check (default: 1.0 = 100%%)") + parser.add_argument("--baseline_file", type=str, default="exact_check_value.txt", + help="Filename for exact check baseline (default: exact_check_value.txt)") args = parser.parse_args() try: @@ -52,7 +54,7 @@ def main(): print(f"Error: {e}") sys.exit(1) - exact_check_file = os.path.join(args.project_dir, "exact_check_value.txt") + exact_check_file = os.path.join(args.project_dir, args.baseline_file) if args.exact_check == "save": with open(exact_check_file, "w") as f: From 22cab67aa60504abda32443bb06480bb671b4512 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 9 Jan 2026 12:01:09 -0700 Subject: [PATCH 31/47] Fix DeepSSM command arg parsing after return value fix --- Applications/shapeworks/Command.h | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Applications/shapeworks/Command.h b/Applications/shapeworks/Command.h index 8c6db366ef..a6e8b3a6a1 100644 --- a/Applications/shapeworks/Command.h +++ b/Applications/shapeworks/Command.h @@ -33,7 +33,7 @@ class Command { const std::string desc() const { return parser.description(); } /// parses the arguments for this command, saving them in the parser and returning the leftovers - std::vector parse_args(const std::vector &arguments); + virtual std::vector parse_args(const std::vector &arguments); /// calls execute for this command using the parsed args, returning system exit value int run(SharedCommandData &sharedData); @@ -108,6 +108,12 @@ class DeepSSMCommandGroup : public Command public: const std::string type() override { return "DeepSSM"; } + // DeepSSM is a terminal command - don't pass remaining args to other commands + std::vector parse_args(const std::vector &arguments) override { + Command::parse_args(arguments); + return {}; // return empty - DeepSSM consumes all args + } + private: }; From 3e05885d4aad31ca74f14b902c74eca4d5746687 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:17:02 -0700 Subject: [PATCH 32/47] Fix toMesh pipeline and add empty mesh validation - Improve toMesh() pipeline in Image.cpp: add TriangleFilter to handle degenerate cells from vtkContourFilter, CleanPolyData to remove duplicates, and ConnectivityFilter to extract largest region - Add empty mesh validation in Groom after toMesh() - Add empty segmentation check before crop operation - Check both source and reference mesh in ICP transforms - Add validation in Mesh::extractLargestComponent() for empty/degenerate cells --- Libs/Groom/Groom.cpp | 15 +++++++++++++- Libs/Image/Image.cpp | 49 +++++++++++++++++++++++++++++++++++++++++--- Libs/Mesh/Mesh.cpp | 26 +++++++++++++++++++++++ 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/Libs/Groom/Groom.cpp b/Libs/Groom/Groom.cpp index 0b0f24ab9a..9b36a96ab7 100644 --- a/Libs/Groom/Groom.cpp +++ b/Libs/Groom/Groom.cpp @@ -186,7 +186,17 @@ bool Groom::image_pipeline(std::shared_ptr subject, size_t domain) { std::string groomed_name = get_output_filename(original, DomainType::Image); if (params.get_convert_to_mesh()) { + // Use isovalue 0.0 for distance transforms (the zero level set is the surface) Mesh mesh = image.toMesh(0.0); + if (mesh.numPoints() == 0) { + throw std::runtime_error("Empty mesh generated from segmentation - segmentation may have no valid data"); + } + // Check for valid cells + auto poly_data = mesh.getVTKMesh(); + if (poly_data->GetNumberOfCells() == 0) { + throw std::runtime_error("Mesh has no cells - segmentation may have no valid surface"); + } + SW_DEBUG("Mesh after toMesh: {} points, {} cells", poly_data->GetNumberOfPoints(), poly_data->GetNumberOfCells()); run_mesh_pipeline(mesh, params, original); groomed_name = get_output_filename(original, DomainType::Mesh); // save the groomed mesh @@ -239,6 +249,9 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) { // crop if (params.get_crop()) { PhysicalRegion region = image.physicalBoundingBox(0.5); + if (!region.valid()) { + throw std::runtime_error("Empty segmentation - no voxels found above threshold for cropping"); + } image.crop(region); increment_progress(); } @@ -1336,7 +1349,7 @@ std::vector> Groom::get_icp_transforms(const std::vectorIdentity(); Mesh source = meshes[i]; - if (source.getVTKMesh()->GetNumberOfPoints() != 0) { + if (source.getVTKMesh()->GetNumberOfPoints() != 0 && reference.getVTKMesh()->GetNumberOfPoints() != 0) { // create copies for thread safety auto poly_data1 = vtkSmartPointer::New(); poly_data1->DeepCopy(source.getVTKMesh()); diff --git a/Libs/Image/Image.cpp b/Libs/Image/Image.cpp index fc4788daa7..26937abdf5 100644 --- a/Libs/Image/Image.cpp +++ b/Libs/Image/Image.cpp @@ -32,10 +32,13 @@ #include #include #include +#include #include #include #include #include +#include +#include #include #include @@ -1019,7 +1022,40 @@ Mesh Image::toMesh(PixelType isoValue) const { targetContour->SetValue(0, isoValue); targetContour->Update(); - return Mesh(targetContour->GetOutput()); + auto contourOutput = targetContour->GetOutput(); + + // Use vtkTriangleFilter FIRST to convert all polygons to proper triangles + // This removes degenerate cells that can crash downstream filters + auto triangleFilter = vtkSmartPointer::New(); + triangleFilter->SetInputData(contourOutput); + triangleFilter->PassVertsOff(); + triangleFilter->PassLinesOff(); + triangleFilter->Update(); + + // Clean the mesh to remove degenerate points and merge duplicates + auto clean = vtkSmartPointer::New(); + clean->SetInputData(triangleFilter->GetOutput()); + clean->ConvertPolysToLinesOff(); + clean->ConvertLinesToPointsOff(); + clean->ConvertStripsToPolysOff(); + clean->PointMergingOn(); + clean->SetTolerance(0.0); + clean->Update(); + + // Check if we have any data to process + auto cleanOutput = clean->GetOutput(); + if (cleanOutput->GetNumberOfPoints() == 0 || cleanOutput->GetNumberOfCells() == 0) { + // Return empty mesh + return Mesh(cleanOutput); + } + + // Use connectivity filter to extract only connected surface regions + auto connectivity = vtkSmartPointer::New(); + connectivity->SetInputData(cleanOutput); + connectivity->SetExtractionModeToLargestRegion(); + connectivity->Update(); + + return Mesh(connectivity->GetOutput()); } Image::PixelType Image::evaluate(Point p) { @@ -1170,11 +1206,18 @@ TransformPtr Image::createRigidRegistrationTransform(const Image& target_dt, flo Mesh sourceContour = toMesh(isoValue); Mesh targetContour = target_dt.toMesh(isoValue); + // Check for empty meshes before attempting ICP + if (sourceContour.numPoints() == 0 || targetContour.numPoints() == 0) { + SW_WARN("Cannot create ICP transform: source has {} points, target has {} points", + sourceContour.numPoints(), targetContour.numPoints()); + return AffineTransform::New(); + } + try { auto mat = MeshUtils::createICPTransform(sourceContour, targetContour, Mesh::Rigid, iterations); return shapeworks::createTransform(ShapeWorksUtils::convert_matrix(mat), ShapeWorksUtils::get_offset(mat)); - } catch (std::invalid_argument) { - std::cerr << "failed to create ICP transform.\n"; + } catch (std::invalid_argument& e) { + std::cerr << "failed to create ICP transform: " << e.what() << "\n"; if (sourceContour.numPoints() == 0) { std::cerr << "\tspecified isoValue (" << isoValue << ") results in an empty mesh for source\n"; } diff --git a/Libs/Mesh/Mesh.cpp b/Libs/Mesh/Mesh.cpp index 42df05e6fb..6023bbab74 100644 --- a/Libs/Mesh/Mesh.cpp +++ b/Libs/Mesh/Mesh.cpp @@ -606,6 +606,24 @@ Mesh& Mesh::fixNonManifold() { } Mesh& Mesh::extractLargestComponent() { + // Check for valid cells before attempting connectivity filter + if (poly_data_->GetNumberOfCells() == 0) { + SW_WARN("extractLargestComponent: mesh has no cells"); + return *this; + } + + // Verify mesh has at least some valid cells + bool hasValidCells = false; + for (vtkIdType i = 0; i < poly_data_->GetNumberOfCells() && !hasValidCells; i++) { + if (poly_data_->GetCellType(i) != 0) { // VTK_EMPTY_CELL = 0 + hasValidCells = true; + } + } + if (!hasValidCells) { + SW_WARN("extractLargestComponent: mesh has no valid cells (all cells are type 0)"); + return *this; + } + auto connectivityFilter = vtkSmartPointer::New(); connectivityFilter->SetExtractionModeToLargestRegion(); connectivityFilter->SetInputData(poly_data_); @@ -1603,6 +1621,14 @@ bool Mesh::compare(const Mesh& other, const double eps) const { MeshTransform Mesh::createRegistrationTransform(const Mesh& target, Mesh::AlignmentType align, unsigned iterations) const { + // Check for empty meshes before attempting ICP + if (numPoints() == 0 || target.numPoints() == 0) { + SW_WARN("Cannot create registration transform: source has {} points, target has {} points", + numPoints(), target.numPoints()); + vtkSmartPointer identity = vtkSmartPointer::New(); + identity->Identity(); + return createMeshTransform(identity); + } const vtkSmartPointer mat( MeshUtils::createICPTransform(this->poly_data_, target.getVTKMesh(), align, iterations, true)); return createMeshTransform(mat); From 3951ffe526aa2d2b85eb961f3040c2b0a0530611 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:02:30 -0700 Subject: [PATCH 33/47] Return identity transform for empty meshes in ICP When createICPTransform receives empty source or target meshes, return an identity transform with a warning instead of throwing an exception. This allows batch processing to continue gracefully when some shapes fail to generate valid meshes. --- Libs/Mesh/MeshUtils.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Libs/Mesh/MeshUtils.cpp b/Libs/Mesh/MeshUtils.cpp index fdf69958b1..1bbae90c9f 100644 --- a/Libs/Mesh/MeshUtils.cpp +++ b/Libs/Mesh/MeshUtils.cpp @@ -71,7 +71,11 @@ const vtkSmartPointer MeshUtils::createICPTransform(const Mesh sou Mesh::AlignmentType align, const unsigned iterations, bool meshTransform) { if (source.numPoints() == 0 || target.numPoints() == 0) { - throw std::invalid_argument("empty mesh passed to MeshUtils::createICPTransform"); + SW_WARN("Empty mesh in createICPTransform: source has {} points, target has {} points - returning identity", + source.numPoints(), target.numPoints()); + vtkSmartPointer identity = vtkSmartPointer::New(); + identity->Identity(); + return identity; } vtkSmartPointer icp = vtkSmartPointer::New(); From d8dba8858e31acae8b5c7d8757cadf097ec3e9ef Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:09:10 -0700 Subject: [PATCH 34/47] Add streaming data loaders to reduce DeepSSM memory usage Instead of loading all images into memory when creating DataLoaders, use streaming datasets that load images on-demand during training. This significantly reduces memory usage for large datasets. Key changes: - DeepSSMdatasetStreaming class loads images lazily from disk - Training/validation/test loaders save metadata instead of full data - load_data_loader() reconstructs loaders from metadata - get_loader_info() extracts dimensions without loading full dataset - Backward compatible with legacy pre-loaded loaders --- .../DeepSSMUtilsPackage/DeepSSMUtils/eval.py | 2 +- .../DeepSSMUtils/loaders.py | 446 ++++++++++++++++-- .../DeepSSMUtilsPackage/DeepSSMUtils/model.py | 15 +- .../DeepSSMUtils/trainer.py | 8 +- 4 files changed, 420 insertions(+), 51 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py index a850d10bd1..ee64b568d8 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py @@ -36,7 +36,7 @@ def test(config_file, loader="test"): # load the loaders sw_message("Loading " + loader + " data loader...") - test_loader = torch.load(loader_dir + loader, weights_only=False) + test_loader = loaders.load_data_loader(loader_dir + loader, loader_type='test') # initialization sw_message("Loading trained model...") diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index 7a6661e064..48391df834 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -6,12 +6,15 @@ import subprocess import torch from torch import nn -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset import shapeworks as sw from shapeworks.utils import sw_message from DeepSSMUtils import constants as C random.seed(1) +# Use streaming data loading to avoid loading all images into memory +USE_STREAMING = True + class DataLoadingError(Exception): """Raised when data loading fails.""" @@ -26,6 +29,83 @@ def make_dir(dirPath): if not os.path.exists(dirPath): os.makedirs(dirPath) + +''' +Load a DataLoader from a saved file. Handles both streaming (metadata) and legacy (full loader) formats. +''' +def load_data_loader(loader_path, loader_type='train'): + data = torch.load(loader_path, weights_only=False) + + # Check if it's streaming metadata or a full DataLoader + if isinstance(data, dict) and data.get('streaming', False): + # Reconstruct streaming DataLoader from metadata + if loader_type == 'train': + dataset = DeepSSMdatasetStreaming( + data['image_paths'], + data['scores'], + data['models'], + data['prefixes'], + data['mean_img'], + data['std_img'] + ) + return DataLoader( + dataset, + batch_size=data.get('batch_size', 1), + shuffle=True, + num_workers=data.get('num_workers', 0), + pin_memory=torch.cuda.is_available() + ) + else: + # Validation or test + dataset = DeepSSMdatasetStreaming( + data['image_paths'], + data['scores'], + data['models'], + data['names'], + data['mean_img'], + data['std_img'] + ) + return DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=data.get('num_workers', 0), + pin_memory=torch.cuda.is_available() + ) + else: + # Legacy format - data is already a DataLoader + return data + + +''' +Get dataset info (image dimensions, num_corr) from a loader file. +Works with both streaming and legacy formats. +''' +def get_loader_info(loader_path): + data = torch.load(loader_path, weights_only=False) + + if isinstance(data, dict) and data.get('streaming', False): + # Streaming format - load one image to get dimensions + image_path = data['image_paths'][0] + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + img_dims = img.shape + num_corr = len(data['models'][0]) + num_pca = len(data['scores'][0]) if data['scores'][0] != [1] else data.get('num_pca', 0) + return { + 'img_dims': img_dims, + 'num_corr': num_corr, + 'num_pca': num_pca, + 'streaming': True + } + else: + # Legacy format + return { + 'img_dims': data.dataset.img[0].shape[1:], + 'num_corr': data.dataset.mdl_target[0].shape[0], + 'num_pca': data.dataset.pca_target[0].shape[0], + 'streaming': False + } + ''' Reads csv and makes both train and validation data loaders from it ''' @@ -70,23 +150,66 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow ''' def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): sw_message("Creating training torch loader...") - # Get data make_dir(loader_dir) - images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir) - images, scores, models, prefixes = shuffle_data(images, scores, models, prefixes) - train_data = DeepSSMdataset(images, scores, models, prefixes) - # Save - trainloader = DataLoader( + + if USE_STREAMING: + # Streaming approach - don't load all images into memory + image_paths, scores, models, prefixes = get_all_train_data_streaming( + loader_dir, data_csv, down_factor, down_dir + ) + image_paths, scores, models, prefixes = shuffle_data(image_paths, scores, models, prefixes) + + # Load saved mean/std + mean_img = np.load(loader_dir + C.MEAN_IMG_FILE) + std_img = np.load(loader_dir + C.STD_IMG_FILE) + + train_data = DeepSSMdatasetStreaming( + list(image_paths), list(scores), list(models), list(prefixes), + float(mean_img), float(std_img) + ) + + # For streaming, we don't save the full DataLoader (it would try to pickle the dataset) + # Instead, save metadata that can be used to reconstruct the loader + trainloader = DataLoader( train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + C.TRAIN_LOADER - torch.save(trainloader, train_path) - sw_message("Training loader complete.") - return train_path + + # Save metadata for reconstruction + train_meta = { + 'image_paths': list(image_paths), + 'scores': list(scores), + 'models': list(models), + 'prefixes': list(prefixes), + 'mean_img': float(mean_img), + 'std_img': float(std_img), + 'batch_size': batch_size, + 'num_workers': num_workers, + 'streaming': True + } + train_path = loader_dir + C.TRAIN_LOADER + torch.save(train_meta, train_path) + sw_message("Training loader complete.") + return train_path + else: + # Legacy approach - load all into memory + images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir) + images, scores, models, prefixes = shuffle_data(images, scores, models, prefixes) + train_data = DeepSSMdataset(images, scores, models, prefixes) + trainloader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + train_path = loader_dir + C.TRAIN_LOADER + torch.save(trainloader, train_path) + sw_message("Training loader complete.") + return train_path ''' Makes validation data loader @@ -101,6 +224,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 raise DataLoadingError( f"Mismatched validation data: {len(val_img_list)} images but {len(val_particles)} particle files" ) + # Get data image_paths = [] scores = [] @@ -108,32 +232,67 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 names = [] for index in range(len(val_img_list)): image_path = val_img_list[index] - # add name prefix = get_prefix(image_path) names.append(prefix) image_paths.append(image_path) - scores.append([1]) # placeholder + scores.append([1]) # placeholder mdl = get_particles(val_particles[index]) models.append(mdl) - # Write test names to file so they are saved somewhere + + # Write validation names to file name_file = open(loader_dir + C.VALIDATION_NAMES_FILE, 'w+') name_file.write(str(names)) name_file.close() sw_message("Validation names saved to: " + loader_dir + C.VALIDATION_NAMES_FILE) - images = get_images(loader_dir, image_paths, down_factor, down_dir) - val_data = DeepSSMdataset(images, scores, models, names) - # Make loader - val_loader = DataLoader( + + if USE_STREAMING: + # Prepare image paths + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Load mean/std from training (should already exist) + mean_img = float(np.load(loader_dir + C.MEAN_IMG_FILE)) + std_img = float(np.load(loader_dir + C.STD_IMG_FILE)) + + val_data = DeepSSMdatasetStreaming(image_paths, scores, models, names, mean_img, std_img) + + val_loader = DataLoader( val_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + C.VALIDATION_LOADER - torch.save(val_loader, val_path) - sw_message("Validation loader complete.") - return val_path + + # Save metadata + val_meta = { + 'image_paths': image_paths, + 'scores': scores, + 'models': models, + 'names': names, + 'mean_img': mean_img, + 'std_img': std_img, + 'num_workers': num_workers, + 'streaming': True + } + val_path = loader_dir + C.VALIDATION_LOADER + torch.save(val_meta, val_path) + sw_message("Validation loader complete.") + return val_path + else: + # Legacy approach + images = get_images(loader_dir, image_paths, down_factor, down_dir) + val_data = DeepSSMdataset(images, scores, models, names) + val_loader = DataLoader( + val_data, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + val_path = loader_dir + C.VALIDATION_LOADER + torch.save(val_loader, val_path) + sw_message("Validation loader complete.") + return val_path ''' Makes test data loader @@ -142,44 +301,141 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num sw_message("Creating test torch loader...") if not test_img_list: raise DataLoadingError("Test image list is empty") - # get data + + # Get data image_paths = [] scores = [] models = [] test_names = [] for index in range(len(test_img_list)): image_path = test_img_list[index] - # add name prefix = get_prefix(image_path) test_names.append(prefix) image_paths.append(image_path) - # add label placeholders - scores.append([1]) - models.append([1]) - images = get_images(loader_dir, image_paths, down_factor, down_dir) - test_data = DeepSSMdataset(images, scores, models, test_names) - # Write test names to file so they are saved somewhere + scores.append([1]) # placeholder + models.append([1]) # placeholder + + # Write test names to file name_file = open(loader_dir + C.TEST_NAMES_FILE, 'w+') name_file.write(str(test_names)) name_file.close() sw_message("Test names saved to: " + loader_dir + C.TEST_NAMES_FILE) - # Make loader - testloader = DataLoader( + + if USE_STREAMING: + # Prepare image paths + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Load mean/std from training + mean_img = float(np.load(loader_dir + C.MEAN_IMG_FILE)) + std_img = float(np.load(loader_dir + C.STD_IMG_FILE)) + + test_data = DeepSSMdatasetStreaming(image_paths, scores, models, test_names, mean_img, std_img) + + testloader = DataLoader( test_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - test_path = loader_dir + C.TEST_LOADER - torch.save(testloader, test_path) - sw_message("Test loader complete.") - return test_path, test_names + + # Save metadata + test_meta = { + 'image_paths': image_paths, + 'scores': scores, + 'models': models, + 'names': test_names, + 'mean_img': mean_img, + 'std_img': std_img, + 'num_workers': num_workers, + 'streaming': True + } + test_path = loader_dir + C.TEST_LOADER + torch.save(test_meta, test_path) + sw_message("Test loader complete.") + return test_path, test_names + else: + # Legacy approach + images = get_images(loader_dir, image_paths, down_factor, down_dir) + test_data = DeepSSMdataset(images, scores, models, test_names) + testloader = DataLoader( + test_data, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + test_path = loader_dir + C.TEST_LOADER + torch.save(testloader, test_path) + sw_message("Test loader complete.") + return test_path, test_names ################################ Helper functions ###################################### ''' -returns images, scores, models, prefixes from CSV +Returns image_paths, scores, models, prefixes from CSV for streaming. +Computes mean/std incrementally without loading all images. +''' +def get_all_train_data_streaming(loader_dir, data_csv, down_factor, down_dir): + if not os.path.exists(data_csv): + raise DataLoadingError(f"CSV file not found: {data_csv}") + + image_paths = [] + scores = [] + models = [] + prefixes = [] + + try: + with open(data_csv, newline='') as csvfile: + datareader = csv.reader(csvfile) + for row_num, row in enumerate(datareader, 1): + if len(row) < 3: + raise DataLoadingError( + f"Invalid row {row_num} in {data_csv}: expected at least 3 columns " + f"(image_path, model_path, pca_scores), got {len(row)}" + ) + image_path = row[0] + model_path = row[1] + pca_scores = row[2:] + + prefix = get_prefix(image_path) + prefixes.append(prefix) + image_paths.append(image_path) + + try: + pca_scores = [float(i) for i in pca_scores] + except ValueError as e: + raise DataLoadingError( + f"Invalid PCA scores in {data_csv} at row {row_num}: {e}" + ) + scores.append(pca_scores) + + mdl = get_particles(model_path) + models.append(mdl) + except csv.Error as e: + raise DataLoadingError(f"Error parsing CSV file {data_csv}: {e}") + + if not image_paths: + raise DataLoadingError(f"CSV file is empty: {data_csv}") + + # Prepare image paths (apply downsampling if needed) + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Compute mean/std incrementally + sw_message("Computing image statistics incrementally...") + mean_img, std_img = compute_image_stats_incremental(image_paths, down_factor, down_dir) + np.save(loader_dir + C.MEAN_IMG_FILE, mean_img) + np.save(loader_dir + C.STD_IMG_FILE, std_img) + sw_message(f"Image stats: mean={mean_img:.4f}, std={std_img:.4f}") + + # Whiten PCA scores + scores = whiten_PCA_scores(scores, loader_dir) + + return image_paths, scores, models, prefixes + + +''' +returns images, scores, models, prefixes from CSV (legacy - loads all into memory) ''' def get_all_train_data(loader_dir, data_csv, down_factor, down_dir): if not os.path.exists(data_csv): @@ -238,6 +494,7 @@ def shuffle_data(images, scores, models, prefixes): ''' Class for DeepSSM datasets that works with Pytorch DataLoader +Loads all images into memory upfront (legacy approach). ''' class DeepSSMdataset(): def __init__(self, img, pca_target, mdl_target, names): @@ -254,6 +511,40 @@ def __getitem__(self, index): def __len__(self): return len(self.img) + +''' +Streaming dataset that loads images on-demand to minimize memory usage. +Only keeps file paths in memory, loads each image when accessed. +''' +class DeepSSMdatasetStreaming(Dataset): + def __init__(self, image_paths, pca_target, mdl_target, names, mean_img, std_img): + self.image_paths = image_paths + self.pca_target = torch.FloatTensor(np.array(pca_target)) + self.mdl_target = torch.FloatTensor(np.array(mdl_target)) + self.names = names + self.mean_img = mean_img + self.std_img = std_img + + def __getitem__(self, index): + # Load image on-demand + image_path = self.image_paths[index] + try: + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") + + # Normalize + img = (img - self.mean_img) / self.std_img + x = torch.FloatTensor(img).unsqueeze(0) # Add channel dimension + + y1 = self.pca_target[index] + y2 = self.mdl_target[index] + name = self.names[index] + return x, y1, y2, name + + def __len__(self): + return len(self.image_paths) + ''' returns sample prefix from path string ''' @@ -287,7 +578,86 @@ def get_particles(model_path): raise DataLoadingError(f"Error reading particle file {model_path}: {e}") ''' -reads .nrrd files and returns whitened data +Compute image mean and std incrementally without loading all images into memory. +Uses Welford's online algorithm for numerical stability. +''' +def compute_image_stats_incremental(image_list, down_factor=1, down_dir=None): + if not image_list: + raise DataLoadingError("Image list is empty") + + n = 0 + mean = 0.0 + M2 = 0.0 # Sum of squared differences from mean + + for i, image_path in enumerate(image_list): + # Handle downsampling + if down_dir is not None: + make_dir(down_dir) + img_name = os.path.basename(image_path) + res_img = os.path.join(down_dir, img_name) + if not os.path.exists(res_img): + apply_down_sample(image_path, res_img, down_factor) + image_path = res_img + + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + + try: + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") + + # Welford's online algorithm for each pixel value + for val in img.flat: + n += 1 + delta = val - mean + mean += delta / n + delta2 = val - mean + M2 += delta * delta2 + + # Free memory + del img + + if (i + 1) % 10 == 0: + sw_message(f" Computing stats: {i + 1}/{len(image_list)} images processed") + + if n < 2: + raise DataLoadingError("Need at least 2 pixel values to compute statistics") + + variance = M2 / n + std = np.sqrt(variance) + + return mean, std + + +''' +Prepare image paths, applying downsampling if needed. +Returns list of paths to use (either original or downsampled). +''' +def prepare_image_paths(image_list, down_factor=1, down_dir=None): + if not image_list: + raise DataLoadingError("Image list is empty") + + prepared_paths = [] + for image_path in image_list: + if down_dir is not None: + make_dir(down_dir) + img_name = os.path.basename(image_path) + res_img = os.path.join(down_dir, img_name) + if not os.path.exists(res_img): + apply_down_sample(image_path, res_img, down_factor) + image_path = res_img + + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + + prepared_paths.append(image_path) + + return prepared_paths + + +''' +reads .nrrd files and returns whitened data (legacy - loads all into memory) ''' def get_images(loader_dir, image_list, down_factor, down_dir): if not image_list: diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index 7d684ee62d..51d9514368 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -6,6 +6,7 @@ from collections import OrderedDict from DeepSSMUtils import net_utils from DeepSSMUtils import constants as C +from DeepSSMUtils import loaders class ConvolutionalBackbone(nn.Module): @@ -106,10 +107,9 @@ def __init__(self, config_file): parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) - self.num_corr = loader.dataset.mdl_target[0].shape[0] - img_dims = loader.dataset.img[0].shape - self.img_dims = img_dims[1:] + loader_info = loaders.get_loader_info(self.loader_dir + C.VALIDATION_LOADER) + self.num_corr = loader_info['num_corr'] + self.img_dims = loader_info['img_dims'] # encoder if parameters['encoder']['deterministic']: self.encoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir ) @@ -178,10 +178,9 @@ def __init__(self, conflict_file): parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + C.VALIDATION_LOADER, weights_only=False) - self.num_corr = loader.dataset.mdl_target[0].shape[0] - img_dims = loader.dataset.img[0].shape - self.img_dims = img_dims[1:] + loader_info = loaders.get_loader_info(self.loader_dir + C.VALIDATION_LOADER) + self.num_corr = loader_info['num_corr'] + self.img_dims = loader_info['img_dims'] self.CorrespondenceEncoder = CorrespondenceEncoder(self.num_latent, self.num_corr) self.CorrespondenceDecoder = CorrespondenceDecoder(self.num_latent, self.num_corr) self.ImageEncoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index 1dd9fcc575..0151710b81 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -108,8 +108,8 @@ def supervised_train(config_file): train_loader_path = loader_dir + C.TRAIN_LOADER validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path, weights_only=False) - val_loader = torch.load(validation_loader_path, weights_only=False) + train_loader = loaders.load_data_loader(train_loader_path, loader_type='train') + val_loader = loaders.load_data_loader(validation_loader_path, loader_type='validation') print("Done.") # initializations num_pca = train_loader.dataset.pca_target[0].shape[0] @@ -418,8 +418,8 @@ def supervised_train_tl(config_file): train_loader_path = loader_dir + C.TRAIN_LOADER validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path, weights_only=False) - val_loader = torch.load(validation_loader_path, weights_only=False) + train_loader = loaders.load_data_loader(train_loader_path, loader_type='train') + val_loader = loaders.load_data_loader(validation_loader_path, loader_type='validation') print("Done.") print("Defining model...") net = model.DeepSSMNet_TLNet(config_file) From 29f416524fd3bc018ee1b1cfca6ae72f8fc42260 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 14 Jan 2026 13:11:09 -0700 Subject: [PATCH 35/47] Fix bounding box calculation and add error handling in run_utils - Use world particle positions for bounding box calculation instead of transformed groomed meshes. World particles reflect actual aligned positions including optimization transforms. - Add periodic garbage collection during training image grooming - Add try/except around validation/test image registration to continue processing even if individual subjects fail - Skip missing validation/test images gracefully with warnings - Skip test subjects without predictions during post-processing --- .../DeepSSMUtils/run_utils.py | 242 +++++++++++------- 1 file changed, 155 insertions(+), 87 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 723795d882..b969fe2d99 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -1,6 +1,7 @@ import random import math import os +import gc import numpy as np import json @@ -155,21 +156,33 @@ def get_training_indices(project): def get_training_bounding_box(project): - """ Get the bounding box of the training subjects. """ + """ Get the bounding box of the training subjects. + + Uses world particle positions to compute the bounding box. This ensures + consistency with the actual aligned particle positions used during training, + which may include additional transforms applied during optimization that + aren't captured by get_groomed_transforms() alone. + """ subjects = project.get_subjects() training_indices = get_training_indices(project) - training_bounding_box = None - train_mesh_list = [] + + # Compute bounding box from world particles + min_pt = np.array([np.inf, np.inf, np.inf]) + max_pt = np.array([-np.inf, -np.inf, -np.inf]) + for i in training_indices: subject = subjects[i] - mesh = subject.get_groomed_clipped_mesh() - # apply transform - alignment = convert_transform_to_numpy(subject.get_groomed_transforms()[0]) - mesh.applyTransform(alignment) - train_mesh_list.append(mesh) + world_particle_files = subject.get_world_particle_filenames() + if world_particle_files: + particles = np.loadtxt(world_particle_files[0]) + min_pt = np.minimum(min_pt, particles.min(axis=0)) + max_pt = np.maximum(max_pt, particles.max(axis=0)) + + # Create bounding box from particle extents + # PhysicalRegion takes two sequences: min point and max point + bounding_box = sw.PhysicalRegion(min_pt.tolist(), max_pt.tolist()) - bounding_box = sw.MeshUtils.boundingBox(train_mesh_list).pad(10) - return bounding_box + return bounding_box.pad(10) def convert_transform_to_numpy(transform): @@ -229,14 +242,15 @@ def groom_training_images(project): f.write(bounding_box_string) sw_message("Grooming training images") - for i in get_training_indices(project): + training_indices = get_training_indices(project) + for count, i in enumerate(training_indices): if sw_check_abort(): sw_message("Aborted") return image_name = sw.utils.get_image_filename(subjects[i]) - sw_progress(i / (len(subjects) + 1), f"Grooming Training Image: {image_name}") + sw_progress(count / (len(training_indices) + 1), f"Grooming Training Image: {image_name}") image = sw.Image(image_name) subject = subjects[i] # get alignment transform @@ -257,6 +271,15 @@ def groom_training_images(project): # write image using the index of the subject image.write(deepssm_dir + f"/train_images/{i}.nrrd") + # Explicitly delete the image and run garbage collection periodically + # to prevent memory accumulation + del image + if count % 50 == 0: + gc.collect() + + # Final cleanup after processing all training images + gc.collect() + def run_data_augmentation(project, num_samples, num_dim, percent_variability, sampler, mixture_num=0, processes=1): """ Run data augmentation on the training images. """ @@ -362,86 +385,105 @@ def groom_val_test_images(project, indices): val_test_transforms = [] val_test_image_files = [] + failed_indices = [] - count = 1 - for i in val_test_indices: + for count, i in enumerate(val_test_indices): if sw_check_abort(): sw_message("Aborted") return image_name = sw.utils.get_image_filename(subjects[i]) sw_progress(count / (len(val_test_indices) + 1), - f"Grooming val/test image {image_name} ({count}/{len(val_test_indices)})") - count = count + 1 - image = sw.Image(image_name) + f"Grooming val/test image {image_name} ({count + 1}/{len(val_test_indices)})") + + try: + image = sw.Image(image_name) + + image_file = val_test_images_dir + f"{i}.nrrd" + + # check if this subject needs reflection + needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) + + # 1. Apply reflection + reflection = np.eye(4) + if needs_reflection: + reflection[axis, axis] = -1 + # account for offset + reflection[-1][0] = 2 * image.center()[0] + + image.applyTransform(reflection) + transform = sw.utils.getVTKtransform(reflection) + + # 2. Translate to have ref center to make rigid registration easier + translation = ref_center - image.center() + image.setOrigin(image.origin() + translation).write(image_file) + transform[:3, -1] += translation + + # 3. Translate with respect to slightly cropped ref + image = sw.Image(image_file).fitRegion(large_bb).write(image_file) + itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, + image_file, + transform_type='translation') + # 4. Apply transform + image.applyTransform(itk_translation_transform, + large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), + large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) + transform = np.matmul(vtk_translation_transform, transform) + + # 5. Crop with medium bounding box and find rigid transform + image.fitRegion(medium_bb).write(image_file) + itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, + image_file, transform_type='rigid') + + # 6. Apply transform + image.applyTransform(itk_rigid_transform, + medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), + medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) + transform = np.matmul(vtk_rigid_transform, transform) + + # 7. Get similarity transform from image registration and apply + image.fitRegion(bounding_box).write(image_file) + itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, + image_file, + transform_type='similarity') + image.applyTransform(itk_similarity_transform, + cropped_ref_image.origin(), cropped_ref_image.dims(), + cropped_ref_image.spacing(), cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + image.write(image_file) + vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) + transform = np.matmul(vtk_similarity_transform, transform) + + # 8. Save transform + val_test_transforms.append(transform) + extra_values = subjects[i].get_extra_values() + extra_values["registration_transform"] = transform_to_string(transform) - image_file = val_test_images_dir + f"{i}.nrrd" - - # check if this subject needs reflection - needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) - - # 1. Apply reflection - reflection = np.eye(4) - if needs_reflection: - reflection[axis, axis] = -1 - # account for offset - reflection[-1][0] = 2 * image.center()[0] - - image.applyTransform(reflection) - transform = sw.utils.getVTKtransform(reflection) - - # 2. Translate to have ref center to make rigid registration easier - translation = ref_center - image.center() - image.setOrigin(image.origin() + translation).write(image_file) - transform[:3, -1] += translation - - # 3. Translate with respect to slightly cropped ref - image = sw.Image(image_file).fitRegion(large_bb).write(image_file) - itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, - image_file, - transform_type='translation') - # 4. Apply transform - image.applyTransform(itk_translation_transform, - large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), - large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) - transform = np.matmul(vtk_translation_transform, transform) - - # 5. Crop with medium bounding box and find rigid transform - image.fitRegion(medium_bb).write(image_file) - itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, - image_file, transform_type='rigid') - - # 6. Apply transform - image.applyTransform(itk_rigid_transform, - medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), - medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) - transform = np.matmul(vtk_rigid_transform, transform) - - # 7. Get similarity transform from image registration and apply - image.fitRegion(bounding_box).write(image_file) - itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, - image_file, - transform_type='similarity') - image.applyTransform(itk_similarity_transform, - cropped_ref_image.origin(), cropped_ref_image.dims(), - cropped_ref_image.spacing(), cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - image.write(image_file) - vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) - transform = np.matmul(vtk_similarity_transform, transform) - - # 8. Save transform - val_test_transforms.append(transform) - extra_values = subjects[i].get_extra_values() - extra_values["registration_transform"] = transform_to_string(transform) + subjects[i].set_extra_values(extra_values) - subjects[i].set_extra_values(extra_values) + # Explicitly delete image and run garbage collection periodically + del image + except Exception as e: + sw_message(f"Warning: Failed to process val/test image for subject {i}: {e}") + failed_indices.append(i) + # Clean up partial file if it exists + if os.path.exists(val_test_images_dir + f"{i}.nrrd"): + os.remove(val_test_images_dir + f"{i}.nrrd") + + if count % 20 == 0: + gc.collect() + + # Final cleanup + gc.collect() project.set_subjects(subjects) + if failed_indices: + sw_message(f"Warning: {len(failed_indices)} val/test images failed to process: {failed_indices}") + def prepare_data_loaders(project, batch_size, split="all", num_workers=0): """ Prepare PyTorch laoders """ @@ -454,10 +496,17 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): val_image_files = [] val_world_particles = [] val_indices = get_split_indices(project, "val") + skipped_val = [] for i in val_indices: - val_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") - particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] - val_world_particles.append(particle_file) + image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" + if os.path.exists(image_file): + val_image_files.append(image_file) + particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] + val_world_particles.append(particle_file) + else: + skipped_val.append(i) + if skipped_val: + sw_message(f"Warning: Skipping {len(skipped_val)} missing validation images: {skipped_val}") DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers) if split == "all" or split == "train": @@ -468,8 +517,15 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): if split == "all" or split == "test": test_image_files = [] test_indices = get_split_indices(project, "test") + skipped_test = [] for i in test_indices: - test_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") + image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" + if os.path.exists(image_file): + test_image_files.append(image_file) + else: + skipped_test.append(i) + if skipped_test: + sw_message(f"Warning: Skipping {len(skipped_test)} missing test images: {skipped_test}") DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers) @@ -508,16 +564,25 @@ def process_test_predictions(project, config_file): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] - test_indices = get_split_indices(project, "test") + all_test_indices = get_split_indices(project, "test") predicted_test_local_particles = [] predicted_test_world_particles = [] test_transforms = [] test_mesh_files = [] + test_indices = [] # Only indices with valid predictions + skipped_indices = [] - for index in test_indices: + for index in all_test_indices: world_particle_file = f"{world_predictions_dir}/{index}.particles" + + # Skip subjects that don't have predictions (e.g., failed during image grooming) + if not os.path.exists(world_particle_file): + skipped_indices.append(index) + continue + print(f"world_particle_file: {world_particle_file}") + test_indices.append(index) predicted_test_world_particles.append(world_particle_file) transform = get_test_alignment_transform(project, index) @@ -534,6 +599,9 @@ def process_test_predictions(project, config_file): np.savetxt(local_particle_file, local_particles) predicted_test_local_particles.append(local_particle_file) + if skipped_indices: + sw_message(f"Warning: Skipping {len(skipped_indices)} test subjects without predictions: {skipped_indices}") + distances = eval_utils.get_mesh_distances(predicted_test_local_particles, test_mesh_files, template_particles, template_mesh, pred_dir) From 22436d1f107d59d1516d92daee18682b4c697e81 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Wed, 4 Feb 2026 10:18:16 -0700 Subject: [PATCH 36/47] Fail with clear errors instead of silently skipping missing files --- .../DeepSSMUtils/run_utils.py | 185 ++++++++---------- 1 file changed, 79 insertions(+), 106 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index b969fe2d99..77d1834fa7 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -385,7 +385,6 @@ def groom_val_test_images(project, indices): val_test_transforms = [] val_test_image_files = [] - failed_indices = [] for count, i in enumerate(val_test_indices): if sw_check_abort(): @@ -396,83 +395,76 @@ def groom_val_test_images(project, indices): sw_progress(count / (len(val_test_indices) + 1), f"Grooming val/test image {image_name} ({count + 1}/{len(val_test_indices)})") - try: - image = sw.Image(image_name) - - image_file = val_test_images_dir + f"{i}.nrrd" - - # check if this subject needs reflection - needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) - - # 1. Apply reflection - reflection = np.eye(4) - if needs_reflection: - reflection[axis, axis] = -1 - # account for offset - reflection[-1][0] = 2 * image.center()[0] - - image.applyTransform(reflection) - transform = sw.utils.getVTKtransform(reflection) - - # 2. Translate to have ref center to make rigid registration easier - translation = ref_center - image.center() - image.setOrigin(image.origin() + translation).write(image_file) - transform[:3, -1] += translation - - # 3. Translate with respect to slightly cropped ref - image = sw.Image(image_file).fitRegion(large_bb).write(image_file) - itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, - image_file, - transform_type='translation') - # 4. Apply transform - image.applyTransform(itk_translation_transform, - large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), - large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) - transform = np.matmul(vtk_translation_transform, transform) - - # 5. Crop with medium bounding box and find rigid transform - image.fitRegion(medium_bb).write(image_file) - itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, - image_file, transform_type='rigid') - - # 6. Apply transform - image.applyTransform(itk_rigid_transform, - medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), - medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) - transform = np.matmul(vtk_rigid_transform, transform) - - # 7. Get similarity transform from image registration and apply - image.fitRegion(bounding_box).write(image_file) - itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, - image_file, - transform_type='similarity') - image.applyTransform(itk_similarity_transform, - cropped_ref_image.origin(), cropped_ref_image.dims(), - cropped_ref_image.spacing(), cropped_ref_image.coordsys(), - sw.InterpolationType.Linear, meshTransform=False) - image.write(image_file) - vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) - transform = np.matmul(vtk_similarity_transform, transform) - - # 8. Save transform - val_test_transforms.append(transform) - extra_values = subjects[i].get_extra_values() - extra_values["registration_transform"] = transform_to_string(transform) + image = sw.Image(image_name) - subjects[i].set_extra_values(extra_values) + image_file = val_test_images_dir + f"{i}.nrrd" + + # check if this subject needs reflection + needs_reflection, axis = does_subject_need_reflection(project, subjects[i]) + + # 1. Apply reflection + reflection = np.eye(4) + if needs_reflection: + reflection[axis, axis] = -1 + # account for offset + reflection[-1][0] = 2 * image.center()[0] + + image.applyTransform(reflection) + transform = sw.utils.getVTKtransform(reflection) + + # 2. Translate to have ref center to make rigid registration easier + translation = ref_center - image.center() + image.setOrigin(image.origin() + translation).write(image_file) + transform[:3, -1] += translation + + # 3. Translate with respect to slightly cropped ref + image = sw.Image(image_file).fitRegion(large_bb).write(image_file) + itk_translation_transform = DeepSSMUtils.get_image_registration_transform(large_cropped_ref_image_file, + image_file, + transform_type='translation') + # 4. Apply transform + image.applyTransform(itk_translation_transform, + large_cropped_ref_image.origin(), large_cropped_ref_image.dims(), + large_cropped_ref_image.spacing(), large_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_translation_transform = sw.utils.getVTKtransform(itk_translation_transform) + transform = np.matmul(vtk_translation_transform, transform) + + # 5. Crop with medium bounding box and find rigid transform + image.fitRegion(medium_bb).write(image_file) + itk_rigid_transform = DeepSSMUtils.get_image_registration_transform(medium_cropped_ref_image_file, + image_file, transform_type='rigid') + + # 6. Apply transform + image.applyTransform(itk_rigid_transform, + medium_cropped_ref_image.origin(), medium_cropped_ref_image.dims(), + medium_cropped_ref_image.spacing(), medium_cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + vtk_rigid_transform = sw.utils.getVTKtransform(itk_rigid_transform) + transform = np.matmul(vtk_rigid_transform, transform) + + # 7. Get similarity transform from image registration and apply + image.fitRegion(bounding_box).write(image_file) + itk_similarity_transform = DeepSSMUtils.get_image_registration_transform(cropped_ref_image_file, + image_file, + transform_type='similarity') + image.applyTransform(itk_similarity_transform, + cropped_ref_image.origin(), cropped_ref_image.dims(), + cropped_ref_image.spacing(), cropped_ref_image.coordsys(), + sw.InterpolationType.Linear, meshTransform=False) + image.write(image_file) + vtk_similarity_transform = sw.utils.getVTKtransform(itk_similarity_transform) + transform = np.matmul(vtk_similarity_transform, transform) + + # 8. Save transform + val_test_transforms.append(transform) + extra_values = subjects[i].get_extra_values() + extra_values["registration_transform"] = transform_to_string(transform) + + subjects[i].set_extra_values(extra_values) - # Explicitly delete image and run garbage collection periodically - del image - except Exception as e: - sw_message(f"Warning: Failed to process val/test image for subject {i}: {e}") - failed_indices.append(i) - # Clean up partial file if it exists - if os.path.exists(val_test_images_dir + f"{i}.nrrd"): - os.remove(val_test_images_dir + f"{i}.nrrd") + # Explicitly delete image and run garbage collection periodically + del image if count % 20 == 0: gc.collect() @@ -481,9 +473,6 @@ def groom_val_test_images(project, indices): gc.collect() project.set_subjects(subjects) - if failed_indices: - sw_message(f"Warning: {len(failed_indices)} val/test images failed to process: {failed_indices}") - def prepare_data_loaders(project, batch_size, split="all", num_workers=0): """ Prepare PyTorch laoders """ @@ -496,17 +485,13 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): val_image_files = [] val_world_particles = [] val_indices = get_split_indices(project, "val") - skipped_val = [] for i in val_indices: image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" - if os.path.exists(image_file): - val_image_files.append(image_file) - particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] - val_world_particles.append(particle_file) - else: - skipped_val.append(i) - if skipped_val: - sw_message(f"Warning: Skipping {len(skipped_val)} missing validation images: {skipped_val}") + if not os.path.exists(image_file): + raise FileNotFoundError(f"Missing validation image for subject {i}: {image_file}") + val_image_files.append(image_file) + particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] + val_world_particles.append(particle_file) DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers) if split == "all" or split == "train": @@ -517,15 +502,11 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): if split == "all" or split == "test": test_image_files = [] test_indices = get_split_indices(project, "test") - skipped_test = [] for i in test_indices: image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" - if os.path.exists(image_file): - test_image_files.append(image_file) - else: - skipped_test.append(i) - if skipped_test: - sw_message(f"Warning: Skipping {len(skipped_test)} missing test images: {skipped_test}") + if not os.path.exists(image_file): + raise FileNotFoundError(f"Missing test image for subject {i}: {image_file}") + test_image_files.append(image_file) DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers) @@ -564,25 +545,20 @@ def process_test_predictions(project, config_file): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] - all_test_indices = get_split_indices(project, "test") + test_indices = get_split_indices(project, "test") predicted_test_local_particles = [] predicted_test_world_particles = [] test_transforms = [] test_mesh_files = [] - test_indices = [] # Only indices with valid predictions - skipped_indices = [] - for index in all_test_indices: + for index in test_indices: world_particle_file = f"{world_predictions_dir}/{index}.particles" - # Skip subjects that don't have predictions (e.g., failed during image grooming) if not os.path.exists(world_particle_file): - skipped_indices.append(index) - continue + raise FileNotFoundError(f"Missing prediction for test subject {index}: {world_particle_file}") print(f"world_particle_file: {world_particle_file}") - test_indices.append(index) predicted_test_world_particles.append(world_particle_file) transform = get_test_alignment_transform(project, index) @@ -599,9 +575,6 @@ def process_test_predictions(project, config_file): np.savetxt(local_particle_file, local_particles) predicted_test_local_particles.append(local_particle_file) - if skipped_indices: - sw_message(f"Warning: Skipping {len(skipped_indices)} test subjects without predictions: {skipped_indices}") - distances = eval_utils.get_mesh_distances(predicted_test_local_particles, test_mesh_files, template_particles, template_mesh, pred_dir) From 3660dd5a0c854254365a5352086d0181e231e510 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Thu, 5 Feb 2026 12:16:08 -0700 Subject: [PATCH 37/47] Reduce DeepSSM tests from 4 to 2 configurations Run only default and tl_net_fine_tune tests, which together cover all code paths (standard DeepSSM, TL-DeepSSM, and fine tuning). Cuts test time from ~3 minutes to ~90 seconds. --- Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py | 6 +++--- Testing/DeepSSMTests/DeepSSMTests.cpp | 7 +++---- Testing/DeepSSMTests/deepssm_default.sh | 2 +- Testing/DeepSSMTests/deepssm_fine_tune.sh | 2 +- Testing/DeepSSMTests/deepssm_tl_net.sh | 2 +- Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh | 2 +- Testing/DeepSSMTests/run_exact_check.sh | 2 +- Testing/DeepSSMTests/run_extended_tests.sh | 2 +- 8 files changed, 12 insertions(+), 13 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py index 86e12fc03f..638158a577 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py @@ -2,15 +2,15 @@ import SimpleITK import numpy as np -def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): - # Prepare parameter map +def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid', max_iterations=1024): + # Prepare parameter map parameter_object = itk.ParameterObject.New() parameter_map = parameter_object.GetDefaultParameterMap('rigid') if transform_type == 'similarity': parameter_map['Transform'] = ['SimilarityTransform'] elif transform_type == 'translation': parameter_map['Transform'] = ['TranslationTransform'] - parameter_map['MaximumNumberOfIterations'] = ['1024'] + parameter_map['MaximumNumberOfIterations'] = [str(max_iterations)] parameter_object.AddParameterMap(parameter_map) # Load images diff --git a/Testing/DeepSSMTests/DeepSSMTests.cpp b/Testing/DeepSSMTests/DeepSSMTests.cpp index 05f12e8299..6783e325b9 100644 --- a/Testing/DeepSSMTests/DeepSSMTests.cpp +++ b/Testing/DeepSSMTests/DeepSSMTests.cpp @@ -11,10 +11,9 @@ void run_deepssm_test(const std::string& name) { } //--------------------------------------------------------------------------- +// Run 2 configurations that cover all code paths: +// - default: standard DeepSSM +// - tl_net_fine_tune: TL-DeepSSM with fine tuning (covers both tl_net and fine_tune paths) TEST(DeepSSMTests, defaultTest) { run_deepssm_test("deepssm_default.sh"); } -TEST(DeepSSMTests, tlNetTest) { run_deepssm_test("deepssm_tl_net.sh"); } - -TEST(DeepSSMTests, fineTuneTest) { run_deepssm_test("deepssm_fine_tune.sh"); } - TEST(DeepSSMTests, tlNetFineTuneTest) { run_deepssm_test("deepssm_tl_net_fine_tune.sh"); } diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index c8a7305829..fcdac3e31c 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name default.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh index 5b991a3f84..c0e96b800a 100755 --- a/Testing/DeepSSMTests/deepssm_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name fine_tune.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh index f246158782..2ed22c47c1 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh index 70ea18f1f8..9a2d154e1a 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -14,4 +14,4 @@ rm -rf deepssm groomed *_particles shapeworks deepssm --name tl_net_fine_tune.swproj --all # Verify results -python "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/run_exact_check.sh b/Testing/DeepSSMTests/run_exact_check.sh index d4bd7ad54c..e31cb61697 100755 --- a/Testing/DeepSSMTests/run_exact_check.sh +++ b/Testing/DeepSSMTests/run_exact_check.sh @@ -33,7 +33,7 @@ for config in $CONFIGS; do shapeworks deepssm --name ${config}.swproj --all # Run exact check with config-specific file - python "${SCRIPT_DIR}/verify_deepssm_results.py" . \ + python3 "${SCRIPT_DIR}/verify_deepssm_results.py" . \ --exact_check "$MODE" \ --baseline_file "exact_check_${config}.txt" diff --git a/Testing/DeepSSMTests/run_extended_tests.sh b/Testing/DeepSSMTests/run_extended_tests.sh index 1fa3f3afa7..46e96e6e6c 100755 --- a/Testing/DeepSSMTests/run_extended_tests.sh +++ b/Testing/DeepSSMTests/run_extended_tests.sh @@ -99,7 +99,7 @@ run_project() { verify_args="--expected 10 --tolerance 1.0" fi - python "${SCRIPT_DIR}/verify_deepssm_results.py" . $verify_args + python3 "${SCRIPT_DIR}/verify_deepssm_results.py" . $verify_args echo "" } From 122599528126f09ea2a9375b201cca154c4a354f Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 6 Feb 2026 09:35:51 -0700 Subject: [PATCH 38/47] Update baselines --- Testing/data/femur1_to_2_icp.nrrd | 4 ++-- Testing/data/femur2_to_1_icp.nrrd | 4 ++-- Testing/data/la-bin.vtk | 4 ++-- Testing/data/reconstruct_mean_surface.vtk | 4 ++-- Testing/data/transforms/meshTransformWithImageTransform.vtk | 4 ++-- .../data/transforms/meshTransformWithoutImageTransform.vtk | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Testing/data/femur1_to_2_icp.nrrd b/Testing/data/femur1_to_2_icp.nrrd index a351d7392a..0321448e69 100644 --- a/Testing/data/femur1_to_2_icp.nrrd +++ b/Testing/data/femur1_to_2_icp.nrrd @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b42fb8d7611fbef0b0b3c988505059d5ce4c681b3b0e44d7115bbd96eb8ca8d1 -size 748054 +oid sha256:60ecd61c1b944ff72936c31b0c1550e5b8db9c90bf790e4068b06f92b2d643a4 +size 755252 diff --git a/Testing/data/femur2_to_1_icp.nrrd b/Testing/data/femur2_to_1_icp.nrrd index 4bd8b1a5ce..19b85103ed 100644 --- a/Testing/data/femur2_to_1_icp.nrrd +++ b/Testing/data/femur2_to_1_icp.nrrd @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6757b04eea0cde376666826beeec0328dbd71f47e57cfe0bab6551d29e3f9982 -size 758746 +oid sha256:cf585f5a63f2567caef7de3dc73a19000db835d783752d31b9e598a9c1af2692 +size 752034 diff --git a/Testing/data/la-bin.vtk b/Testing/data/la-bin.vtk index 97a4f88312..58efdb6cb2 100644 --- a/Testing/data/la-bin.vtk +++ b/Testing/data/la-bin.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:78233434f7745536768b8ff689dcd28f6e2bc4cf7d77fe6a197c1e419601bd3a -size 2196097 +oid sha256:fc7cfe8d712e7a531ca11c1f9cda50f5b4766f81eb38120d6451623e04b9d20d +size 2872943 diff --git a/Testing/data/reconstruct_mean_surface.vtk b/Testing/data/reconstruct_mean_surface.vtk index 2961524227..98f0d7c80b 100644 --- a/Testing/data/reconstruct_mean_surface.vtk +++ b/Testing/data/reconstruct_mean_surface.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:50faf83cc9d39e44a1932c3f1de5b8392e30b4715512fbd1492a6eecaa9f43cf -size 432903 +oid sha256:40daf835c37e0a1a0bef01ffe466470d4804f3e146312bf1e39f1532c6219e4b +size 432959 diff --git a/Testing/data/transforms/meshTransformWithImageTransform.vtk b/Testing/data/transforms/meshTransformWithImageTransform.vtk index 8a92a9895c..361e63ae92 100644 --- a/Testing/data/transforms/meshTransformWithImageTransform.vtk +++ b/Testing/data/transforms/meshTransformWithImageTransform.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2f7144cdb9d60dee1da5ef387d281f12e307d90759d572a5d4db3558517d1d9b -size 29754 +oid sha256:0edb7df0897be426d80259a8b72984c95ce6e960bde8199b64b1db7287dd4e7c +size 43445 diff --git a/Testing/data/transforms/meshTransformWithoutImageTransform.vtk b/Testing/data/transforms/meshTransformWithoutImageTransform.vtk index 987a0c0dd9..384b55c73b 100644 --- a/Testing/data/transforms/meshTransformWithoutImageTransform.vtk +++ b/Testing/data/transforms/meshTransformWithoutImageTransform.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:93f78c2f7b4a25547df75035fb6e035509b401406b9255430bb5c35deb658ddb -size 29754 +oid sha256:ea2d951a91a557d0d96a78bab3ed606a034f3e641e074ff89468e374e5bc8ad4 +size 43443 From 70cd306fd52816409c99a7ef53b27d6b82ed8930 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 6 Feb 2026 10:11:18 -0700 Subject: [PATCH 39/47] Fix tests --- Testing/MeshTests/MeshTests.cpp | 2 +- Testing/data/smoothsinc.vtp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Testing/MeshTests/MeshTests.cpp b/Testing/MeshTests/MeshTests.cpp index 787a65ebd7..d7d343e090 100644 --- a/Testing/MeshTests/MeshTests.cpp +++ b/Testing/MeshTests/MeshTests.cpp @@ -633,7 +633,7 @@ TEST(MeshTests, fieldTest2) { Mesh mesh(std::string(TEST_DATA_DIR) + "/la-bin.vtk"); double a = mesh.getFieldValue("scalars", 0); double b = mesh.getFieldValue("scalars", 1000); - double c = mesh.getFieldValue("Normals", 4231); + double c = mesh.getFieldValue("Normals", 12); double d = mesh.getFieldValue("Normals", 5634); ASSERT_TRUE(a == 1); diff --git a/Testing/data/smoothsinc.vtp b/Testing/data/smoothsinc.vtp index b4d36516c9..306e3fd036 100644 --- a/Testing/data/smoothsinc.vtp +++ b/Testing/data/smoothsinc.vtp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:11dff539015f11fc545839df9eb89ba61d3e3ee531638a0de4e1307d0244abd4 -size 7873251 +oid sha256:b16526bf253c9ff696d888d0c14d4917758afdaa1d662f554ba05a9604c27414 +size 7873240 From ff77d8620f5d40246e728aa994cf3c9075760ed3 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 6 Feb 2026 12:03:34 -0700 Subject: [PATCH 40/47] Fix SW_MAJOR_VERSION --- devenv.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devenv.sh b/devenv.sh index f4cc5b9177..30b4a9dc96 100644 --- a/devenv.sh +++ b/devenv.sh @@ -11,7 +11,7 @@ # compiled portion of the Python bindings). # -SW_MAJOR_VERSION=6.6 +SW_MAJOR_VERSION=6.7 (return 0 2>/dev/null) && sourced=1 || sourced=0 From 53205b7f83c7949fa0a64b11ad53988a6aed1ca0 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 6 Feb 2026 14:05:06 -0700 Subject: [PATCH 41/47] Fixes for deepssm tests --- Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py | 12 +++++++----- Testing/UseCaseTests/UseCaseTests.cpp | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 77d1834fa7..7d9f9c6677 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -481,6 +481,13 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): if not os.path.exists(loader_dir): os.makedirs(loader_dir) + # Train must run first: it computes and saves mean_img.npy/std_img.npy + # which are required by validation and test loaders. + if split == "all" or split == "train": + aug_dir = deepssm_dir + "augmentation/" + aug_data_csv = aug_dir + "TotalData.csv" + DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size, num_workers=num_workers) + if split == "all" or split == "val": val_image_files = [] val_world_particles = [] @@ -494,11 +501,6 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): val_world_particles.append(particle_file) DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers) - if split == "all" or split == "train": - aug_dir = deepssm_dir + "augmentation/" - aug_data_csv = aug_dir + "TotalData.csv" - DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size, num_workers=num_workers) - if split == "all" or split == "test": test_image_files = [] test_indices = get_split_indices(project, "test") diff --git a/Testing/UseCaseTests/UseCaseTests.cpp b/Testing/UseCaseTests/UseCaseTests.cpp index 4a4270300c..d9b890b107 100644 --- a/Testing/UseCaseTests/UseCaseTests.cpp +++ b/Testing/UseCaseTests/UseCaseTests.cpp @@ -13,7 +13,7 @@ void run_test(const std::string& name) { std::remove(outputname.c_str()); // run python - std::string command = "python RunUseCase.py " + name + " --tiny_test 1>" + outputname + " 2>&1"; + std::string command = "python RunUseCase.py " + name + " --tiny_test --clean 1>" + outputname + " 2>&1"; // use the below instead of there is some problem in getting the output // std::string command = "python RunUseCase.py " + name + " --tiny_test"; std::cerr << "Running command: " << command << "\n"; From 16add774a926a1278c7dd696adefd047e75e0823 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 6 Feb 2026 15:31:30 -0700 Subject: [PATCH 42/47] Fix CI tests --- Testing/UseCaseTests/UseCaseTests.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Testing/UseCaseTests/UseCaseTests.cpp b/Testing/UseCaseTests/UseCaseTests.cpp index d9b890b107..ca8ee646dc 100644 --- a/Testing/UseCaseTests/UseCaseTests.cpp +++ b/Testing/UseCaseTests/UseCaseTests.cpp @@ -13,7 +13,11 @@ void run_test(const std::string& name) { std::remove(outputname.c_str()); // run python - std::string command = "python RunUseCase.py " + name + " --tiny_test --clean 1>" + outputname + " 2>&1"; + // Remove status files so all steps re-run from scratch. + // Don't use --clean as it deletes pre-downloaded test data. + boost::filesystem::remove_all("Output/" + name + "/status"); + boost::filesystem::remove_all("Output/" + name + "/tiny_test_status"); + std::string command = "python RunUseCase.py " + name + " --tiny_test 1>" + outputname + " 2>&1"; // use the below instead of there is some problem in getting the output // std::string command = "python RunUseCase.py " + name + " --tiny_test"; std::cerr << "Running command: " << command << "\n"; From a18ee3a5535204916fce0aaa66470250d0e2e810 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Fri, 6 Feb 2026 22:25:02 -0700 Subject: [PATCH 43/47] Set OMP_NUM_THREADS=1 for windows CI deepssm --- Examples/Python/deep_ssm.py | 4 ++-- Testing/DeepSSMTests/deepssm_default.sh | 3 +++ Testing/DeepSSMTests/deepssm_fine_tune.sh | 3 +++ Testing/DeepSSMTests/deepssm_tl_net.sh | 3 +++ Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh | 3 +++ 5 files changed, 14 insertions(+), 2 deletions(-) diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index 3e696997b4..e3508b9673 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -64,8 +64,8 @@ def Run_Pipeline(args): This data is comprised of femur meshes and corresponding hip CT scans. """ - if platform.system() == "Darwin": - # On MacOS, CPU PyTorch is hanging with parallel + if platform.system() != "Linux": + # CPU PyTorch hangs with OpenMP parallelism on macOS and Windows os.environ['OMP_NUM_THREADS'] = "1" # If running a tiny_test, then download subset of the data if args.tiny_test: diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index fcdac3e31c..377936631b 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -2,6 +2,9 @@ # Test DeepSSM with default settings (no tl_net, no fine_tune) set -e +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Unzip test data if not already extracted diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh index c0e96b800a..ef8771f50a 100755 --- a/Testing/DeepSSMTests/deepssm_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -2,6 +2,9 @@ # Test DeepSSM with fine tuning enabled set -e +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Unzip test data if not already extracted diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh index 2ed22c47c1..b57dc3b36a 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -2,6 +2,9 @@ # Test DeepSSM with TL-DeepSSM network enabled set -e +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Unzip test data if not already extracted diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh index 9a2d154e1a..b47b1b2e41 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -2,6 +2,9 @@ # Test DeepSSM with both TL-DeepSSM and fine tuning enabled set -e +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Unzip test data if not already extracted From 5334d33fd44999cbc3355b5131f79c2f319fd477 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Sat, 7 Feb 2026 08:08:32 -0700 Subject: [PATCH 44/47] Add --aug_processes option to avoid multiprocessing hang on Windows CI --- Applications/shapeworks/Commands.cpp | 9 +++++++++ Libs/Application/DeepSSM/DeepSSMJob.cpp | 11 +++++++++-- Libs/Application/DeepSSM/DeepSSMJob.h | 4 ++++ Libs/Application/Job/PythonWorker.cpp | 5 ++--- Testing/DeepSSMTests/deepssm_default.sh | 2 +- Testing/DeepSSMTests/deepssm_fine_tune.sh | 2 +- Testing/DeepSSMTests/deepssm_tl_net.sh | 2 +- Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh | 2 +- 8 files changed, 28 insertions(+), 9 deletions(-) diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index e161d07256..c3c08df0df 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -371,6 +371,12 @@ void DeepSSMCommand::buildParser() { .set_default(0) .help("Number of data loader workers (default: 0)"); + parser.add_option("--aug_processes") + .action("store") + .type("int") + .set_default(0) + .help("Number of augmentation processes (default: 0 = use all cores)"); + Command::buildParser(); } @@ -413,12 +419,14 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& bool do_test = options.is_set("test") || options.is_set("all"); int num_workers = static_cast(options.get("num_workers")); + int aug_processes = static_cast(options.get("aug_processes")); std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n"; std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n"; std::cout << "Train step: " << (do_train ? "on" : "off") << "\n"; std::cout << "Test step: " << (do_test ? "on" : "off") << "\n"; std::cout << "Num dataloader workers: " << num_workers << "\n"; + std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << "\n"; if (!do_prep && !do_augment && !do_train && !do_test) { do_prep = true; @@ -472,6 +480,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& if (do_augment) { std::cout << "Running DeepSSM data augmentation...\n"; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType); + job->set_aug_processes(aug_processes); python_worker.run_job(job); if (!wait_for_job(job)) { return false; diff --git a/Libs/Application/DeepSSM/DeepSSMJob.cpp b/Libs/Application/DeepSSM/DeepSSMJob.cpp index ec47638877..b90002dff2 100644 --- a/Libs/Application/DeepSSM/DeepSSMJob.cpp +++ b/Libs/Application/DeepSSM/DeepSSMJob.cpp @@ -241,11 +241,12 @@ void DeepSSMJob::run_augmentation() { py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils"); py::object run_data_aug = py_deep_ssm_utils.attr("run_data_augmentation"); + int processes = aug_processes_ > 0 ? aug_processes_ : QThread::idealThreadCount(); + int aug_dims = run_data_aug(project_, params.get_aug_num_samples(), 0 /* num dims, set to zero to allow percent variability to be used */, params.get_aug_percent_variability(), sampler_type.toStdString(), 0 /* mixture_num */, - QThread::idealThreadCount() /* processes */ - ) + processes) .cast(); params.set_training_num_dims(aug_dims); @@ -394,6 +395,12 @@ void DeepSSMJob::set_num_dataloader_workers(int num_workers) { num_dataloader_wo //--------------------------------------------------------------------------- int DeepSSMJob::get_num_dataloader_workers() { return num_dataloader_workers_; } +//--------------------------------------------------------------------------- +void DeepSSMJob::set_aug_processes(int processes) { aug_processes_ = processes; } + +//--------------------------------------------------------------------------- +int DeepSSMJob::get_aug_processes() { return aug_processes_; } + //--------------------------------------------------------------------------- void DeepSSMJob::update_prep_stage(PrepStep step) { /* diff --git a/Libs/Application/DeepSSM/DeepSSMJob.h b/Libs/Application/DeepSSM/DeepSSMJob.h index b24ba753ec..021c8cd479 100644 --- a/Libs/Application/DeepSSM/DeepSSMJob.h +++ b/Libs/Application/DeepSSM/DeepSSMJob.h @@ -55,6 +55,9 @@ class DeepSSMJob : public Job { void set_num_dataloader_workers(int num_workers); int get_num_dataloader_workers(); + void set_aug_processes(int processes); + int get_aug_processes(); + void set_prep_step(DeepSSMJob::PrepStep step) { std::lock_guard lock(mutex_); prep_step_ = step; @@ -72,6 +75,7 @@ class DeepSSMJob : public Job { DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED}; int num_dataloader_workers_{0}; + int aug_processes_{0}; // mutex std::mutex mutex_; diff --git a/Libs/Application/Job/PythonWorker.cpp b/Libs/Application/Job/PythonWorker.cpp index 00a46e50fa..b062a6bd3b 100644 --- a/Libs/Application/Job/PythonWorker.cpp +++ b/Libs/Application/Job/PythonWorker.cpp @@ -207,11 +207,10 @@ bool PythonWorker::init() { path = QString::fromStdString(line); } file.close(); + qputenv("PATH", path.toUtf8()); + SW_LOG("Setting PATH for Python to: " + path.toStdString()); } - qputenv("PATH", path.toUtf8()); - SW_LOG("Setting PATH for Python to: " + path.toStdString()); - // Python 3.8+ requires explicit DLL directory registration // PATH environment variable is no longer used for DLL search SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_USER_DIRS); diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index 377936631b..9c8b6b6aa4 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -14,7 +14,7 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles -shapeworks deepssm --name default.swproj --all +shapeworks deepssm --name default.swproj --all --aug_processes 1 # Verify results python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh index ef8771f50a..c4450ae70c 100755 --- a/Testing/DeepSSMTests/deepssm_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -14,7 +14,7 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles -shapeworks deepssm --name fine_tune.swproj --all +shapeworks deepssm --name fine_tune.swproj --all --aug_processes 1 # Verify results python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh index b57dc3b36a..a36369e5a5 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -14,7 +14,7 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles -shapeworks deepssm --name tl_net.swproj --all +shapeworks deepssm --name tl_net.swproj --all --aug_processes 1 # Verify results python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh index b47b1b2e41..7eeb725606 100755 --- a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -14,7 +14,7 @@ fi cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles -shapeworks deepssm --name tl_net_fine_tune.swproj --all +shapeworks deepssm --name tl_net_fine_tune.swproj --all --aug_processes 1 # Verify results python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" From 56bb670bf42cc4ad82affeccb3fa350fb909de41 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Sat, 7 Feb 2026 09:57:56 -0700 Subject: [PATCH 45/47] Test run to debug windows CI --- Applications/shapeworks/Commands.cpp | 28 ++++++++++++------------- Testing/DeepSSMTests/CMakeLists.txt | 1 + Testing/DeepSSMTests/deepssm_default.sh | 11 ++++++++++ 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index c3c08df0df..70e23835f1 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -421,12 +421,12 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& int num_workers = static_cast(options.get("num_workers")); int aug_processes = static_cast(options.get("aug_processes")); - std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n"; - std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n"; - std::cout << "Train step: " << (do_train ? "on" : "off") << "\n"; - std::cout << "Test step: " << (do_test ? "on" : "off") << "\n"; - std::cout << "Num dataloader workers: " << num_workers << "\n"; - std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << "\n"; + std::cout << "Prep step: " << (do_prep ? "on" : "off") << std::endl; + std::cout << "Augment step: " << (do_augment ? "on" : "off") << std::endl; + std::cout << "Train step: " << (do_train ? "on" : "off") << std::endl; + std::cout << "Test step: " << (do_test ? "on" : "off") << std::endl; + std::cout << "Num dataloader workers: " << num_workers << std::endl; + std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << std::endl; if (!do_prep && !do_augment && !do_train && !do_test) { do_prep = true; @@ -470,40 +470,40 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& SW_ERROR("Unknown prep step: {}", prep_step); return false; } - std::cout << "Running DeepSSM preparation step...\n"; + std::cerr << "Running DeepSSM preparation step..." << std::endl; python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM preparation step completed.\n"; + std::cerr << "DeepSSM preparation step completed." << std::endl; } if (do_augment) { - std::cout << "Running DeepSSM data augmentation...\n"; + std::cerr << "Running DeepSSM data augmentation..." << std::endl; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType); job->set_aug_processes(aug_processes); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM data augmentation completed.\n"; + std::cerr << "DeepSSM data augmentation completed." << std::endl; } if (do_train) { - std::cout << "Running DeepSSM training...\n"; + std::cerr << "Running DeepSSM training..." << std::endl; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM training completed.\n"; + std::cerr << "DeepSSM training completed." << std::endl; } if (do_test) { - std::cout << "Running DeepSSM testing...\n"; + std::cerr << "Running DeepSSM testing..." << std::endl; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TestingType); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM testing completed.\n"; + std::cerr << "DeepSSM testing completed." << std::endl; } project->save(); diff --git a/Testing/DeepSSMTests/CMakeLists.txt b/Testing/DeepSSMTests/CMakeLists.txt index 7119af3cef..7a0c119de1 100644 --- a/Testing/DeepSSMTests/CMakeLists.txt +++ b/Testing/DeepSSMTests/CMakeLists.txt @@ -11,3 +11,4 @@ target_link_libraries(DeepSSMTests ) add_test(NAME DeepSSMTests COMMAND DeepSSMTests) +set_tests_properties(DeepSSMTests PROPERTIES TIMEOUT 1800) diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index 9c8b6b6aa4..cf7ff20318 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -5,16 +5,27 @@ set -e # Prevent PyTorch/OpenMP deadlock on macOS and Windows export OMP_NUM_THREADS=1 +echo "=== DeepSSM default test starting ===" +echo "DATA=${DATA}" +echo "OMP_NUM_THREADS=${OMP_NUM_THREADS}" + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +echo "SCRIPT_DIR=${SCRIPT_DIR}" # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then + echo "Unzipping test data..." unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" fi +echo "Changing to ${DATA}/deepssm/projects" cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles +echo "Running shapeworks deepssm..." shapeworks deepssm --name default.swproj --all --aug_processes 1 +echo "shapeworks deepssm completed" # Verify results +echo "Verifying results..." python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +echo "=== DeepSSM default test complete ===" From aede059d1d7227759dd84a1b73b5ce3d8960c344 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Sun, 8 Feb 2026 01:55:44 -0700 Subject: [PATCH 46/47] Debugging windows CI --- Applications/shapeworks/Commands.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index 70e23835f1..19ad2c0a6e 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -12,7 +12,7 @@ #include #include -#include +#include #include namespace shapeworks { @@ -381,14 +381,14 @@ void DeepSSMCommand::buildParser() { } bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) { - // Create a non-gui QApplication instance - int argc = 3; - char* argv[3]; + // QCoreApplication provides the event loop needed for PythonWorker's QThread, + // without requiring Qt platform plugins (which may not be available on headless CI). + int argc = 1; + char* argv[1]; argv[0] = const_cast("shapeworks"); - argv[1] = const_cast("-platform"); - argv[2] = const_cast("offscreen"); - QApplication app(argc, argv); + QCoreApplication app(argc, argv); + std::cerr << "QCoreApplication initialized." << std::endl; // Handle project file: either from --name or first positional argument std::string project_file; @@ -438,8 +438,10 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& ProjectHandle project = std::make_shared(); project->load(project_file); + std::cerr << "Creating PythonWorker..." << std::endl; PythonWorker python_worker; python_worker.set_cli_mode(true); + std::cerr << "PythonWorker created." << std::endl; auto wait_for_job = [&](auto job) { // This lambda will block until the job is complete From 91dbbb1f8e4c2d1df955f245d10e51efc539b206 Mon Sep 17 00:00:00 2001 From: Alan Morris Date: Sun, 8 Feb 2026 09:55:26 -0700 Subject: [PATCH 47/47] Clean up debugging. --- Applications/shapeworks/Commands.cpp | 31 +++++++++++-------------- Testing/DeepSSMTests/deepssm_default.sh | 11 --------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index 19ad2c0a6e..7bb8752786 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -388,7 +388,6 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& argv[0] = const_cast("shapeworks"); QCoreApplication app(argc, argv); - std::cerr << "QCoreApplication initialized." << std::endl; // Handle project file: either from --name or first positional argument std::string project_file; @@ -421,12 +420,12 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& int num_workers = static_cast(options.get("num_workers")); int aug_processes = static_cast(options.get("aug_processes")); - std::cout << "Prep step: " << (do_prep ? "on" : "off") << std::endl; - std::cout << "Augment step: " << (do_augment ? "on" : "off") << std::endl; - std::cout << "Train step: " << (do_train ? "on" : "off") << std::endl; - std::cout << "Test step: " << (do_test ? "on" : "off") << std::endl; - std::cout << "Num dataloader workers: " << num_workers << std::endl; - std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << std::endl; + std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n"; + std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n"; + std::cout << "Train step: " << (do_train ? "on" : "off") << "\n"; + std::cout << "Test step: " << (do_test ? "on" : "off") << "\n"; + std::cout << "Num dataloader workers: " << num_workers << "\n"; + std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << "\n"; if (!do_prep && !do_augment && !do_train && !do_test) { do_prep = true; @@ -438,10 +437,8 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& ProjectHandle project = std::make_shared(); project->load(project_file); - std::cerr << "Creating PythonWorker..." << std::endl; PythonWorker python_worker; python_worker.set_cli_mode(true); - std::cerr << "PythonWorker created." << std::endl; auto wait_for_job = [&](auto job) { // This lambda will block until the job is complete @@ -472,40 +469,40 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& SW_ERROR("Unknown prep step: {}", prep_step); return false; } - std::cerr << "Running DeepSSM preparation step..." << std::endl; + std::cout << "Running DeepSSM preparation step...\n"; python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cerr << "DeepSSM preparation step completed." << std::endl; + std::cout << "DeepSSM preparation step completed.\n"; } if (do_augment) { - std::cerr << "Running DeepSSM data augmentation..." << std::endl; + std::cout << "Running DeepSSM data augmentation...\n"; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType); job->set_aug_processes(aug_processes); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cerr << "DeepSSM data augmentation completed." << std::endl; + std::cout << "DeepSSM data augmentation completed.\n"; } if (do_train) { - std::cerr << "Running DeepSSM training..." << std::endl; + std::cout << "Running DeepSSM training...\n"; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cerr << "DeepSSM training completed." << std::endl; + std::cout << "DeepSSM training completed.\n"; } if (do_test) { - std::cerr << "Running DeepSSM testing..." << std::endl; + std::cout << "Running DeepSSM testing...\n"; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TestingType); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cerr << "DeepSSM testing completed." << std::endl; + std::cout << "DeepSSM testing completed.\n"; } project->save(); diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh index cf7ff20318..9c8b6b6aa4 100755 --- a/Testing/DeepSSMTests/deepssm_default.sh +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -5,27 +5,16 @@ set -e # Prevent PyTorch/OpenMP deadlock on macOS and Windows export OMP_NUM_THREADS=1 -echo "=== DeepSSM default test starting ===" -echo "DATA=${DATA}" -echo "OMP_NUM_THREADS=${OMP_NUM_THREADS}" - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -echo "SCRIPT_DIR=${SCRIPT_DIR}" # Unzip test data if not already extracted if [ ! -d "${DATA}/deepssm" ]; then - echo "Unzipping test data..." unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" fi -echo "Changing to ${DATA}/deepssm/projects" cd "${DATA}/deepssm/projects" rm -rf deepssm groomed *_particles -echo "Running shapeworks deepssm..." shapeworks deepssm --name default.swproj --all --aug_processes 1 -echo "shapeworks deepssm completed" # Verify results -echo "Verifying results..." python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" -echo "=== DeepSSM default test complete ==="