diff --git a/Applications/shapeworks/Command.h b/Applications/shapeworks/Command.h index 8c6db366ef..a6e8b3a6a1 100644 --- a/Applications/shapeworks/Command.h +++ b/Applications/shapeworks/Command.h @@ -33,7 +33,7 @@ class Command { const std::string desc() const { return parser.description(); } /// parses the arguments for this command, saving them in the parser and returning the leftovers - std::vector parse_args(const std::vector &arguments); + virtual std::vector parse_args(const std::vector &arguments); /// calls execute for this command using the parsed args, returning system exit value int run(SharedCommandData &sharedData); @@ -108,6 +108,12 @@ class DeepSSMCommandGroup : public Command public: const std::string type() override { return "DeepSSM"; } + // DeepSSM is a terminal command - don't pass remaining args to other commands + std::vector parse_args(const std::vector &arguments) override { + Command::parse_args(arguments); + return {}; // return empty - DeepSSM consumes all args + } + private: }; diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index da59fe0c67..70e23835f1 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -371,6 +371,12 @@ void DeepSSMCommand::buildParser() { .set_default(0) .help("Number of data loader workers (default: 0)"); + parser.add_option("--aug_processes") + .action("store") + .type("int") + .set_default(0) + .help("Number of augmentation processes (default: 0 = use all cores)"); + Command::buildParser(); } @@ -413,12 +419,14 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& bool do_test = options.is_set("test") || options.is_set("all"); int num_workers = static_cast(options.get("num_workers")); + int aug_processes = static_cast(options.get("aug_processes")); - std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n"; - std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n"; - std::cout << "Train step: " << (do_train ? "on" : "off") << "\n"; - std::cout << "Test step: " << (do_test ? "on" : "off") << "\n"; - std::cout << "Num dataloader workers: " << num_workers << "\n"; + std::cout << "Prep step: " << (do_prep ? "on" : "off") << std::endl; + std::cout << "Augment step: " << (do_augment ? "on" : "off") << std::endl; + std::cout << "Train step: " << (do_train ? "on" : "off") << std::endl; + std::cout << "Test step: " << (do_test ? "on" : "off") << std::endl; + std::cout << "Num dataloader workers: " << num_workers << std::endl; + std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << std::endl; if (!do_prep && !do_augment && !do_train && !do_test) { do_prep = true; @@ -462,44 +470,45 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& SW_ERROR("Unknown prep step: {}", prep_step); return false; } - std::cout << "Running DeepSSM preparation step...\n"; + std::cerr << "Running DeepSSM preparation step..." << std::endl; python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM preparation step completed.\n"; + std::cerr << "DeepSSM preparation step completed." << std::endl; } if (do_augment) { - std::cout << "Running DeepSSM data augmentation...\n"; + std::cerr << "Running DeepSSM data augmentation..." << std::endl; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType); + job->set_aug_processes(aug_processes); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM data augmentation completed.\n"; + std::cerr << "DeepSSM data augmentation completed." << std::endl; } if (do_train) { - std::cout << "Running DeepSSM training...\n"; + std::cerr << "Running DeepSSM training..." << std::endl; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM training completed.\n"; + std::cerr << "DeepSSM training completed." << std::endl; } if (do_test) { - std::cout << "Running DeepSSM testing...\n"; + std::cerr << "Running DeepSSM testing..." << std::endl; auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TestingType); python_worker.run_job(job); if (!wait_for_job(job)) { return false; } - std::cout << "DeepSSM testing completed.\n"; + std::cerr << "DeepSSM testing completed." << std::endl; } project->save(); - return false; + return true; } } // namespace shapeworks diff --git a/Examples/Python/RunUseCase.py b/Examples/Python/RunUseCase.py index 9394235a8f..aa1a175f50 100644 --- a/Examples/Python/RunUseCase.py +++ b/Examples/Python/RunUseCase.py @@ -69,8 +69,14 @@ parser.add_argument("--tiny_test", help="Run as a short test", action="store_true") parser.add_argument("--verify", help="Run as a full test", action="store_true") parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true") + parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)", + choices=["save", "verify"]) args = parser.parse_args() + # Validate deep_ssm-specific arguments + if args.exact_check and args.use_case != "deep_ssm": + parser.error("--exact_check is only supported for the deep_ssm use case") + type = "" if args.tiny_test: type = "tiny_test_" diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index a1fa04b330..e3508b9673 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -64,8 +64,8 @@ def Run_Pipeline(args): This data is comprised of femur meshes and corresponding hip CT scans. """ - if platform.system() == "Darwin": - # On MacOS, CPU PyTorch is hanging with parallel + if platform.system() != "Linux": + # CPU PyTorch hangs with OpenMP parallelism on macOS and Windows os.environ['OMP_NUM_THREADS'] = "1" # If running a tiny_test, then download subset of the data if args.tiny_test: @@ -396,6 +396,7 @@ def Run_Pipeline(args): "c_lat": 6.3 } } + if args.tiny_test: model_parameters["trainer"]["epochs"] = 1 # Save config file @@ -436,17 +437,17 @@ def Run_Pipeline(args): val_world_particles.append(project_path + subjects[index].get_world_particle_filenames()[0]) val_mesh_files.append(project_path + subjects[index].get_groomed_filenames()[0]) - val_out_dir = output_directory + model_name + '/validation_predictions/' predicted_val_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='validation') print("Validation world predictions saved.") - # Generate local predictions - local_val_prediction_dir = val_out_dir + 'local_predictions/' + # Generate local predictions - create directory next to world_predictions + world_pred_dir = os.path.dirname(predicted_val_world_particles[0]) + local_val_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions") if not os.path.exists(local_val_prediction_dir): os.makedirs(local_val_prediction_dir) predicted_val_local_particles = [] for particle_file, transform in zip(predicted_val_world_particles, val_transforms): particles = np.loadtxt(particle_file) - local_particle_file = particle_file.replace("world_predictions/", "local_predictions/") + local_particle_file = particle_file.replace("world_predictions", "local_predictions") local_particles = sw.utils.transformParticles(particles, transform, inverse=True) np.savetxt(local_particle_file, local_particles) predicted_val_local_particles.append(local_particle_file) @@ -462,6 +463,8 @@ def Run_Pipeline(args): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] # Get distance between clipped true and predicted meshes + # Get the validation output directory from the predictions path + val_out_dir = os.path.dirname(local_val_prediction_dir.rstrip('/')) + '/' mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_val_local_particles, val_mesh_files, template_particles, template_mesh, val_out_dir, planes=val_planes) @@ -500,17 +503,17 @@ def Run_Pipeline(args): with open(plane_file) as json_file: test_planes.append(json.load(json_file)['planes'][0]['points']) - test_out_dir = output_directory + model_name + '/test_predictions/' predicted_test_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='test') print("Test world predictions saved.") - # Generate local predictions - local_test_prediction_dir = test_out_dir + 'local_predictions/' + # Generate local predictions - create directory next to world_predictions + world_pred_dir = os.path.dirname(predicted_test_world_particles[0]) + local_test_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions") if not os.path.exists(local_test_prediction_dir): os.makedirs(local_test_prediction_dir) predicted_test_local_particles = [] for particle_file, transform in zip(predicted_test_world_particles, test_transforms): particles = np.loadtxt(particle_file) - local_particle_file = particle_file.replace("world_predictions/", "local_predictions/") + local_particle_file = particle_file.replace("world_predictions", "local_predictions") local_particles = sw.utils.transformParticles(particles, transform, inverse=True) np.savetxt(local_particle_file, local_particles) predicted_test_local_particles.append(local_particle_file) @@ -524,15 +527,17 @@ def Run_Pipeline(args): template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0] template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0] + # Get the test output directory from the predictions path + test_out_dir = os.path.dirname(local_test_prediction_dir.rstrip('/')) + '/' mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_test_local_particles, test_mesh_files, template_particles, template_mesh, test_out_dir, planes=test_planes) print("Test mean mesh surface-to-surface distance: " + str(mean_dist)) - DeepSSMUtils.process_test_predictions(project, config_file) - + final_mean_dist = DeepSSMUtils.process_test_predictions(project, config_file) + # If tiny test or verify, check results and exit - check_results(args, mean_dist) + check_results(args, final_mean_dist, output_directory) open(status_dir + "step_12.txt", 'w').close() @@ -540,12 +545,35 @@ def Run_Pipeline(args): # Verification -def check_results(args, mean_dist): +def check_results(args, mean_dist, output_directory): if args.tiny_test: print("\nVerifying use case results.") - if not math.isclose(mean_dist, 10, rel_tol=1): - print("Test failed.") - exit(-1) + + exact_check_file = output_directory + "exact_check_value.txt" + + # Exact check for refactoring verification (platform-specific) + if args.exact_check == "save": + with open(exact_check_file, "w") as f: + f.write(str(mean_dist)) + print(f"Saved exact check value to: {exact_check_file}") + print(f"Value: {mean_dist}") + elif args.exact_check == "verify": + if not os.path.exists(exact_check_file): + print(f"Error: No saved value found at {exact_check_file}") + print("Run with --exact_check save first to create baseline.") + exit(-1) + with open(exact_check_file, "r") as f: + expected_mean_dist = float(f.read().strip()) + if mean_dist != expected_mean_dist: + print(f"Exact check failed: expected {expected_mean_dist}, got {mean_dist}") + exit(-1) + print(f"Exact check passed: {mean_dist}") + else: + # Relaxed check for CI/cross-platform + if not math.isclose(mean_dist, 10, rel_tol=1): + print("Test failed.") + exit(-1) + print("Done with test, verification succeeded.") exit(0) else: diff --git a/Libs/Application/DeepSSM/DeepSSMJob.cpp b/Libs/Application/DeepSSM/DeepSSMJob.cpp index ec47638877..b90002dff2 100644 --- a/Libs/Application/DeepSSM/DeepSSMJob.cpp +++ b/Libs/Application/DeepSSM/DeepSSMJob.cpp @@ -241,11 +241,12 @@ void DeepSSMJob::run_augmentation() { py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils"); py::object run_data_aug = py_deep_ssm_utils.attr("run_data_augmentation"); + int processes = aug_processes_ > 0 ? aug_processes_ : QThread::idealThreadCount(); + int aug_dims = run_data_aug(project_, params.get_aug_num_samples(), 0 /* num dims, set to zero to allow percent variability to be used */, params.get_aug_percent_variability(), sampler_type.toStdString(), 0 /* mixture_num */, - QThread::idealThreadCount() /* processes */ - ) + processes) .cast(); params.set_training_num_dims(aug_dims); @@ -394,6 +395,12 @@ void DeepSSMJob::set_num_dataloader_workers(int num_workers) { num_dataloader_wo //--------------------------------------------------------------------------- int DeepSSMJob::get_num_dataloader_workers() { return num_dataloader_workers_; } +//--------------------------------------------------------------------------- +void DeepSSMJob::set_aug_processes(int processes) { aug_processes_ = processes; } + +//--------------------------------------------------------------------------- +int DeepSSMJob::get_aug_processes() { return aug_processes_; } + //--------------------------------------------------------------------------- void DeepSSMJob::update_prep_stage(PrepStep step) { /* diff --git a/Libs/Application/DeepSSM/DeepSSMJob.h b/Libs/Application/DeepSSM/DeepSSMJob.h index b24ba753ec..021c8cd479 100644 --- a/Libs/Application/DeepSSM/DeepSSMJob.h +++ b/Libs/Application/DeepSSM/DeepSSMJob.h @@ -55,6 +55,9 @@ class DeepSSMJob : public Job { void set_num_dataloader_workers(int num_workers); int get_num_dataloader_workers(); + void set_aug_processes(int processes); + int get_aug_processes(); + void set_prep_step(DeepSSMJob::PrepStep step) { std::lock_guard lock(mutex_); prep_step_ = step; @@ -72,6 +75,7 @@ class DeepSSMJob : public Job { DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED}; int num_dataloader_workers_{0}; + int aug_processes_{0}; // mutex std::mutex mutex_; diff --git a/Libs/Application/Job/PythonWorker.cpp b/Libs/Application/Job/PythonWorker.cpp index 00a46e50fa..b062a6bd3b 100644 --- a/Libs/Application/Job/PythonWorker.cpp +++ b/Libs/Application/Job/PythonWorker.cpp @@ -207,11 +207,10 @@ bool PythonWorker::init() { path = QString::fromStdString(line); } file.close(); + qputenv("PATH", path.toUtf8()); + SW_LOG("Setting PATH for Python to: " + path.toStdString()); } - qputenv("PATH", path.toUtf8()); - SW_LOG("Setting PATH for Python to: " + path.toStdString()); - // Python 3.8+ requires explicit DLL directory registration // PATH environment variable is no longer used for DLL search SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_USER_DIRS); diff --git a/Libs/Groom/Groom.cpp b/Libs/Groom/Groom.cpp index e5b2a4ac55..9b36a96ab7 100644 --- a/Libs/Groom/Groom.cpp +++ b/Libs/Groom/Groom.cpp @@ -186,7 +186,17 @@ bool Groom::image_pipeline(std::shared_ptr subject, size_t domain) { std::string groomed_name = get_output_filename(original, DomainType::Image); if (params.get_convert_to_mesh()) { + // Use isovalue 0.0 for distance transforms (the zero level set is the surface) Mesh mesh = image.toMesh(0.0); + if (mesh.numPoints() == 0) { + throw std::runtime_error("Empty mesh generated from segmentation - segmentation may have no valid data"); + } + // Check for valid cells + auto poly_data = mesh.getVTKMesh(); + if (poly_data->GetNumberOfCells() == 0) { + throw std::runtime_error("Mesh has no cells - segmentation may have no valid surface"); + } + SW_DEBUG("Mesh after toMesh: {} points, {} cells", poly_data->GetNumberOfPoints(), poly_data->GetNumberOfCells()); run_mesh_pipeline(mesh, params, original); groomed_name = get_output_filename(original, DomainType::Mesh); // save the groomed mesh @@ -239,6 +249,9 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) { // crop if (params.get_crop()) { PhysicalRegion region = image.physicalBoundingBox(0.5); + if (!region.valid()) { + throw std::runtime_error("Empty segmentation - no voxels found above threshold for cropping"); + } image.crop(region); increment_progress(); } @@ -560,7 +573,7 @@ bool Groom::run_alignment() { bool any_alignment = false; int reference_index = -1; - int subset_size = -1; + int subset_size = base_params.get_alignment_subset_size(); // per-domain alignment for (size_t domain = 0; domain < num_domains; domain++) { @@ -1336,7 +1349,7 @@ std::vector> Groom::get_icp_transforms(const std::vectorIdentity(); Mesh source = meshes[i]; - if (source.getVTKMesh()->GetNumberOfPoints() != 0) { + if (source.getVTKMesh()->GetNumberOfPoints() != 0 && reference.getVTKMesh()->GetNumberOfPoints() != 0) { // create copies for thread safety auto poly_data1 = vtkSmartPointer::New(); poly_data1->DeepCopy(source.getVTKMesh()); diff --git a/Libs/Image/Image.cpp b/Libs/Image/Image.cpp index fc4788daa7..26937abdf5 100644 --- a/Libs/Image/Image.cpp +++ b/Libs/Image/Image.cpp @@ -32,10 +32,13 @@ #include #include #include +#include #include #include #include #include +#include +#include #include #include @@ -1019,7 +1022,40 @@ Mesh Image::toMesh(PixelType isoValue) const { targetContour->SetValue(0, isoValue); targetContour->Update(); - return Mesh(targetContour->GetOutput()); + auto contourOutput = targetContour->GetOutput(); + + // Use vtkTriangleFilter FIRST to convert all polygons to proper triangles + // This removes degenerate cells that can crash downstream filters + auto triangleFilter = vtkSmartPointer::New(); + triangleFilter->SetInputData(contourOutput); + triangleFilter->PassVertsOff(); + triangleFilter->PassLinesOff(); + triangleFilter->Update(); + + // Clean the mesh to remove degenerate points and merge duplicates + auto clean = vtkSmartPointer::New(); + clean->SetInputData(triangleFilter->GetOutput()); + clean->ConvertPolysToLinesOff(); + clean->ConvertLinesToPointsOff(); + clean->ConvertStripsToPolysOff(); + clean->PointMergingOn(); + clean->SetTolerance(0.0); + clean->Update(); + + // Check if we have any data to process + auto cleanOutput = clean->GetOutput(); + if (cleanOutput->GetNumberOfPoints() == 0 || cleanOutput->GetNumberOfCells() == 0) { + // Return empty mesh + return Mesh(cleanOutput); + } + + // Use connectivity filter to extract only connected surface regions + auto connectivity = vtkSmartPointer::New(); + connectivity->SetInputData(cleanOutput); + connectivity->SetExtractionModeToLargestRegion(); + connectivity->Update(); + + return Mesh(connectivity->GetOutput()); } Image::PixelType Image::evaluate(Point p) { @@ -1170,11 +1206,18 @@ TransformPtr Image::createRigidRegistrationTransform(const Image& target_dt, flo Mesh sourceContour = toMesh(isoValue); Mesh targetContour = target_dt.toMesh(isoValue); + // Check for empty meshes before attempting ICP + if (sourceContour.numPoints() == 0 || targetContour.numPoints() == 0) { + SW_WARN("Cannot create ICP transform: source has {} points, target has {} points", + sourceContour.numPoints(), targetContour.numPoints()); + return AffineTransform::New(); + } + try { auto mat = MeshUtils::createICPTransform(sourceContour, targetContour, Mesh::Rigid, iterations); return shapeworks::createTransform(ShapeWorksUtils::convert_matrix(mat), ShapeWorksUtils::get_offset(mat)); - } catch (std::invalid_argument) { - std::cerr << "failed to create ICP transform.\n"; + } catch (std::invalid_argument& e) { + std::cerr << "failed to create ICP transform: " << e.what() << "\n"; if (sourceContour.numPoints() == 0) { std::cerr << "\tspecified isoValue (" << isoValue << ") results in an empty mesh for source\n"; } diff --git a/Libs/Mesh/Mesh.cpp b/Libs/Mesh/Mesh.cpp index 42df05e6fb..6023bbab74 100644 --- a/Libs/Mesh/Mesh.cpp +++ b/Libs/Mesh/Mesh.cpp @@ -606,6 +606,24 @@ Mesh& Mesh::fixNonManifold() { } Mesh& Mesh::extractLargestComponent() { + // Check for valid cells before attempting connectivity filter + if (poly_data_->GetNumberOfCells() == 0) { + SW_WARN("extractLargestComponent: mesh has no cells"); + return *this; + } + + // Verify mesh has at least some valid cells + bool hasValidCells = false; + for (vtkIdType i = 0; i < poly_data_->GetNumberOfCells() && !hasValidCells; i++) { + if (poly_data_->GetCellType(i) != 0) { // VTK_EMPTY_CELL = 0 + hasValidCells = true; + } + } + if (!hasValidCells) { + SW_WARN("extractLargestComponent: mesh has no valid cells (all cells are type 0)"); + return *this; + } + auto connectivityFilter = vtkSmartPointer::New(); connectivityFilter->SetExtractionModeToLargestRegion(); connectivityFilter->SetInputData(poly_data_); @@ -1603,6 +1621,14 @@ bool Mesh::compare(const Mesh& other, const double eps) const { MeshTransform Mesh::createRegistrationTransform(const Mesh& target, Mesh::AlignmentType align, unsigned iterations) const { + // Check for empty meshes before attempting ICP + if (numPoints() == 0 || target.numPoints() == 0) { + SW_WARN("Cannot create registration transform: source has {} points, target has {} points", + numPoints(), target.numPoints()); + vtkSmartPointer identity = vtkSmartPointer::New(); + identity->Identity(); + return createMeshTransform(identity); + } const vtkSmartPointer mat( MeshUtils::createICPTransform(this->poly_data_, target.getVTKMesh(), align, iterations, true)); return createMeshTransform(mat); diff --git a/Libs/Mesh/MeshUtils.cpp b/Libs/Mesh/MeshUtils.cpp index 2468230ecb..1bbae90c9f 100644 --- a/Libs/Mesh/MeshUtils.cpp +++ b/Libs/Mesh/MeshUtils.cpp @@ -71,7 +71,11 @@ const vtkSmartPointer MeshUtils::createICPTransform(const Mesh sou Mesh::AlignmentType align, const unsigned iterations, bool meshTransform) { if (source.numPoints() == 0 || target.numPoints() == 0) { - throw std::invalid_argument("empty mesh passed to MeshUtils::createICPTransform"); + SW_WARN("Empty mesh in createICPTransform: source has {} points, target has {} points - returning identity", + source.numPoints(), target.numPoints()); + vtkSmartPointer identity = vtkSmartPointer::New(); + identity->Identity(); + return identity; } vtkSmartPointer icp = vtkSmartPointer::New(); @@ -182,6 +186,10 @@ PhysicalRegion MeshUtils::boundingBox(const std::vector& meshes, bool cent } int MeshUtils::findReferenceMesh(std::vector& meshes, int random_subset_size) { + // auto (-1) defaults to a subset of 30 to avoid O(n^2) pairwise ICP on large datasets + if (random_subset_size < 0) { + random_subset_size = 30; + } bool use_random_subset = random_subset_size > 0 && random_subset_size < meshes.size(); int num_meshes = use_random_subset ? random_subset_size : meshes.size(); diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index 135f4f0a05..eb51ab3904 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -1,3 +1,5 @@ +from typing import List, Optional, Tuple, Any + from DeepSSMUtils import trainer from DeepSSMUtils import loaders from DeepSSMUtils import eval @@ -6,6 +8,11 @@ from DeepSSMUtils import train_viz from DeepSSMUtils import image_utils from DeepSSMUtils import run_utils +from DeepSSMUtils import net_utils +from DeepSSMUtils import constants +from DeepSSMUtils import config_validation + +from .net_utils import set_seed from .run_utils import create_split, groom_training_shapes, groom_training_images, \ run_data_augmentation, groom_val_test_images, prep_project_for_val_particles, groom_validation_shapes, \ @@ -16,65 +23,171 @@ import torch -def getTrainValLoaders(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): +def getTrainValLoaders( + loader_dir: str, + aug_data_csv: str, + batch_size: int = 1, + down_factor: float = 1, + down_dir: Optional[str] = None, + train_split: float = 0.80, + num_workers: int = 0 +) -> None: + """Create training and validation data loaders from augmented data CSV.""" testPytorch() loaders.get_train_val_loaders(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getTrainLoader(loader_dir, aug_data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): +def getTrainLoader( + loader_dir: str, + aug_data_csv: str, + batch_size: int = 1, + down_factor: float = 1, + down_dir: Optional[str] = None, + train_split: float = 0.80, + num_workers: int = 0 +) -> None: + """Create training data loader from augmented data CSV.""" testPytorch() loaders.get_train_loader(loader_dir, aug_data_csv, batch_size, down_factor, down_dir, train_split, num_workers) -def getValidationLoader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): +def getValidationLoader( + loader_dir: str, + val_img_list: List[str], + val_particles: List[str], + down_factor: float = 1, + down_dir: Optional[str] = None, + num_workers: int = 0 +) -> None: + """Create validation data loader from image and particle lists.""" loaders.get_validation_loader(loader_dir, val_img_list, val_particles, down_factor, down_dir, num_workers) -def getTestLoader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): +def getTestLoader( + loader_dir: str, + test_img_list: List[str], + down_factor: float = 1, + down_dir: Optional[str] = None, + num_workers: int = 0 +) -> None: + """Create test data loader from image list.""" loaders.get_test_loader(loader_dir, test_img_list, down_factor, down_dir, num_workers) -def prepareConfigFile(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate, - decay_lr, fine_tune, fine_tune_epochs, fine_tune_learning_rate): +def prepareConfigFile( + config_filename: str, + model_name: str, + embedded_dim: int, + out_dir: str, + loader_dir: str, + aug_dir: str, + epochs: int, + learning_rate: float, + decay_lr: bool, + fine_tune: bool, + fine_tune_epochs: int, + fine_tune_learning_rate: float +) -> None: + """Prepare a DeepSSM configuration file with the specified parameters.""" config_file.prepare_config_file(config_filename, model_name, embedded_dim, out_dir, loader_dir, aug_dir, epochs, learning_rate, decay_lr, fine_tune, fine_tune_epochs, fine_tune_learning_rate) -def trainDeepSSM(project, config_file): +def trainDeepSSM(project: Any, config_file: str) -> None: + """Train a DeepSSM model using the given project and configuration file.""" testPytorch() trainer.train(project, config_file) return -def testDeepSSM(config_file, loader="test"): +def testDeepSSM(config_file: str, loader: str = "test") -> List[str]: + """ + Test a trained DeepSSM model and return predicted particle files. + + Args: + config_file: Path to the configuration JSON file + loader: Which loader to use ("test" or "validation") + + Returns: + List of paths to predicted particle files + """ predicted_particle_files = eval.test(config_file, loader) return predicted_particle_files -def analyzeMSE(predicted_particles, true_particles): +def analyzeMSE( + predicted_particles: List[str], + true_particles: List[str] +) -> Tuple[float, float]: + """ + Analyze mean squared error between predicted and true particles. + + Returns: + Tuple of (mean_MSE, std_MSE) + """ mean_MSE, STD_MSE = eval_utils.get_MSE(predicted_particles, true_particles) return mean_MSE, STD_MSE -def analyzeMeshDistance(predicted_particles, mesh_files, template_particles, template_mesh, out_dir, planes=None): +def analyzeMeshDistance( + predicted_particles: List[str], + mesh_files: List[str], + template_particles: str, + template_mesh: str, + out_dir: str, + planes: Optional[Any] = None +) -> float: + """ + Analyze mesh distance between predicted particles and ground truth meshes. + + Returns: + Mean surface-to-surface distance + """ mean_distance = eval_utils.get_mesh_distance(predicted_particles, mesh_files, template_particles, template_mesh, out_dir, planes) return mean_distance -def analyzeResults(out_dir, DT_dir, prediction_dir, mean_prefix): +def analyzeResults( + out_dir: str, + DT_dir: str, + prediction_dir: str, + mean_prefix: str +) -> float: + """ + Analyze results by computing distance between predicted and ground truth meshes. + + Returns: + Average surface distance + """ avg_distance = eval_utils.get_distance_meshes(out_dir, DT_dir, prediction_dir, mean_prefix) return avg_distance -def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): +def get_image_registration_transform( + fixed_image_file: str, + moving_image_file: str, + transform_type: str = 'rigid' +) -> Any: + """ + Compute image registration transform between two images. + + Args: + fixed_image_file: Path to the fixed/reference image + moving_image_file: Path to the moving image to be registered + transform_type: Type of transform ('rigid', 'affine', etc.) + + Returns: + ITK transform object + """ itk_transform = image_utils.get_image_registration_transform(fixed_image_file, moving_image_file, transform_type=transform_type) return itk_transform -def testPytorch(): +def testPytorch() -> None: + """Check if PyTorch is using GPU and print a warning if not.""" if torch.cuda.is_available(): print("Running on GPU.") else: diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py new file mode 100644 index 0000000000..0725711558 --- /dev/null +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/config_validation.py @@ -0,0 +1,205 @@ +""" +Configuration file validation for DeepSSM. + +This module provides validation for DeepSSM config files to catch +errors early with clear error messages. +""" +import os +import json +from typing import Any, Dict, List, Optional + + +class ConfigValidationError(Exception): + """Raised when config validation fails.""" + pass + + +# Schema definition for DeepSSM config +CONFIG_SCHEMA = { + "model_name": {"type": str, "required": True}, + "num_latent_dim": {"type": int, "required": True, "min": 1}, + "paths": { + "type": dict, + "required": True, + "children": { + "out_dir": {"type": str, "required": True}, + "loader_dir": {"type": str, "required": True}, + "aug_dir": {"type": str, "required": True}, + } + }, + "encoder": { + "type": dict, + "required": True, + "children": { + "deterministic": {"type": bool, "required": True}, + } + }, + "decoder": { + "type": dict, + "required": True, + "children": { + "deterministic": {"type": bool, "required": True}, + "linear": {"type": bool, "required": True}, + } + }, + "loss": { + "type": dict, + "required": True, + "children": { + "function": {"type": str, "required": True, "choices": ["MSE", "Focal"]}, + "supervised_latent": {"type": bool, "required": True}, + } + }, + "trainer": { + "type": dict, + "required": True, + "children": { + "epochs": {"type": int, "required": True, "min": 1}, + "learning_rate": {"type": (int, float), "required": True, "min": 0}, + "val_freq": {"type": int, "required": True, "min": 1}, + "decay_lr": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "type": {"type": str, "required": False, "choices": ["Step", "CosineAnnealing"]}, + "parameters": {"type": dict, "required": False}, + } + }, + } + }, + "fine_tune": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "loss": {"type": str, "required": False, "choices": ["MSE", "Focal"]}, + "epochs": {"type": int, "required": False, "min": 1}, + "learning_rate": {"type": (int, float), "required": False, "min": 0}, + "val_freq": {"type": int, "required": False, "min": 1}, + } + }, + "use_best_model": {"type": bool, "required": True}, + "tl_net": { + "type": dict, + "required": True, + "children": { + "enabled": {"type": bool, "required": True}, + "ae_epochs": {"type": int, "required": False, "min": 1}, + "tf_epochs": {"type": int, "required": False, "min": 1}, + "joint_epochs": {"type": int, "required": False, "min": 1}, + "alpha": {"type": (int, float), "required": False}, + "a_ae": {"type": (int, float), "required": False}, + "c_ae": {"type": (int, float), "required": False}, + "a_lat": {"type": (int, float), "required": False}, + "c_lat": {"type": (int, float), "required": False}, + } + }, +} + + +def validate_config(config_path: str) -> Dict[str, Any]: + """ + Validate a DeepSSM configuration file. + + Args: + config_path: Path to the JSON configuration file + + Returns: + Validated configuration dictionary + + Raises: + ConfigValidationError: If validation fails + FileNotFoundError: If config file doesn't exist + json.JSONDecodeError: If config file is not valid JSON + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + + with open(config_path) as f: + try: + config = json.load(f) + except json.JSONDecodeError as e: + raise ConfigValidationError(f"Invalid JSON in config file: {e}") + + errors = _validate_dict(config, CONFIG_SCHEMA, "config") + + if errors: + error_msg = "Config validation failed:\n" + "\n".join(f" - {e}" for e in errors) + raise ConfigValidationError(error_msg) + + return config + + +def _validate_dict( + data: Dict[str, Any], + schema: Dict[str, Any], + path: str +) -> List[str]: + """ + Recursively validate a dictionary against a schema. + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + for key, rules in schema.items(): + full_path = f"{path}.{key}" + value = data.get(key) + + # Check required fields + if rules.get("required", False) and key not in data: + errors.append(f"Missing required field: {full_path}") + continue + + if key not in data: + continue + + # Check type + expected_type = rules.get("type") + if expected_type and not isinstance(value, expected_type): + type_name = expected_type.__name__ if isinstance(expected_type, type) else str(expected_type) + errors.append(f"Invalid type for {full_path}: expected {type_name}, got {type(value).__name__}") + continue + + # Check min value + if "min" in rules and isinstance(value, (int, float)): + if value < rules["min"]: + errors.append(f"Value too small for {full_path}: {value} < {rules['min']}") + + # Check choices + if "choices" in rules and value not in rules["choices"]: + errors.append(f"Invalid value for {full_path}: '{value}' not in {rules['choices']}") + + # Recurse into nested dicts + if expected_type == dict and "children" in rules: + errors.extend(_validate_dict(value, rules["children"], full_path)) + + return errors + + +def validate_paths_exist(config: Dict[str, Any], check_loader_dir: bool = True) -> List[str]: + """ + Validate that required paths in config exist. + + Args: + config: Configuration dictionary + check_loader_dir: Whether to check if loader_dir exists + + Returns: + List of warning messages for missing paths + """ + warnings = [] + paths = config.get("paths", {}) + + if check_loader_dir: + loader_dir = paths.get("loader_dir", "") + if loader_dir and not os.path.exists(loader_dir): + warnings.append(f"Loader directory does not exist: {loader_dir}") + + aug_dir = paths.get("aug_dir", "") + if aug_dir and not os.path.exists(aug_dir): + warnings.append(f"Augmentation directory does not exist: {aug_dir}") + + return warnings diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py new file mode 100644 index 0000000000..db912adcf6 --- /dev/null +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/constants.py @@ -0,0 +1,73 @@ +""" +Constants used throughout DeepSSM. + +This module centralizes magic strings and default values to improve +maintainability and reduce errors from typos. +""" + +# Model file names +BEST_MODEL_FILE = "best_model.torch" +FINAL_MODEL_FILE = "final_model.torch" +BEST_MODEL_FT_FILE = "best_model_ft.torch" +FINAL_MODEL_FT_FILE = "final_model_ft.torch" +FINAL_MODEL_AE_FILE = "final_model_ae.torch" +FINAL_MODEL_TF_FILE = "final_model_tf.torch" + +# Data loader names +TRAIN_LOADER = "train" +VALIDATION_LOADER = "validation" +TEST_LOADER = "test" + +# File names for saved statistics +MEAN_PCA_FILE = "mean_PCA.npy" +STD_PCA_FILE = "std_PCA.npy" +MEAN_IMG_FILE = "mean_img.npy" +STD_IMG_FILE = "std_img.npy" + +# Names files +TRAIN_NAMES_FILE = "train_names.txt" +VALIDATION_NAMES_FILE = "validation_names.txt" +TEST_NAMES_FILE = "test_names.txt" + +# Log and plot files +TRAIN_LOG_FILE = "train_log.csv" +TRAINING_PLOT_FILE = "training_plot.png" +TRAINING_PLOT_FT_FILE = "training_plot_ft.png" +TRAINING_PLOT_AE_FILE = "training_plot_ae.png" +TRAINING_PLOT_TF_FILE = "training_plot_tf.png" +TRAINING_PLOT_JOINT_FILE = "training_plot_joint.png" + +# PCA info directory and files +PCA_INFO_DIR = "PCA_Particle_Info" +PCA_MEAN_FILE = "mean.particles" +PCA_MODE_FILE_TEMPLATE = "pcamode{}.particles" + +# Prediction directories +WORLD_PREDICTIONS_DIR = "world_predictions" +PCA_PREDICTIONS_DIR = "pca_predictions" +LOCAL_PREDICTIONS_DIR = "local_predictions" + +# Examples directory +EXAMPLES_DIR = "examples" +TRAIN_EXAMPLES_PREFIX = "train_" +VALIDATION_EXAMPLES_PREFIX = "validation_" + +# Training stage names (for logging) +class TrainingStage: + BASE = "Base_Training" + FINE_TUNING = "Fine_Tuning" + AUTOENCODER = "AE" + T_FLANK = "T-Flank" + JOINT = "Joint" + +# Default values +class Defaults: + BATCH_SIZE = 1 + DOWN_FACTOR = 1 + TRAIN_SPLIT = 0.80 + NUM_WORKERS = 0 + VAL_FREQ = 1 + +# Device strings +DEVICE_CUDA = "cuda:0" +DEVICE_CPU = "cpu" diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py index 5f7fe30e36..ee64b568d8 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py @@ -6,6 +6,7 @@ import torch from torch.utils.data import DataLoader from DeepSSMUtils import model, loaders +from DeepSSMUtils import constants as C from shapeworks.utils import sw_message from shapeworks.utils import sw_progress from shapeworks.utils import sw_check_abort @@ -24,9 +25,9 @@ def test(config_file, loader="test"): pred_dir = model_dir + loader + '_predictions/' loaders.make_dir(pred_dir) if parameters["use_best_model"]: - model_path = model_dir + 'best_model.torch' + model_path = model_dir + C.BEST_MODEL_FILE else: - model_path = model_dir + 'final_model.torch' + model_path = model_dir + C.FINAL_MODEL_FILE if parameters["fine_tune"]["enabled"]: model_path_ft = model_path.replace(".torch", "_ft.torch") else: @@ -35,7 +36,7 @@ def test(config_file, loader="test"): # load the loaders sw_message("Loading " + loader + " data loader...") - test_loader = torch.load(loader_dir + loader, weights_only=False) + test_loader = loaders.load_data_loader(loader_dir + loader, loader_type='test') # initialization sw_message("Loading trained model...") @@ -67,9 +68,9 @@ def test(config_file, loader="test"): index = 0 pred_scores = [] - pred_path = pred_dir + 'world_predictions/' + pred_path = pred_dir + C.WORLD_PREDICTIONS_DIR + '/' loaders.make_dir(pred_path) - pred_path_pca = pred_dir + 'pca_predictions/' + pred_path_pca = pred_dir + C.PCA_PREDICTIONS_DIR + '/' loaders.make_dir(pred_path_pca) predicted_particle_files = [] @@ -86,18 +87,18 @@ def test(config_file, loader="test"): [pred_tf, pred_mdl_tl] = model_tl(mdl, img) pred_scores.append(pred_tf.cpu().data.numpy()) # save the AE latent space as shape descriptors - filename = pred_path + test_names[index] + '.npy' - np.save(filename, pred_tf.squeeze().detach().cpu().numpy()) + latent_filename = pred_path + test_names[index] + '.npy' + np.save(latent_filename, pred_tf.squeeze().detach().cpu().numpy()) np.savetxt(particle_filename, pred_mdl_tl.squeeze().detach().cpu().numpy()) else: [pred, pred_mdl_pca] = model_pca(img) [pred, pred_mdl_ft] = model_ft(img) pred_scores.append(pred.cpu().data.numpy()[0]) - filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles' - np.savetxt(filename, pred_mdl_pca.squeeze().detach().cpu().numpy()) + pca_filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles' + np.savetxt(pca_filename, pred_mdl_pca.squeeze().detach().cpu().numpy()) np.savetxt(particle_filename, pred_mdl_ft.squeeze().detach().cpu().numpy()) print("Predicted particle file: ", particle_filename) - predicted_particle_files.append(filename) + predicted_particle_files.append(particle_filename) index += 1 sw_message("Test completed.") return predicted_particle_files diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py index 86e12fc03f..638158a577 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py @@ -2,15 +2,15 @@ import SimpleITK import numpy as np -def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): - # Prepare parameter map +def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid', max_iterations=1024): + # Prepare parameter map parameter_object = itk.ParameterObject.New() parameter_map = parameter_object.GetDefaultParameterMap('rigid') if transform_type == 'similarity': parameter_map['Transform'] = ['SimilarityTransform'] elif transform_type == 'translation': parameter_map['Transform'] = ['TranslationTransform'] - parameter_map['MaximumNumberOfIterations'] = ['1024'] + parameter_map['MaximumNumberOfIterations'] = [str(max_iterations)] parameter_object.AddParameterMap(parameter_map) # Load images diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py index 5573b7e9db..48391df834 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/loaders.py @@ -6,11 +6,20 @@ import subprocess import torch from torch import nn -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset import shapeworks as sw from shapeworks.utils import sw_message +from DeepSSMUtils import constants as C random.seed(1) +# Use streaming data loading to avoid loading all images into memory +USE_STREAMING = True + + +class DataLoadingError(Exception): + """Raised when data loading fails.""" + pass + ######################## Data loading functions #################################### ''' @@ -20,6 +29,83 @@ def make_dir(dirPath): if not os.path.exists(dirPath): os.makedirs(dirPath) + +''' +Load a DataLoader from a saved file. Handles both streaming (metadata) and legacy (full loader) formats. +''' +def load_data_loader(loader_path, loader_type='train'): + data = torch.load(loader_path, weights_only=False) + + # Check if it's streaming metadata or a full DataLoader + if isinstance(data, dict) and data.get('streaming', False): + # Reconstruct streaming DataLoader from metadata + if loader_type == 'train': + dataset = DeepSSMdatasetStreaming( + data['image_paths'], + data['scores'], + data['models'], + data['prefixes'], + data['mean_img'], + data['std_img'] + ) + return DataLoader( + dataset, + batch_size=data.get('batch_size', 1), + shuffle=True, + num_workers=data.get('num_workers', 0), + pin_memory=torch.cuda.is_available() + ) + else: + # Validation or test + dataset = DeepSSMdatasetStreaming( + data['image_paths'], + data['scores'], + data['models'], + data['names'], + data['mean_img'], + data['std_img'] + ) + return DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=data.get('num_workers', 0), + pin_memory=torch.cuda.is_available() + ) + else: + # Legacy format - data is already a DataLoader + return data + + +''' +Get dataset info (image dimensions, num_corr) from a loader file. +Works with both streaming and legacy formats. +''' +def get_loader_info(loader_path): + data = torch.load(loader_path, weights_only=False) + + if isinstance(data, dict) and data.get('streaming', False): + # Streaming format - load one image to get dimensions + image_path = data['image_paths'][0] + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + img_dims = img.shape + num_corr = len(data['models'][0]) + num_pca = len(data['scores'][0]) if data['scores'][0] != [1] else data.get('num_pca', 0) + return { + 'img_dims': img_dims, + 'num_corr': num_corr, + 'num_pca': num_pca, + 'streaming': True + } + else: + # Legacy format + return { + 'img_dims': data.dataset.img[0].shape[1:], + 'num_corr': data.dataset.mdl_target[0].shape[0], + 'num_pca': data.dataset.pca_target[0].shape[0], + 'streaming': False + } + ''' Reads csv and makes both train and validation data loaders from it ''' @@ -44,7 +130,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + 'train' + train_path = loader_dir + C.TRAIN_LOADER torch.save(trainloader, train_path) validationloader = DataLoader( @@ -54,7 +140,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + 'validation' + val_path = loader_dir + C.VALIDATION_LOADER torch.save(validationloader, val_path) sw_message("Training and validation loaders complete.\n") return train_path, val_path @@ -64,29 +150,81 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow ''' def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir=None, train_split=0.80, num_workers=0): sw_message("Creating training torch loader...") - # Get data make_dir(loader_dir) - images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir) - images, scores, models, prefixes = shuffle_data(images, scores, models, prefixes) - train_data = DeepSSMdataset(images, scores, models, prefixes) - # Save - trainloader = DataLoader( + + if USE_STREAMING: + # Streaming approach - don't load all images into memory + image_paths, scores, models, prefixes = get_all_train_data_streaming( + loader_dir, data_csv, down_factor, down_dir + ) + image_paths, scores, models, prefixes = shuffle_data(image_paths, scores, models, prefixes) + + # Load saved mean/std + mean_img = np.load(loader_dir + C.MEAN_IMG_FILE) + std_img = np.load(loader_dir + C.STD_IMG_FILE) + + train_data = DeepSSMdatasetStreaming( + list(image_paths), list(scores), list(models), list(prefixes), + float(mean_img), float(std_img) + ) + + # For streaming, we don't save the full DataLoader (it would try to pickle the dataset) + # Instead, save metadata that can be used to reconstruct the loader + trainloader = DataLoader( train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - train_path = loader_dir + 'train' - torch.save(trainloader, train_path) - sw_message("Training loader complete.") - return train_path + + # Save metadata for reconstruction + train_meta = { + 'image_paths': list(image_paths), + 'scores': list(scores), + 'models': list(models), + 'prefixes': list(prefixes), + 'mean_img': float(mean_img), + 'std_img': float(std_img), + 'batch_size': batch_size, + 'num_workers': num_workers, + 'streaming': True + } + train_path = loader_dir + C.TRAIN_LOADER + torch.save(train_meta, train_path) + sw_message("Training loader complete.") + return train_path + else: + # Legacy approach - load all into memory + images, scores, models, prefixes = get_all_train_data(loader_dir, data_csv, down_factor, down_dir) + images, scores, models, prefixes = shuffle_data(images, scores, models, prefixes) + train_data = DeepSSMdataset(images, scores, models, prefixes) + trainloader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + train_path = loader_dir + C.TRAIN_LOADER + torch.save(trainloader, train_path) + sw_message("Training loader complete.") + return train_path ''' Makes validation data loader ''' def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating validation torch loader:") + if not val_img_list: + raise DataLoadingError("Validation image list is empty") + if not val_particles: + raise DataLoadingError("Validation particle list is empty") + if len(val_img_list) != len(val_particles): + raise DataLoadingError( + f"Mismatched validation data: {len(val_img_list)} images but {len(val_particles)} particle files" + ) + # Get data image_paths = [] scores = [] @@ -94,109 +232,253 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1 names = [] for index in range(len(val_img_list)): image_path = val_img_list[index] - # add name prefix = get_prefix(image_path) names.append(prefix) image_paths.append(image_path) - scores.append([1]) # placeholder + scores.append([1]) # placeholder mdl = get_particles(val_particles[index]) models.append(mdl) - # Write test names to file so they are saved somewhere - name_file = open(loader_dir + 'validation_names.txt', 'w+') + + # Write validation names to file + name_file = open(loader_dir + C.VALIDATION_NAMES_FILE, 'w+') name_file.write(str(names)) name_file.close() - sw_message("Validation names saved to: " + loader_dir + "validation_names.txt") - images = get_images(loader_dir, image_paths, down_factor, down_dir) - val_data = DeepSSMdataset(images, scores, models, names) - # Make loader - val_loader = DataLoader( + sw_message("Validation names saved to: " + loader_dir + C.VALIDATION_NAMES_FILE) + + if USE_STREAMING: + # Prepare image paths + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Load mean/std from training (should already exist) + mean_img = float(np.load(loader_dir + C.MEAN_IMG_FILE)) + std_img = float(np.load(loader_dir + C.STD_IMG_FILE)) + + val_data = DeepSSMdatasetStreaming(image_paths, scores, models, names, mean_img, std_img) + + val_loader = DataLoader( val_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - val_path = loader_dir + 'validation' - torch.save(val_loader, val_path) - sw_message("Validation loader complete.") - return val_path + + # Save metadata + val_meta = { + 'image_paths': image_paths, + 'scores': scores, + 'models': models, + 'names': names, + 'mean_img': mean_img, + 'std_img': std_img, + 'num_workers': num_workers, + 'streaming': True + } + val_path = loader_dir + C.VALIDATION_LOADER + torch.save(val_meta, val_path) + sw_message("Validation loader complete.") + return val_path + else: + # Legacy approach + images = get_images(loader_dir, image_paths, down_factor, down_dir) + val_data = DeepSSMdataset(images, scores, models, names) + val_loader = DataLoader( + val_data, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + val_path = loader_dir + C.VALIDATION_LOADER + torch.save(val_loader, val_path) + sw_message("Validation loader complete.") + return val_path ''' Makes test data loader ''' def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None, num_workers=0): sw_message("Creating test torch loader...") - # get data + if not test_img_list: + raise DataLoadingError("Test image list is empty") + + # Get data image_paths = [] scores = [] models = [] test_names = [] for index in range(len(test_img_list)): image_path = test_img_list[index] - # add name prefix = get_prefix(image_path) test_names.append(prefix) image_paths.append(image_path) - # add label placeholders - scores.append([1]) - models.append([1]) - images = get_images(loader_dir, image_paths, down_factor, down_dir) - test_data = DeepSSMdataset(images, scores, models, test_names) - # Write test names to file so they are saved somewhere - name_file = open(loader_dir + 'test_names.txt', 'w+') + scores.append([1]) # placeholder + models.append([1]) # placeholder + + # Write test names to file + name_file = open(loader_dir + C.TEST_NAMES_FILE, 'w+') name_file.write(str(test_names)) name_file.close() - sw_message("Test names saved to: " + loader_dir + "test_names.txt") - # Make loader - testloader = DataLoader( + sw_message("Test names saved to: " + loader_dir + C.TEST_NAMES_FILE) + + if USE_STREAMING: + # Prepare image paths + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Load mean/std from training + mean_img = float(np.load(loader_dir + C.MEAN_IMG_FILE)) + std_img = float(np.load(loader_dir + C.STD_IMG_FILE)) + + test_data = DeepSSMdatasetStreaming(image_paths, scores, models, test_names, mean_img, std_img) + + testloader = DataLoader( test_data, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available() ) - test_path = loader_dir + 'test' - torch.save(testloader, test_path) - sw_message("Test loader complete.") - return test_path, test_names + + # Save metadata + test_meta = { + 'image_paths': image_paths, + 'scores': scores, + 'models': models, + 'names': test_names, + 'mean_img': mean_img, + 'std_img': std_img, + 'num_workers': num_workers, + 'streaming': True + } + test_path = loader_dir + C.TEST_LOADER + torch.save(test_meta, test_path) + sw_message("Test loader complete.") + return test_path, test_names + else: + # Legacy approach + images = get_images(loader_dir, image_paths, down_factor, down_dir) + test_data = DeepSSMdataset(images, scores, models, test_names) + testloader = DataLoader( + test_data, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available() + ) + test_path = loader_dir + C.TEST_LOADER + torch.save(testloader, test_path) + sw_message("Test loader complete.") + return test_path, test_names ################################ Helper functions ###################################### ''' -returns images, scores, models, prefixes from CSV +Returns image_paths, scores, models, prefixes from CSV for streaming. +Computes mean/std incrementally without loading all images. +''' +def get_all_train_data_streaming(loader_dir, data_csv, down_factor, down_dir): + if not os.path.exists(data_csv): + raise DataLoadingError(f"CSV file not found: {data_csv}") + + image_paths = [] + scores = [] + models = [] + prefixes = [] + + try: + with open(data_csv, newline='') as csvfile: + datareader = csv.reader(csvfile) + for row_num, row in enumerate(datareader, 1): + if len(row) < 3: + raise DataLoadingError( + f"Invalid row {row_num} in {data_csv}: expected at least 3 columns " + f"(image_path, model_path, pca_scores), got {len(row)}" + ) + image_path = row[0] + model_path = row[1] + pca_scores = row[2:] + + prefix = get_prefix(image_path) + prefixes.append(prefix) + image_paths.append(image_path) + + try: + pca_scores = [float(i) for i in pca_scores] + except ValueError as e: + raise DataLoadingError( + f"Invalid PCA scores in {data_csv} at row {row_num}: {e}" + ) + scores.append(pca_scores) + + mdl = get_particles(model_path) + models.append(mdl) + except csv.Error as e: + raise DataLoadingError(f"Error parsing CSV file {data_csv}: {e}") + + if not image_paths: + raise DataLoadingError(f"CSV file is empty: {data_csv}") + + # Prepare image paths (apply downsampling if needed) + image_paths = prepare_image_paths(image_paths, down_factor, down_dir) + + # Compute mean/std incrementally + sw_message("Computing image statistics incrementally...") + mean_img, std_img = compute_image_stats_incremental(image_paths, down_factor, down_dir) + np.save(loader_dir + C.MEAN_IMG_FILE, mean_img) + np.save(loader_dir + C.STD_IMG_FILE, std_img) + sw_message(f"Image stats: mean={mean_img:.4f}, std={std_img:.4f}") + + # Whiten PCA scores + scores = whiten_PCA_scores(scores, loader_dir) + + return image_paths, scores, models, prefixes + + +''' +returns images, scores, models, prefixes from CSV (legacy - loads all into memory) ''' def get_all_train_data(loader_dir, data_csv, down_factor, down_dir): + if not os.path.exists(data_csv): + raise DataLoadingError(f"CSV file not found: {data_csv}") # get all data and targets image_paths = [] scores = [] models = [] prefixes = [] - with open(data_csv, newline='') as csvfile: - datareader = csv.reader(csvfile) - index = 0 - for row in datareader: - image_path = row[0] - model_path = row[1] - pca_scores = row[2:] - # add name - prefix = get_prefix(image_path) - # data error check - # if prefix not in get_prefix(model_path): - # print("Error: Images and particles are mismatched in csv.") - # print(f"index: {index}") - # print(f"prefix: {prefix}") - # print(f"get_prefix(model_path): {get_prefix(model_path)}}") - # exit() - prefixes.append(prefix) - # add image path - image_paths.append(image_path) - # add score (un-normalized) - pca_scores = [float(i) for i in pca_scores] - scores.append(pca_scores) - # add model - mdl = get_particles(model_path) - models.append(mdl) - index += 1 + try: + with open(data_csv, newline='') as csvfile: + datareader = csv.reader(csvfile) + for row_num, row in enumerate(datareader, 1): + if len(row) < 3: + raise DataLoadingError( + f"Invalid row {row_num} in {data_csv}: expected at least 3 columns " + f"(image_path, model_path, pca_scores), got {len(row)}" + ) + image_path = row[0] + model_path = row[1] + pca_scores = row[2:] + # add name + prefix = get_prefix(image_path) + prefixes.append(prefix) + # add image path + image_paths.append(image_path) + # add score (un-normalized) + try: + pca_scores = [float(i) for i in pca_scores] + except ValueError as e: + raise DataLoadingError( + f"Invalid PCA scores in {data_csv} at row {row_num}: {e}" + ) + scores.append(pca_scores) + # add model + mdl = get_particles(model_path) + models.append(mdl) + except csv.Error as e: + raise DataLoadingError(f"Error parsing CSV file {data_csv}: {e}") + + if not image_paths: + raise DataLoadingError(f"CSV file is empty: {data_csv}") + images = get_images(loader_dir, image_paths, down_factor, down_dir) scores = whiten_PCA_scores(scores, loader_dir) return images, scores, models, prefixes @@ -212,6 +494,7 @@ def shuffle_data(images, scores, models, prefixes): ''' Class for DeepSSM datasets that works with Pytorch DataLoader +Loads all images into memory upfront (legacy approach). ''' class DeepSSMdataset(): def __init__(self, img, pca_target, mdl_target, names): @@ -228,6 +511,40 @@ def __getitem__(self, index): def __len__(self): return len(self.img) + +''' +Streaming dataset that loads images on-demand to minimize memory usage. +Only keeps file paths in memory, loads each image when accessed. +''' +class DeepSSMdatasetStreaming(Dataset): + def __init__(self, image_paths, pca_target, mdl_target, names, mean_img, std_img): + self.image_paths = image_paths + self.pca_target = torch.FloatTensor(np.array(pca_target)) + self.mdl_target = torch.FloatTensor(np.array(mdl_target)) + self.names = names + self.mean_img = mean_img + self.std_img = std_img + + def __getitem__(self, index): + # Load image on-demand + image_path = self.image_paths[index] + try: + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") + + # Normalize + img = (img - self.mean_img) / self.std_img + x = torch.FloatTensor(img).unsqueeze(0) # Add channel dimension + + y1 = self.pca_target[index] + y2 = self.mdl_target[index] + name = self.names[index] + return x, y1, y2, name + + def __len__(self): + return len(self.image_paths) + ''' returns sample prefix from path string ''' @@ -240,18 +557,111 @@ def get_prefix(path): get list from .particles format ''' def get_particles(model_path): - f = open(model_path, "r") - data = [] - for line in f.readlines(): - points = line.split() - points = [float(i) for i in points] - data.append(points) - return(data) + if not os.path.exists(model_path): + raise DataLoadingError(f"Particle file not found: {model_path}") + try: + with open(model_path, "r") as f: + data = [] + for line_num, line in enumerate(f.readlines(), 1): + points = line.split() + try: + points = [float(i) for i in points] + except ValueError as e: + raise DataLoadingError( + f"Invalid particle data in {model_path} at line {line_num}: {e}" + ) + data.append(points) + if not data: + raise DataLoadingError(f"Particle file is empty: {model_path}") + return data + except IOError as e: + raise DataLoadingError(f"Error reading particle file {model_path}: {e}") ''' -reads .nrrd files and returns whitened data +Compute image mean and std incrementally without loading all images into memory. +Uses Welford's online algorithm for numerical stability. +''' +def compute_image_stats_incremental(image_list, down_factor=1, down_dir=None): + if not image_list: + raise DataLoadingError("Image list is empty") + + n = 0 + mean = 0.0 + M2 = 0.0 # Sum of squared differences from mean + + for i, image_path in enumerate(image_list): + # Handle downsampling + if down_dir is not None: + make_dir(down_dir) + img_name = os.path.basename(image_path) + res_img = os.path.join(down_dir, img_name) + if not os.path.exists(res_img): + apply_down_sample(image_path, res_img, down_factor) + image_path = res_img + + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + + try: + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") + + # Welford's online algorithm for each pixel value + for val in img.flat: + n += 1 + delta = val - mean + mean += delta / n + delta2 = val - mean + M2 += delta * delta2 + + # Free memory + del img + + if (i + 1) % 10 == 0: + sw_message(f" Computing stats: {i + 1}/{len(image_list)} images processed") + + if n < 2: + raise DataLoadingError("Need at least 2 pixel values to compute statistics") + + variance = M2 / n + std = np.sqrt(variance) + + return mean, std + + +''' +Prepare image paths, applying downsampling if needed. +Returns list of paths to use (either original or downsampled). +''' +def prepare_image_paths(image_list, down_factor=1, down_dir=None): + if not image_list: + raise DataLoadingError("Image list is empty") + + prepared_paths = [] + for image_path in image_list: + if down_dir is not None: + make_dir(down_dir) + img_name = os.path.basename(image_path) + res_img = os.path.join(down_dir, img_name) + if not os.path.exists(res_img): + apply_down_sample(image_path, res_img, down_factor) + image_path = res_img + + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + + prepared_paths.append(image_path) + + return prepared_paths + + +''' +reads .nrrd files and returns whitened data (legacy - loads all into memory) ''' def get_images(loader_dir, image_list, down_factor, down_dir): + if not image_list: + raise DataLoadingError("Image list is empty") # get all images all_images = [] for image_path in image_list: @@ -262,14 +672,19 @@ def get_images(loader_dir, image_list, down_factor, down_dir): if not os.path.exists(res_img): apply_down_sample(image_path, res_img, down_factor) image_path = res_img - # for_viewing returns 'F' order, i.e., transpose, needed for this array - img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + if not os.path.exists(image_path): + raise DataLoadingError(f"Image file not found: {image_path}") + try: + # for_viewing returns 'F' order, i.e., transpose, needed for this array + img = sw.Image(image_path).toArray(copy=True, for_viewing=True) + except Exception as e: + raise DataLoadingError(f"Error reading image {image_path}: {e}") all_images.append(img) all_images = np.array(all_images) # get mean and std - mean_path = loader_dir + 'mean_img.npy' - std_path = loader_dir + 'std_img.npy' + mean_path = loader_dir + C.MEAN_IMG_FILE + std_path = loader_dir + C.STD_IMG_FILE mean_image = np.mean(all_images) std_image = np.std(all_images) np.save(mean_path, mean_image) @@ -305,8 +720,8 @@ def whiten_PCA_scores(scores, loader_dir): scores = np.array(scores) mean_score = np.mean(scores, 0) std_score = np.std(scores, 0) - np.save(loader_dir + 'mean_PCA.npy', mean_score) - np.save(loader_dir + 'std_PCA.npy', std_score) + np.save(loader_dir + C.MEAN_PCA_FILE, mean_score) + np.save(loader_dir + C.STD_PCA_FILE, std_score) norm_scores = [] for score in scores: norm_scores.append((score-mean_score)/std_score) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index f512f2e244..51d9514368 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -5,6 +5,8 @@ import numpy as np from collections import OrderedDict from DeepSSMUtils import net_utils +from DeepSSMUtils import constants as C +from DeepSSMUtils import loaders class ConvolutionalBackbone(nn.Module): @@ -61,9 +63,9 @@ class DeterministicEncoder(nn.Module): def __init__(self, num_latent, img_dims, loader_dir): super(DeterministicEncoder, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device self.num_latent = num_latent self.img_dims = img_dims @@ -97,18 +99,17 @@ class DeepSSMNet(nn.Module): def __init__(self, config_file): super(DeepSSMNet, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device - with open(config_file) as json_file: + with open(config_file) as json_file: parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + "validation", weights_only=False) - self.num_corr = loader.dataset.mdl_target[0].shape[0] - img_dims = loader.dataset.img[0].shape - self.img_dims = img_dims[1:] + loader_info = loaders.get_loader_info(self.loader_dir + C.VALIDATION_LOADER) + self.num_corr = loader_info['num_corr'] + self.img_dims = loader_info['img_dims'] # encoder if parameters['encoder']['deterministic']: self.encoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir ) @@ -169,18 +170,17 @@ class DeepSSMNet_TLNet(nn.Module): def __init__(self, conflict_file): super(DeepSSMNet_TLNet, self).__init__() if torch.cuda.is_available(): - device = 'cuda:0' + device = C.DEVICE_CUDA else: - device = 'cpu' + device = C.DEVICE_CPU self.device = device - with open(conflict_file) as json_file: + with open(conflict_file) as json_file: parameters = json.load(json_file) self.num_latent = parameters['num_latent_dim'] self.loader_dir = parameters['paths']['loader_dir'] - loader = torch.load(self.loader_dir + "validation") - self.num_corr = loader.dataset.mdl_target[0].shape[0] - img_dims = loader.dataset.img[0].shape - self.img_dims = img_dims[1:] + loader_info = loaders.get_loader_info(self.loader_dir + C.VALIDATION_LOADER) + self.num_corr = loader_info['num_corr'] + self.img_dims = loader_info['img_dims'] self.CorrespondenceEncoder = CorrespondenceEncoder(self.num_latent, self.num_corr) self.CorrespondenceDecoder = CorrespondenceDecoder(self.num_latent, self.num_corr) self.ImageEncoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py index 3ffa0a9014..447d60bdb9 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/net_utils.py @@ -1,21 +1,79 @@ +import random import torch from torch import nn import numpy as np +from DeepSSMUtils import constants as C + + +def set_seed(seed: int = 42) -> None: + """ + Set random seeds for reproducibility across all random number generators. + + Args: + seed: Integer seed value for random number generators + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + class Flatten(nn.Module): - def forward(self, x): + """Flatten layer to reshape tensor for fully connected layers.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(x.size(0), -1) -def poolOutDim(inDim, kernel_size, padding=0, stride=0, dilation=1): + +def poolOutDim( + inDim: int, + kernel_size: int, + padding: int = 0, + stride: int = 0, + dilation: int = 1 +) -> int: + """ + Calculate output dimension after pooling operation. + + Args: + inDim: Input dimension size + kernel_size: Size of the pooling kernel + padding: Padding applied to input + stride: Stride of pooling (defaults to kernel_size if 0) + dilation: Dilation factor + + Returns: + Output dimension size after pooling + """ if stride == 0: stride = kernel_size num = inDim + 2*padding - dilation*(kernel_size - 1) - 1 outDim = int(np.floor(num/stride + 1)) return outDim -def unwhiten_PCA_scores(torch_loading, loader_dir, device): - mean_score = torch.from_numpy(np.load(loader_dir + '/mean_PCA.npy')).to(device).float() - std_score = torch.from_numpy(np.load(loader_dir + '/std_PCA.npy')).to(device).float() + +def unwhiten_PCA_scores( + torch_loading: torch.Tensor, + loader_dir: str, + device: str +) -> torch.Tensor: + """ + Unwhiten (denormalize) PCA scores using saved mean and std. + + Args: + torch_loading: Whitened PCA scores tensor + loader_dir: Directory containing mean_PCA.npy and std_PCA.npy + device: Device to load tensors to ('cuda:0' or 'cpu') + + Returns: + Unwhitened PCA scores tensor + """ + mean_score = torch.from_numpy(np.load(loader_dir + '/' + C.MEAN_PCA_FILE)).to(device).float() + std_score = torch.from_numpy(np.load(loader_dir + '/' + C.STD_PCA_FILE)).to(device).float() mean_score = mean_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) std_score = std_score.unsqueeze(0).repeat(torch_loading.shape[0], 1) pca_new = torch_loading*(std_score) + mean_score diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 7de1bd1e2c..7d9f9c6677 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -1,6 +1,7 @@ import random import math import os +import gc import numpy as np import json @@ -155,21 +156,33 @@ def get_training_indices(project): def get_training_bounding_box(project): - """ Get the bounding box of the training subjects. """ + """ Get the bounding box of the training subjects. + + Uses world particle positions to compute the bounding box. This ensures + consistency with the actual aligned particle positions used during training, + which may include additional transforms applied during optimization that + aren't captured by get_groomed_transforms() alone. + """ subjects = project.get_subjects() training_indices = get_training_indices(project) - training_bounding_box = None - train_mesh_list = [] + + # Compute bounding box from world particles + min_pt = np.array([np.inf, np.inf, np.inf]) + max_pt = np.array([-np.inf, -np.inf, -np.inf]) + for i in training_indices: subject = subjects[i] - mesh = subject.get_groomed_clipped_mesh() - # apply transform - alignment = convert_transform_to_numpy(subject.get_groomed_transforms()[0]) - mesh.applyTransform(alignment) - train_mesh_list.append(mesh) + world_particle_files = subject.get_world_particle_filenames() + if world_particle_files: + particles = np.loadtxt(world_particle_files[0]) + min_pt = np.minimum(min_pt, particles.min(axis=0)) + max_pt = np.maximum(max_pt, particles.max(axis=0)) - bounding_box = sw.MeshUtils.boundingBox(train_mesh_list).pad(10) - return bounding_box + # Create bounding box from particle extents + # PhysicalRegion takes two sequences: min point and max point + bounding_box = sw.PhysicalRegion(min_pt.tolist(), max_pt.tolist()) + + return bounding_box.pad(10) def convert_transform_to_numpy(transform): @@ -229,14 +242,15 @@ def groom_training_images(project): f.write(bounding_box_string) sw_message("Grooming training images") - for i in get_training_indices(project): + training_indices = get_training_indices(project) + for count, i in enumerate(training_indices): if sw_check_abort(): sw_message("Aborted") return image_name = sw.utils.get_image_filename(subjects[i]) - sw_progress(i / (len(subjects) + 1), f"Grooming Training Image: {image_name}") + sw_progress(count / (len(training_indices) + 1), f"Grooming Training Image: {image_name}") image = sw.Image(image_name) subject = subjects[i] # get alignment transform @@ -257,6 +271,15 @@ def groom_training_images(project): # write image using the index of the subject image.write(deepssm_dir + f"/train_images/{i}.nrrd") + # Explicitly delete the image and run garbage collection periodically + # to prevent memory accumulation + del image + if count % 50 == 0: + gc.collect() + + # Final cleanup after processing all training images + gc.collect() + def run_data_augmentation(project, num_samples, num_dim, percent_variability, sampler, mixture_num=0, processes=1): """ Run data augmentation on the training images. """ @@ -363,16 +386,15 @@ def groom_val_test_images(project, indices): val_test_transforms = [] val_test_image_files = [] - count = 1 - for i in val_test_indices: + for count, i in enumerate(val_test_indices): if sw_check_abort(): sw_message("Aborted") return image_name = sw.utils.get_image_filename(subjects[i]) sw_progress(count / (len(val_test_indices) + 1), - f"Grooming val/test image {image_name} ({count}/{len(val_test_indices)})") - count = count + 1 + f"Grooming val/test image {image_name} ({count + 1}/{len(val_test_indices)})") + image = sw.Image(image_name) image_file = val_test_images_dir + f"{i}.nrrd" @@ -440,6 +462,15 @@ def groom_val_test_images(project, indices): extra_values["registration_transform"] = transform_to_string(transform) subjects[i].set_extra_values(extra_values) + + # Explicitly delete image and run garbage collection periodically + del image + + if count % 20 == 0: + gc.collect() + + # Final cleanup + gc.collect() project.set_subjects(subjects) @@ -450,26 +481,34 @@ def prepare_data_loaders(project, batch_size, split="all", num_workers=0): if not os.path.exists(loader_dir): os.makedirs(loader_dir) + # Train must run first: it computes and saves mean_img.npy/std_img.npy + # which are required by validation and test loaders. + if split == "all" or split == "train": + aug_dir = deepssm_dir + "augmentation/" + aug_data_csv = aug_dir + "TotalData.csv" + DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size, num_workers=num_workers) + if split == "all" or split == "val": val_image_files = [] val_world_particles = [] val_indices = get_split_indices(project, "val") for i in val_indices: - val_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") + image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" + if not os.path.exists(image_file): + raise FileNotFoundError(f"Missing validation image for subject {i}: {image_file}") + val_image_files.append(image_file) particle_file = project.get_subjects()[i].get_world_particle_filenames()[0] val_world_particles.append(particle_file) DeepSSMUtils.getValidationLoader(loader_dir, val_image_files, val_world_particles, num_workers=num_workers) - if split == "all" or split == "train": - aug_dir = deepssm_dir + "augmentation/" - aug_data_csv = aug_dir + "TotalData.csv" - DeepSSMUtils.getTrainLoader(loader_dir, aug_data_csv, batch_size, num_workers=num_workers) - if split == "all" or split == "test": test_image_files = [] test_indices = get_split_indices(project, "test") for i in test_indices: - test_image_files.append(deepssm_dir + f"/val_and_test_images/{i}.nrrd") + image_file = deepssm_dir + f"/val_and_test_images/{i}.nrrd" + if not os.path.exists(image_file): + raise FileNotFoundError(f"Missing test image for subject {i}: {image_file}") + test_image_files.append(image_file) DeepSSMUtils.getTestLoader(loader_dir, test_image_files, num_workers=num_workers) @@ -517,6 +556,10 @@ def process_test_predictions(project, config_file): for index in test_indices: world_particle_file = f"{world_predictions_dir}/{index}.particles" + + if not os.path.exists(world_particle_file): + raise FileNotFoundError(f"Missing prediction for test subject {index}: {world_particle_file}") + print(f"world_particle_file: {world_particle_file}") predicted_test_world_particles.append(world_particle_file) @@ -538,7 +581,8 @@ def process_test_predictions(project, config_file): template_particles, template_mesh, pred_dir) print("Distances: ", distances) - print("Mean distance: ", np.mean(distances)) + mean_distance = np.mean(distances) + print("Mean distance: ", mean_distance) # write to csv file in deepssm_dir csv_file = f"{deepssm_dir}/test_distances.csv" @@ -561,3 +605,5 @@ def process_test_predictions(project, config_file): mesh = sw.Mesh(local_mesh_file) mesh.applyTransform(transform) mesh.write(world_mesh_file) + + return mean_distance diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index f73e26fb34..0151710b81 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -13,6 +13,9 @@ from DeepSSMUtils import losses from DeepSSMUtils import train_viz from DeepSSMUtils import loaders +from DeepSSMUtils import net_utils +from DeepSSMUtils import constants as C +from DeepSSMUtils import config_validation import DeepSSMUtils from shapeworks.utils import * @@ -68,10 +71,11 @@ def set_scheduler(opt, sched_params): def train(project, config_file): + net_utils.set_seed(42) sw.utils.initialize_project_mesh_warper(project) - with open(config_file) as json_file: - parameters = json.load(json_file) + # Validate config file before training + parameters = config_validation.validate_config(config_file) if parameters["tl_net"]["enabled"]: supervised_train_tl(config_file) else: @@ -101,11 +105,11 @@ def supervised_train(config_file): fine_tune = parameters['fine_tune']['enabled'] loss_func = method_to_call = getattr(losses, parameters["loss"]["function"]) # load the loaders - train_loader_path = loader_dir + "train" - validation_loader_path = loader_dir + "validation" + train_loader_path = loader_dir + C.TRAIN_LOADER + validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path, weights_only=False) - val_loader = torch.load(validation_loader_path, weights_only=False) + train_loader = loaders.load_data_loader(train_loader_path, loader_type='train') + val_loader = loaders.load_data_loader(validation_loader_path, loader_type='validation') print("Done.") # initializations num_pca = train_loader.dataset.pca_target[0].shape[0] @@ -119,8 +123,8 @@ def supervised_train(config_file): net.apply(weight_init(module=nn.Linear, initf=nn.init.xavier_normal_)) # these lines are for the fine tuning layer initialization - whiten_mean = np.load(loader_dir + '/mean_PCA.npy') - whiten_std = np.load(loader_dir + '/std_PCA.npy') + whiten_mean = np.load(loader_dir + '/' + C.MEAN_PCA_FILE) + whiten_std = np.load(loader_dir + '/' + C.STD_PCA_FILE) orig_mean = np.loadtxt(aug_dir + '/PCA_Particle_Info/mean.particles') orig_pc = np.zeros([num_pca, num_corr * 3]) for i in range(num_pca): @@ -146,7 +150,7 @@ def supervised_train(config_file): # train print("Beginning training on device = " + device + '\n') # Initialize logger - logger = open(model_dir + "train_log.csv", "w+", buffering=1) + logger = open(model_dir + C.TRAIN_LOG_FILE, "w+", buffering=1) log_print(logger, ["Training_Stage", "Epoch", "LR", "Train_Err", "Train_Rel_Err", "Val_Err", "Val_Rel_Err", "Sec"]) # Initialize training plot train_plot = plt.figure() @@ -158,7 +162,7 @@ def supervised_train(config_file): axe.set_xlim(0, num_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -241,17 +245,17 @@ def supervised_train(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_FILE) # save if val_rel_err < best_val_rel_error: best_val_rel_error = val_rel_err best_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FILE)) t0 = time.time() if decay_lr: scheduler.step() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FILE)) parameters['best_model_epochs'] = best_epoch with open(config_file, "w") as json_file: json.dump(parameters, json_file, indent=2) @@ -290,7 +294,7 @@ def supervised_train(config_file): axe.set_xlim(0, ft_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_ft.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_FT_FILE, dpi=300) epochs = [] plot_train_losses = [] plot_val_losses = [] @@ -355,7 +359,7 @@ def supervised_train(config_file): if val_rel_loss < best_ft_val_rel_error: best_ft_val_rel_error = val_rel_loss best_ft_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model_ft.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FT_FILE)) pred_particles.extend(pred_mdl.detach().cpu().numpy()) true_particles.extend(mdl.detach().cpu().numpy()) train_viz.write_examples(np.array(pred_particles), np.array(true_particles), val_names, @@ -376,12 +380,12 @@ def supervised_train(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_ft.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_FT_FILE) t0 = time.time() logger.close() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_ft.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FT_FILE)) parameters['best_ft_model_epochs'] = best_ft_epoch with open(config_file, "w") as json_file: @@ -411,11 +415,11 @@ def supervised_train_tl(config_file): a_lat = parameters["tl_net"]["a_lat"] c_lat = parameters["tl_net"]["c_lat"] # load the loaders - train_loader_path = loader_dir + "train" - validation_loader_path = loader_dir + "validation" + train_loader_path = loader_dir + C.TRAIN_LOADER + validation_loader_path = loader_dir + C.VALIDATION_LOADER print("Loading data loaders...") - train_loader = torch.load(train_loader_path) - val_loader = torch.load(validation_loader_path) + train_loader = loaders.load_data_loader(train_loader_path, loader_type='train') + val_loader = loaders.load_data_loader(validation_loader_path, loader_type='validation') print("Done.") print("Defining model...") net = model.DeepSSMNet_TLNet(config_file) @@ -447,7 +451,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, ae_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_ae.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_AE_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -540,10 +544,10 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_ae.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_AE_FILE) t0 = time.time() # save - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_ae.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_AE_FILE)) # fix the autoencoder and train the TL-net for param in net.CorrespondenceDecoder.parameters(): param.requires_grad = False @@ -563,7 +567,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, tf_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_tf.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_TF_FILE, dpi=300) # initialize t0 = time.time() epochs = [] @@ -650,10 +654,10 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_tf.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_TF_FILE) t0 = time.time() # save - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model_tf.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_TF_FILE)) # jointly train the model joint_epochs = parameters['tl_net']['joint_epochs'] alpha = parameters['tl_net']['alpha'] @@ -673,7 +677,7 @@ def supervised_train_tl(config_file): axe.set_xlim(0, joint_epochs + 1) axe.set_ylabel('Particle MSE') axe.legend() - train_plot.savefig(model_dir + "training_plot_joint.png", dpi=300) + train_plot.savefig(model_dir + C.TRAINING_PLOT_JOINT_FILE, dpi=300) # initialize epochs = [] plot_train_losses = [] @@ -771,19 +775,19 @@ def supervised_train_tl(config_file): sp_val.set_data(epochs, plot_val_losses) axe.set_ylim(0, max(max(plot_train_losses), max(plot_val_losses)) + 3) train_plot.canvas.draw() - train_plot.savefig(model_dir + "training_plot_joint.png") + train_plot.savefig(model_dir + C.TRAINING_PLOT_JOINT_FILE) # save val_rel_err = val_rel_ae_err + alpha * val_rel_tf_err if val_rel_err < best_val_rel_error: best_val_rel_error = val_rel_err best_epoch = e - torch.save(net.state_dict(), os.path.join(model_dir, 'best_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.BEST_MODEL_FILE)) t0 = time.time() if decay_lr: scheduler.step() logger.close() - torch.save(net.state_dict(), os.path.join(model_dir, 'final_model.torch')) + torch.save(net.state_dict(), os.path.join(model_dir, C.FINAL_MODEL_FILE)) parameters['best_model_epochs'] = best_epoch with open(config_file, "w") as json_file: json.dump(parameters, json_file, indent=2) diff --git a/Testing/CMakeLists.txt b/Testing/CMakeLists.txt index c03ca89ef2..58dfb2fe16 100644 --- a/Testing/CMakeLists.txt +++ b/Testing/CMakeLists.txt @@ -77,3 +77,4 @@ add_subdirectory(ProjectTests) add_subdirectory(UseCaseTests) add_subdirectory(shapeworksTests) add_subdirectory(UtilsTests) +add_subdirectory(DeepSSMTests) diff --git a/Testing/DeepSSMTests/CMakeLists.txt b/Testing/DeepSSMTests/CMakeLists.txt new file mode 100644 index 0000000000..7a0c119de1 --- /dev/null +++ b/Testing/DeepSSMTests/CMakeLists.txt @@ -0,0 +1,14 @@ +set(TEST_SRCS + DeepSSMTests.cpp + ) + +add_executable(DeepSSMTests + ${TEST_SRCS} + ) + +target_link_libraries(DeepSSMTests + Testing + ) + +add_test(NAME DeepSSMTests COMMAND DeepSSMTests) +set_tests_properties(DeepSSMTests PROPERTIES TIMEOUT 1800) diff --git a/Testing/DeepSSMTests/DeepSSMTests.cpp b/Testing/DeepSSMTests/DeepSSMTests.cpp new file mode 100644 index 0000000000..6783e325b9 --- /dev/null +++ b/Testing/DeepSSMTests/DeepSSMTests.cpp @@ -0,0 +1,19 @@ +#include "Testing.h" + +using namespace shapeworks; + +//--------------------------------------------------------------------------- +void run_deepssm_test(const std::string& name) { + setupenv(std::string(TEST_DATA_DIR) + "/../DeepSSMTests"); + + std::string command = "bash " + name; + ASSERT_FALSE(system(command.c_str())); +} + +//--------------------------------------------------------------------------- +// Run 2 configurations that cover all code paths: +// - default: standard DeepSSM +// - tl_net_fine_tune: TL-DeepSSM with fine tuning (covers both tl_net and fine_tune paths) +TEST(DeepSSMTests, defaultTest) { run_deepssm_test("deepssm_default.sh"); } + +TEST(DeepSSMTests, tlNetFineTuneTest) { run_deepssm_test("deepssm_tl_net_fine_tune.sh"); } diff --git a/Testing/DeepSSMTests/README.md b/Testing/DeepSSMTests/README.md new file mode 100644 index 0000000000..e326fa8525 --- /dev/null +++ b/Testing/DeepSSMTests/README.md @@ -0,0 +1,107 @@ +# DeepSSM Tests + +Automated tests for DeepSSM using ShapeWorks project files (.swproj). + +## Test Configurations + +| Test | Description | +|------|-------------| +| `deepssm_default` | Standard DeepSSM (no TL-Net, no fine-tuning) | +| `deepssm_tl_net` | TL-DeepSSM network enabled | +| `deepssm_fine_tune` | Fine-tuning enabled | +| `deepssm_tl_net_fine_tune` | Both TL-DeepSSM and fine-tuning enabled | + +## Running Tests + +### Run all DeepSSM tests: +```bash +cd /path/to/build +ctest -R DeepSSMTests -V +``` + +### Run a specific test: +```bash +./bin/DeepSSMTests --gtest_filter="*default*" +./bin/DeepSSMTests --gtest_filter="*tl_net*" +``` + +### Run tests directly via shell scripts: +```bash +export DATA=/path/to/Testing/data +bash Testing/DeepSSMTests/deepssm_default.sh +``` + +## Test Data + +Test data is stored in `Testing/data/deepssm_test_data.zip` and automatically extracted on first run. Contains: +- 5 femur meshes, CT images, and constraint files +- Pre-configured project files for each test configuration + +## Result Verification + +Tests verify that the mean surface-to-surface distance is within tolerance. The default tolerance is loose (0-300) for quick 1-epoch tests. + +### Exact Check Mode (for refactoring verification) + +When refactoring DeepSSM code, you can verify results are identical before and after changes. + +**Run all configurations:** +```bash +# Save baselines (before refactoring) +bash Testing/DeepSSMTests/run_exact_check.sh save + +# Verify after refactoring +bash Testing/DeepSSMTests/run_exact_check.sh verify +``` + +**Run a single configuration:** +```bash +cd Testing/data/deepssm/projects +rm -rf deepssm groomed *_particles +shapeworks deepssm --name default.swproj --all + +# Save or verify +python Testing/DeepSSMTests/verify_deepssm_results.py . --exact_check save +python Testing/DeepSSMTests/verify_deepssm_results.py . --exact_check verify +``` + +Baseline values are saved to `exact_check_*.txt` in the project directory. + +**Note:** Exact check is platform-specific due to floating-point differences. Only compare results from the same machine. + +## Extended Tests (Manual) + +Extended tests run on a directory of projects for meaningful accuracy checks. These are not part of automated CI. + +### Directory Structure + +``` +/path/to/projects/ + project1/ + project1.swproj + femur/... + project2/ + project2.swproj + data/... +``` + +Each subdirectory should contain a `.swproj` file and its associated data. + +### Running Extended Tests + +```bash +# Run all projects with relaxed tolerance +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects + +# Save baselines for exact check +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects save + +# Verify against baselines +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects verify + +# Run specific project only +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects save femur +bash Testing/DeepSSMTests/run_extended_tests.sh /path/to/projects verify femur +``` + +Baseline values are saved to `exact_check_.txt` in each project directory. diff --git a/Testing/DeepSSMTests/deepssm_default.sh b/Testing/DeepSSMTests/deepssm_default.sh new file mode 100755 index 0000000000..cf7ff20318 --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_default.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Test DeepSSM with default settings (no tl_net, no fine_tune) +set -e + +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + +echo "=== DeepSSM default test starting ===" +echo "DATA=${DATA}" +echo "OMP_NUM_THREADS=${OMP_NUM_THREADS}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +echo "SCRIPT_DIR=${SCRIPT_DIR}" + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + echo "Unzipping test data..." + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +echo "Changing to ${DATA}/deepssm/projects" +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +echo "Running shapeworks deepssm..." +shapeworks deepssm --name default.swproj --all --aug_processes 1 +echo "shapeworks deepssm completed" + +# Verify results +echo "Verifying results..." +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" +echo "=== DeepSSM default test complete ===" diff --git a/Testing/DeepSSMTests/deepssm_fine_tune.sh b/Testing/DeepSSMTests/deepssm_fine_tune.sh new file mode 100755 index 0000000000..c4450ae70c --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_fine_tune.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Test DeepSSM with fine tuning enabled +set -e + +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name fine_tune.swproj --all --aug_processes 1 + +# Verify results +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net.sh b/Testing/DeepSSMTests/deepssm_tl_net.sh new file mode 100755 index 0000000000..a36369e5a5 --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_tl_net.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Test DeepSSM with TL-DeepSSM network enabled +set -e + +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name tl_net.swproj --all --aug_processes 1 + +# Verify results +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh new file mode 100755 index 0000000000..7eeb725606 --- /dev/null +++ b/Testing/DeepSSMTests/deepssm_tl_net_fine_tune.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Test DeepSSM with both TL-DeepSSM and fine tuning enabled +set -e + +# Prevent PyTorch/OpenMP deadlock on macOS and Windows +export OMP_NUM_THREADS=1 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Unzip test data if not already extracted +if [ ! -d "${DATA}/deepssm" ]; then + unzip -q "${DATA}/deepssm_test_data.zip" -d "${DATA}/deepssm" +fi + +cd "${DATA}/deepssm/projects" +rm -rf deepssm groomed *_particles +shapeworks deepssm --name tl_net_fine_tune.swproj --all --aug_processes 1 + +# Verify results +python3 "${SCRIPT_DIR}/verify_deepssm_results.py" "${DATA}/deepssm/projects" diff --git a/Testing/DeepSSMTests/run_exact_check.sh b/Testing/DeepSSMTests/run_exact_check.sh new file mode 100755 index 0000000000..e31cb61697 --- /dev/null +++ b/Testing/DeepSSMTests/run_exact_check.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Run exact check for all DeepSSM test configurations +# Usage: ./run_exact_check.sh save|verify + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${DATA:-$(dirname "$SCRIPT_DIR")/data}" + +if [ "$1" != "save" ] && [ "$1" != "verify" ]; then + echo "Usage: $0 save|verify" + echo " save - Save baseline values (run before refactoring)" + echo " verify - Verify against saved values (run after refactoring)" + exit 1 +fi + +MODE="$1" +CONFIGS="default tl_net fine_tune tl_net_fine_tune" + +# Unzip test data if not already extracted +if [ ! -d "${DATA_DIR}/deepssm" ]; then + unzip -q "${DATA_DIR}/deepssm_test_data.zip" -d "${DATA_DIR}/deepssm" +fi + +cd "${DATA_DIR}/deepssm/projects" + +for config in $CONFIGS; do + echo "========================================" + echo "Running $config..." + echo "========================================" + + rm -rf deepssm groomed *_particles + shapeworks deepssm --name ${config}.swproj --all + + # Run exact check with config-specific file + python3 "${SCRIPT_DIR}/verify_deepssm_results.py" . \ + --exact_check "$MODE" \ + --baseline_file "exact_check_${config}.txt" + + echo "" +done + +echo "========================================" +echo "All configurations: $MODE complete!" +echo "========================================" diff --git a/Testing/DeepSSMTests/run_extended_tests.sh b/Testing/DeepSSMTests/run_extended_tests.sh new file mode 100755 index 0000000000..46e96e6e6c --- /dev/null +++ b/Testing/DeepSSMTests/run_extended_tests.sh @@ -0,0 +1,145 @@ +#!/bin/bash +# Run extended DeepSSM tests on a directory of projects +# +# Usage: ./run_extended_tests.sh [save|verify|relaxed] [project] +# +# Arguments: +# base_dir - Directory containing project subdirectories +# mode - save: save baseline values +# verify: verify against saved baselines +# relaxed: run with loose tolerance (default) +# project - Optional: run only this project (default: all) +# +# Examples: +# ./run_extended_tests.sh /path/to/projects # Run all with relaxed check +# ./run_extended_tests.sh /path/to/projects save # Save baselines for all +# ./run_extended_tests.sh /path/to/projects verify # Verify all against baselines +# ./run_extended_tests.sh /path/to/projects save femur # Save baseline for femur only +# +# Directory structure: +# base_dir/ +# project1/ +# *.swproj +# femur/ (or other data) +# project2/ +# *.swproj +# ... + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +usage() { + echo "Usage: $0 [save|verify|relaxed] [project]" + echo "" + echo "Arguments:" + echo " base_dir - Directory containing project subdirectories" + echo " mode - save|verify|relaxed (default: relaxed)" + echo " project - Run only this project (default: all)" + echo "" + echo "Examples:" + echo " $0 /path/to/projects" + echo " $0 /path/to/projects save" + echo " $0 /path/to/projects verify" + echo " $0 /path/to/projects save femur" +} + +if [ $# -lt 1 ] || [ "$1" = "-h" ] || [ "$1" = "--help" ]; then + usage + exit 0 +fi + +BASE_DIR="$1" +MODE="${2:-relaxed}" +PROJECT="${3:-all}" + +if [ ! -d "$BASE_DIR" ]; then + echo "Error: Directory not found: $BASE_DIR" + exit 1 +fi + +if [ "$MODE" != "save" ] && [ "$MODE" != "verify" ] && [ "$MODE" != "relaxed" ]; then + echo "Error: Unknown mode: $MODE" + usage + exit 1 +fi + +run_project() { + local project_dir="$1" + local project_name="$(basename "$project_dir")" + + echo "========================================" + echo "Project: $project_name" + echo "========================================" + + # Find .swproj file + local swproj=$(find "$project_dir" -maxdepth 1 -name "*.swproj" | head -1) + if [ -z "$swproj" ]; then + echo "Warning: No .swproj file found in $project_dir, skipping" + return 0 + fi + + local swproj_name="$(basename "$swproj")" + echo "Using project file: $swproj_name" + + cd "$project_dir" + rm -rf deepssm groomed *_particles + + shapeworks deepssm --name "$swproj_name" --all + + # Verify results + local baseline_file="exact_check_${project_name}.txt" + local verify_args="" + + if [ "$MODE" = "save" ]; then + verify_args="--exact_check save --baseline_file $baseline_file" + elif [ "$MODE" = "verify" ]; then + verify_args="--exact_check verify --baseline_file $baseline_file" + else + verify_args="--expected 10 --tolerance 1.0" + fi + + python3 "${SCRIPT_DIR}/verify_deepssm_results.py" . $verify_args + + echo "" +} + +echo "Extended DeepSSM Tests" +echo "Base directory: ${BASE_DIR}" +echo "Mode: ${MODE}" +echo "" + +# Find all project directories (directories containing .swproj files) +ran_any=false +for project_dir in "$BASE_DIR"/*/; do + if [ ! -d "$project_dir" ]; then + continue + fi + + project_name="$(basename "$project_dir")" + + # Skip if specific project requested and this isn't it + if [ "$PROJECT" != "all" ] && [ "$PROJECT" != "$project_name" ]; then + continue + fi + + # Check if this directory has a .swproj file + if ls "$project_dir"/*.swproj 1>/dev/null 2>&1; then + run_project "$project_dir" + ran_any=true + fi +done + +if [ "$ran_any" = false ]; then + if [ "$PROJECT" = "all" ]; then + echo "Error: No projects found in $BASE_DIR" + echo "Each project should be a subdirectory containing a .swproj file." + else + echo "Error: Project not found: $PROJECT" + fi + exit 1 +fi + +echo "========================================" +echo "All projects complete!" +echo "========================================" diff --git a/Testing/DeepSSMTests/verify_deepssm_results.py b/Testing/DeepSSMTests/verify_deepssm_results.py new file mode 100644 index 0000000000..6375b4df27 --- /dev/null +++ b/Testing/DeepSSMTests/verify_deepssm_results.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +""" +Verify DeepSSM test results by checking the mean distance from test_distances.csv. + +Usage: + python verify_deepssm_results.py [--exact_check save|verify] [--expected ] + +The script checks that the mean surface-to-surface distance is reasonable (roughly 10, within tolerance). +For exact refactoring verification, use --exact_check save/verify to save or compare exact values. +""" + +import argparse +import csv +import math +import os +import sys + + +def get_mean_distance(project_dir: str) -> float: + """Read mean distance from test_distances.csv.""" + csv_path = os.path.join(project_dir, "deepssm", "test_distances.csv") + if not os.path.exists(csv_path): + raise FileNotFoundError(f"Results file not found: {csv_path}") + + distances = [] + with open(csv_path, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + distances.append(float(row['Distance'])) + + if not distances: + raise ValueError(f"No distances found in {csv_path}") + + return sum(distances) / len(distances) + + +def main(): + parser = argparse.ArgumentParser(description="Verify DeepSSM test results") + parser.add_argument("project_dir", help="Path to the project directory containing deepssm/ output") + parser.add_argument("--exact_check", choices=["save", "verify"], + help="Save or verify exact values for refactoring verification") + parser.add_argument("--expected", type=float, default=150.0, + help="Expected mean distance for relaxed check (default: 150.0)") + parser.add_argument("--tolerance", type=float, default=1.0, + help="Relative tolerance for relaxed check (default: 1.0 = 100%%)") + parser.add_argument("--baseline_file", type=str, default="exact_check_value.txt", + help="Filename for exact check baseline (default: exact_check_value.txt)") + args = parser.parse_args() + + try: + mean_dist = get_mean_distance(args.project_dir) + print(f"Mean distance: {mean_dist}") + except (FileNotFoundError, ValueError) as e: + print(f"Error: {e}") + sys.exit(1) + + exact_check_file = os.path.join(args.project_dir, args.baseline_file) + + if args.exact_check == "save": + with open(exact_check_file, "w") as f: + f.write(str(mean_dist)) + print(f"Saved exact check value to: {exact_check_file}") + sys.exit(0) + + elif args.exact_check == "verify": + if not os.path.exists(exact_check_file): + print(f"Error: No saved value found at {exact_check_file}") + print("Run with --exact_check save first to create baseline.") + sys.exit(1) + with open(exact_check_file, "r") as f: + expected = float(f.read().strip()) + if mean_dist != expected: + print(f"Exact check FAILED: expected {expected}, got {mean_dist}") + sys.exit(1) + print(f"Exact check PASSED: {mean_dist}") + sys.exit(0) + + else: + # Relaxed check for CI/cross-platform + if not math.isclose(mean_dist, args.expected, rel_tol=args.tolerance): + print(f"FAILED: mean distance {mean_dist} not close to {args.expected} (tolerance {args.tolerance})") + sys.exit(1) + print(f"PASSED: mean distance {mean_dist} is within tolerance of {args.expected}") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/Testing/MeshTests/MeshTests.cpp b/Testing/MeshTests/MeshTests.cpp index 787a65ebd7..d7d343e090 100644 --- a/Testing/MeshTests/MeshTests.cpp +++ b/Testing/MeshTests/MeshTests.cpp @@ -633,7 +633,7 @@ TEST(MeshTests, fieldTest2) { Mesh mesh(std::string(TEST_DATA_DIR) + "/la-bin.vtk"); double a = mesh.getFieldValue("scalars", 0); double b = mesh.getFieldValue("scalars", 1000); - double c = mesh.getFieldValue("Normals", 4231); + double c = mesh.getFieldValue("Normals", 12); double d = mesh.getFieldValue("Normals", 5634); ASSERT_TRUE(a == 1); diff --git a/Testing/UseCaseTests/UseCaseTests.cpp b/Testing/UseCaseTests/UseCaseTests.cpp index 4a4270300c..ca8ee646dc 100644 --- a/Testing/UseCaseTests/UseCaseTests.cpp +++ b/Testing/UseCaseTests/UseCaseTests.cpp @@ -13,6 +13,10 @@ void run_test(const std::string& name) { std::remove(outputname.c_str()); // run python + // Remove status files so all steps re-run from scratch. + // Don't use --clean as it deletes pre-downloaded test data. + boost::filesystem::remove_all("Output/" + name + "/status"); + boost::filesystem::remove_all("Output/" + name + "/tiny_test_status"); std::string command = "python RunUseCase.py " + name + " --tiny_test 1>" + outputname + " 2>&1"; // use the below instead of there is some problem in getting the output // std::string command = "python RunUseCase.py " + name + " --tiny_test"; diff --git a/Testing/data/deepssm_test_data.zip b/Testing/data/deepssm_test_data.zip new file mode 100644 index 0000000000..621d3a1556 --- /dev/null +++ b/Testing/data/deepssm_test_data.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99c6a0a3f6bfa91cc00095db64cf9155fe037a9a56afd918aee25b9c3f4770d5 +size 6196905 diff --git a/Testing/data/femur1_to_2_icp.nrrd b/Testing/data/femur1_to_2_icp.nrrd index a351d7392a..0321448e69 100644 --- a/Testing/data/femur1_to_2_icp.nrrd +++ b/Testing/data/femur1_to_2_icp.nrrd @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b42fb8d7611fbef0b0b3c988505059d5ce4c681b3b0e44d7115bbd96eb8ca8d1 -size 748054 +oid sha256:60ecd61c1b944ff72936c31b0c1550e5b8db9c90bf790e4068b06f92b2d643a4 +size 755252 diff --git a/Testing/data/femur2_to_1_icp.nrrd b/Testing/data/femur2_to_1_icp.nrrd index 4bd8b1a5ce..19b85103ed 100644 --- a/Testing/data/femur2_to_1_icp.nrrd +++ b/Testing/data/femur2_to_1_icp.nrrd @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6757b04eea0cde376666826beeec0328dbd71f47e57cfe0bab6551d29e3f9982 -size 758746 +oid sha256:cf585f5a63f2567caef7de3dc73a19000db835d783752d31b9e598a9c1af2692 +size 752034 diff --git a/Testing/data/la-bin.vtk b/Testing/data/la-bin.vtk index 97a4f88312..58efdb6cb2 100644 --- a/Testing/data/la-bin.vtk +++ b/Testing/data/la-bin.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:78233434f7745536768b8ff689dcd28f6e2bc4cf7d77fe6a197c1e419601bd3a -size 2196097 +oid sha256:fc7cfe8d712e7a531ca11c1f9cda50f5b4766f81eb38120d6451623e04b9d20d +size 2872943 diff --git a/Testing/data/reconstruct_mean_surface.vtk b/Testing/data/reconstruct_mean_surface.vtk index 2961524227..98f0d7c80b 100644 --- a/Testing/data/reconstruct_mean_surface.vtk +++ b/Testing/data/reconstruct_mean_surface.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:50faf83cc9d39e44a1932c3f1de5b8392e30b4715512fbd1492a6eecaa9f43cf -size 432903 +oid sha256:40daf835c37e0a1a0bef01ffe466470d4804f3e146312bf1e39f1532c6219e4b +size 432959 diff --git a/Testing/data/smoothsinc.vtp b/Testing/data/smoothsinc.vtp index b4d36516c9..306e3fd036 100644 --- a/Testing/data/smoothsinc.vtp +++ b/Testing/data/smoothsinc.vtp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:11dff539015f11fc545839df9eb89ba61d3e3ee531638a0de4e1307d0244abd4 -size 7873251 +oid sha256:b16526bf253c9ff696d888d0c14d4917758afdaa1d662f554ba05a9604c27414 +size 7873240 diff --git a/Testing/data/transforms/meshTransformWithImageTransform.vtk b/Testing/data/transforms/meshTransformWithImageTransform.vtk index 8a92a9895c..361e63ae92 100644 --- a/Testing/data/transforms/meshTransformWithImageTransform.vtk +++ b/Testing/data/transforms/meshTransformWithImageTransform.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2f7144cdb9d60dee1da5ef387d281f12e307d90759d572a5d4db3558517d1d9b -size 29754 +oid sha256:0edb7df0897be426d80259a8b72984c95ce6e960bde8199b64b1db7287dd4e7c +size 43445 diff --git a/Testing/data/transforms/meshTransformWithoutImageTransform.vtk b/Testing/data/transforms/meshTransformWithoutImageTransform.vtk index 987a0c0dd9..384b55c73b 100644 --- a/Testing/data/transforms/meshTransformWithoutImageTransform.vtk +++ b/Testing/data/transforms/meshTransformWithoutImageTransform.vtk @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:93f78c2f7b4a25547df75035fb6e035509b401406b9255430bb5c35deb658ddb -size 29754 +oid sha256:ea2d951a91a557d0d96a78bab3ed606a034f3e641e074ff89468e374e5bc8ad4 +size 43443 diff --git a/devenv.sh b/devenv.sh index f4cc5b9177..30b4a9dc96 100644 --- a/devenv.sh +++ b/devenv.sh @@ -11,7 +11,7 @@ # compiled portion of the Python bindings). # -SW_MAJOR_VERSION=6.6 +SW_MAJOR_VERSION=6.7 (return 0 2>/dev/null) && sourced=1 || sourced=0