diff --git a/brainles_preprocessing/defacing/defacer.py b/brainles_preprocessing/defacing/defacer.py index 9c54253..0e0e876 100644 --- a/brainles_preprocessing/defacing/defacer.py +++ b/brainles_preprocessing/defacing/defacer.py @@ -1,32 +1,41 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Union - +from typing import Union, Optional +import numpy as np from auxiliary.io import read_image, write_image class Defacer(ABC): - """ - Base class for defacing medical images using brain masks. - - Subclasses should implement the `deface` method to generate a defaced image - based on the provided input image and mask. - """ - - @abstractmethod - def deface( + def __init__( self, - input_image_path: Union[str, Path], - mask_image_path: Union[str, Path], - ) -> None: + masking_value: Optional[Union[int, float]] = None, + ): """ - Generate a defacing mask provided an input image. + Base class for defacing medical images using brain masks. - Args: - input_image_path (str or Path): Path to the input image (NIfTI format). - mask_image_path (str or Path): Path to the output mask image (NIfTI format). + Subclasses should implement the `deface` method to generate a defaced image + based on the provided input image and mask. """ - pass + # Here, masking value functions as a global value across all images and modalities + # If no value is passed, the minimum of a given input image is chosen + # TODO: Consider extending this to modality-specific masking values in the future, this should + # probably be implemented as a property of the the specific modality + self.masking_value = masking_value + + @abstractmethod + def deface( + self, + input_image_path: Union[str, Path], + mask_image_path: Union[str, Path], + ) -> None: + """ + Generate a defacing mask provided an input image. + + Args: + input_image_path (str or Path): Path to the input image (NIfTI format). + mask_image_path (str or Path): Path to the output mask image (NIfTI format). + """ + pass def apply_mask( self, @@ -63,8 +72,17 @@ def apply_mask( if input_data.shape != mask_data.shape: raise ValueError("Input image and mask must have the same dimensions.") - # Apply mask (element-wise multiplication) - masked_data = input_data * mask_data + # check whether a global masking value was passed, otherwise choose minimum + if self.masking_value is None: + current_masking_value = np.min(input_data) + else: + current_masking_value = ( + np.array(self.masking_value).astype(input_data.dtype).item() + ) + # Apply mask (element-wise either input or masking value) + masked_data = np.where( + mask_data.astype(bool), input_data, current_masking_value + ) # Save the defaced image write_image( diff --git a/brainles_preprocessing/defacing/quickshear/quickshear.py b/brainles_preprocessing/defacing/quickshear/quickshear.py index 754b593..a33b82a 100644 --- a/brainles_preprocessing/defacing/quickshear/quickshear.py +++ b/brainles_preprocessing/defacing/quickshear/quickshear.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import Union, Optional import nibabel as nib from auxiliary.io import read_image, write_image @@ -35,6 +35,7 @@ def __init__( buffer: float = 10.0, force_atlas_registration: bool = True, atlas_image_path: Union[str, Path, Atlas] = Atlas.SRI24, + masking_value: Optional[Union[int, float]] = None, ): """Initialize Quickshear defacer @@ -42,8 +43,9 @@ def __init__( buffer (float, optional): buffer parameter from quickshear algorithm. Defaults to 10.0. force_atlas_registration (bool, optional): If True, forces atlas registration of the BET mask before defacing to potentially boost quickshear performance. Defaults to True. atlas_image_path (Union[str, Path, Atlas], optional): Path to the atlas image or an Atlas enum value that will be used for the optional atlas registrations. Defaults to Atlas.SRI24. + masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image. """ - super().__init__() + super().__init__(masking_value=masking_value) self.buffer = buffer self.force_atlas_registration = force_atlas_registration self.atlas_image_path = atlas_image_path