Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions brainles_preprocessing/defacing/defacer.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union

from typing import Union, Optional
import numpy as np
from auxiliary.io import read_image, write_image


class Defacer(ABC):
"""
Base class for defacing medical images using brain masks.

Subclasses should implement the `deface` method to generate a defaced image
based on the provided input image and mask.
"""

@abstractmethod
def deface(
def __init__(
self,
input_image_path: Union[str, Path],
mask_image_path: Union[str, Path],
) -> None:
masking_value: Optional[Union[int, float]] = None,
):
"""
Generate a defacing mask provided an input image.
Base class for defacing medical images using brain masks.

Args:
input_image_path (str or Path): Path to the input image (NIfTI format).
mask_image_path (str or Path): Path to the output mask image (NIfTI format).
Subclasses should implement the `deface` method to generate a defaced image
based on the provided input image and mask.
"""
pass
# Here, masking value functions as a global value across all images and modalities
# If no value is passed, the minimum of a given input image is chosen
# TODO: Consider extending this to modality-specific masking values in the future, this should
# probably be implemented as a property of the the specific modality
self.masking_value = masking_value

@abstractmethod
def deface(
self,
input_image_path: Union[str, Path],
mask_image_path: Union[str, Path],
) -> None:
"""
Generate a defacing mask provided an input image.

Args:
input_image_path (str or Path): Path to the input image (NIfTI format).
mask_image_path (str or Path): Path to the output mask image (NIfTI format).
"""
pass

def apply_mask(
self,
Expand Down Expand Up @@ -63,8 +72,17 @@ def apply_mask(
if input_data.shape != mask_data.shape:
raise ValueError("Input image and mask must have the same dimensions.")

# Apply mask (element-wise multiplication)
masked_data = input_data * mask_data
# check whether a global masking value was passed, otherwise choose minimum
if self.masking_value is None:
current_masking_value = np.min(input_data)
else:
current_masking_value = (
np.array(self.masking_value).astype(input_data.dtype).item()
)
# Apply mask (element-wise either input or masking value)
masked_data = np.where(
mask_data.astype(bool), input_data, current_masking_value
)

# Save the defaced image
write_image(
Expand Down
6 changes: 4 additions & 2 deletions brainles_preprocessing/defacing/quickshear/quickshear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Union
from typing import Union, Optional

import nibabel as nib
from auxiliary.io import read_image, write_image
Expand Down Expand Up @@ -35,15 +35,17 @@ def __init__(
buffer: float = 10.0,
force_atlas_registration: bool = True,
atlas_image_path: Union[str, Path, Atlas] = Atlas.SRI24,
masking_value: Optional[Union[int, float]] = None,
):
"""Initialize Quickshear defacer

Args:
buffer (float, optional): buffer parameter from quickshear algorithm. Defaults to 10.0.
force_atlas_registration (bool, optional): If True, forces atlas registration of the BET mask before defacing to potentially boost quickshear performance. Defaults to True.
atlas_image_path (Union[str, Path, Atlas], optional): Path to the atlas image or an Atlas enum value that will be used for the optional atlas registrations. Defaults to Atlas.SRI24.
masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image.
"""
super().__init__()
super().__init__(masking_value=masking_value)
self.buffer = buffer
self.force_atlas_registration = force_atlas_registration
self.atlas_image_path = atlas_image_path
Expand Down