diff --git a/dlclive/check_install/check_install.py b/dlclive/check_install/check_install.py index 2bc4e65..bdbd0f8 100755 --- a/dlclive/check_install/check_install.py +++ b/dlclive/check_install/check_install.py @@ -9,13 +9,15 @@ import urllib.request import argparse import shutil -import sys + import urllib.request import warnings from pathlib import Path +from dlclive.utils import download_file from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model +import dlclive from dlclive.benchmark import benchmark_videos from dlclive.engine import Engine @@ -23,14 +25,6 @@ SNAPSHOT_NAME = "snapshot-700000.pb" -def urllib_pbar(count, blockSize, totalSize): - percent = int(count * blockSize * 100 / totalSize) - outstr = f"{round(percent)}%" - sys.stdout.write(outstr) - sys.stdout.write("\b" * len(outstr)) - sys.stdout.flush() - - def main(): parser = argparse.ArgumentParser( description="Test DLC-Live installation by downloading and evaluating a demo DLC project!" @@ -46,21 +40,25 @@ def main(): if not display: print("Running without displaying video") - # make temporary directory in $current + # make temporary directory print("\nCreating temporary directory...\n") - tmp_dir = Path().home() / "dlc-live-tmp" + tmp_dir = Path(dlclive.__file__).parent / "check_install" / "dlc-live-tmp" tmp_dir.mkdir(mode=0o775, exist_ok=True) video_file = str(tmp_dir / "dog_clip.avi") model_dir = tmp_dir / "DLC_Dog_resnet_50_iteration-0_shuffle-0" # download dog test video from github: - if not os.path.exists(video_file): + # Use raw.githubusercontent.com for direct file access + if not Path(video_file).exists(): print(f"Downloading Video to {video_file}") - url_link = "https://github.com/DeepLabCut/DeepLabCut-live/blob/main/check_install/dog_clip.avi?raw=True" - urllib.request.urlretrieve(url_link, video_file, reporthook=urllib_pbar) + url_link = "https://raw.githubusercontent.com/DeepLabCut/DeepLabCut-live/master/check_install/dog_clip.avi" + try: + download_file(url_link, video_file) + except (urllib.error.URLError, IOError) as e: + raise RuntimeError(f"Failed to download video file: {e}") from e else: - print(f"Video already exists at {video_file}") + print(f"Video file already exists at {video_file}, skipping download.") # download model from the DeepLabCut Model Zoo if Path(model_dir / SNAPSHOT_NAME).exists(): diff --git a/dlclive/pose_estimation_pytorch/config.py b/dlclive/pose_estimation_pytorch/config.py new file mode 100644 index 0000000..3546de3 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/config.py @@ -0,0 +1,201 @@ +import logging +from dataclasses import dataclass, fields, asdict + +from collections import OrderedDict +from pathlib import Path + +import torch + +def _parse_dataclass_from_dict(cls: type[dataclass], cfg: dict) -> dataclass: + """Parses a dictionary into a dataclass. + + Args: + cls: The dataclass to parse into. + cfg: The dictionary to parse from. + + Returns: + The dataclass parsed from the dictionary. + """ + # If the config is already a dataclass, return it (it was already parsed before) + if isinstance(cfg, cls): + return cfg + + # Otherwise, parse the dictionary into the dataclass + field_names = {f.name for f in fields(cls)} + known = {k: v for k, v in cfg.items() if k in field_names} + extras = {k: v for k, v in cfg.items() if k not in field_names} + obj = cls(**known) + obj.additional_kwargs = extras + return obj + + +@dataclass +class SkipFrames: + """Configuration for skip frames. + + Skip-frames can be used for top-down models running with a detector. If skip > 0, + then the detector will only be run every `skip` frames. Between frames where the + detector is run, bounding boxes will be computed from the pose estimated in the + previous frame. + + Every `N` frames, the detector will be run to detect bounding boxes for individuals. + In the "skipped" frames between the frames where the object detector is run, the + bounding boxes will be computed from the poses estimated in the previous frame (with + some margin added around the poses). + + Attributes: + skip: The number of frames to skip between each run of the detector. + margin: The margin (in pixels) to use when generating bboxes + """ + + skip: int + margin: int + _age: int = 0 + _detections: dict[str, torch.Tensor] | None = None + + def get_detections(self) -> dict[str, torch.Tensor] | None: + return self._detections + + def update(self, pose: torch.Tensor, w: int, h: int) -> None: + """Generates bounding boxes from a pose. + + Args: + pose: The pose from which to generate bounding boxes. + w: The width of the image. + h: The height of the image. + + Returns: + A dictionary containing the bounding boxes and scores for each detection. + """ + if self._age >= self.skip: + self._age = 0 + self._detections = None + return + + num_det, num_kpts = pose.shape[:2] + size = max(w, h) + + bboxes = torch.zeros((num_det, 4)) + bboxes[:, :2] = ( + torch.min(torch.nan_to_num(pose, size)[..., :2], dim=1)[0] - self.margin + ) + bboxes[:, 2:4] = ( + torch.max(torch.nan_to_num(pose, 0)[..., :2], dim=1)[0] + self.margin + ) + bboxes = torch.clip(bboxes, min=torch.zeros(4), max=torch.tensor([w, h, w, h])) + self._detections = dict(boxes=bboxes, scores=torch.ones(num_det)) + self._age += 1 + + +@dataclass +class TopDownConfig: + """Configuration for top-down models. + + Attributes: + bbox_cutoff: The minimum score required for a bounding box to be considered. + max_detections: The maximum number of detections to keep in a frame. If None, + the `max_detections` will be set to the number of individuals in the model + configuration file when `read_config` is called. + skip_frames: If defined, the detector will only be run every + `skip_frames.skip` frames. + """ + + bbox_cutoff: float = 0.6 + max_detections: int | None = 30 + crop_size: tuple[int, int] = (256, 256) + skip_frames: SkipFrames | None = None + + def read_config(self, model_cfg: dict) -> None: + crop = model_cfg.get("data", {}).get("inference", {}).get("top_down_crop") + if crop is not None: + self.crop_size = (crop["width"], crop["height"]) + + if self.max_detections is None: + individuals = model_cfg.get("metadata", {}).get("individuals", []) + self.max_detections = len(individuals) + + +@dataclass +class DataConfig: + inference: dict + bbox_margin: int | None = None + colormode: str | None = None + train: dict | None = None + + @classmethod + def from_dict(cls, cfg: dict) -> "DataConfig": + return _parse_dataclass_from_dict(cls, cfg) + +@dataclass +class DetectorConfig: + data: DataConfig | dict + model: dict + runner: str | None = None + train_settings: dict | None = None + + @classmethod + def from_dict(cls, cfg: dict) -> "DetectorConfig": + return _parse_dataclass_from_dict(cls, cfg) + +@dataclass +class BaseConfig: + """Pytorch model configuration (DeepLabCut format).""" + model: dict + net_type: str + metadata: dict + data: DataConfig + method: str + detector: DetectorConfig | None = None + train_settings: dict | None = None + inference: dict | None = None + + def __post_init__(self) -> None: + self.data = DataConfig.from_dict(self.data) + if self.detector is not None: + self.detector = DetectorConfig.from_dict(self.detector) + + @classmethod + def from_dict(cls, cfg: dict) -> "BaseConfig": + return _parse_dataclass_from_dict(cls, cfg) + + def to_dict(self) -> dict: + return asdict(self) + +StateDict=OrderedDict[str, torch.Tensor] + +def load_exported_model( + path: str | Path, + map_location: str = "cpu", + weights_only: bool = True, +) -> tuple[BaseConfig, StateDict, StateDict | None]: + """ + Loads a DeepLabCut exported model from a file. + + The exported model is a dictionary containing the following keys: + - config: The base configuration of the model. + - pose: The state dict of the model. + - detector: The state dict of the detector. + + Args: + path: The path to the exported model. + map_location: The device to map the model to. + weights_only: Whether to load only the weights of the model. + + Returns: + A tuple containing the base configuration and the state dicts of the + pose and detector models. (The detector state dict is optional.) + + Raises: + ValueError: If the exported model file does not contain a 'config' and 'pose' key. + FileNotFoundError: If the exported model file does not exist. + """ + raw_data = torch.load(path, map_location=map_location, weights_only=weights_only) + if "config" not in raw_data or "pose" not in raw_data: + raise ValueError( + f"Invalid exported model file: {path}. The exported model must contain " + "a 'config' and 'pose' key. For more information on how to export a model, " + "visit https://deeplabcut.github.io/ and search for `export_model`." + ) + + base_config = BaseConfig.from_dict(raw_data["config"]) + return base_config, raw_data["pose"], raw_data["detector"] \ No newline at end of file diff --git a/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py b/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py index e4ec134..022afdf 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py @@ -42,21 +42,33 @@ class SimCCPredictor(BasePredictor): def __init__( self, simcc_split_ratio: float = 2.0, - apply_softmax: bool = False, normalize_outputs: bool = False, + apply_softmax: bool = True, + sigma: float | int | tuple[float, ...] = 6.0, + decode_beta: float = 150.0, ) -> None: super().__init__() self.simcc_split_ratio = simcc_split_ratio - self.apply_softmax = apply_softmax self.normalize_outputs = normalize_outputs + self.apply_softmax = apply_softmax + + if isinstance(sigma, (float, int)): + self.sigma = np.array([sigma, sigma]) + else: + self.sigma = np.array(sigma) + self.decode_beta = decode_beta def forward( self, stride: float, outputs: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: x, y = outputs["x"].detach(), outputs["y"].detach() + if self.normalize_outputs: x = get_simcc_normalized(x) y = get_simcc_normalized(y) + else: + x = x * (self.sigma[0] * self.decode_beta) + y = y * (self.sigma[1] * self.decode_beta) keypoints, scores = get_simcc_maximum( x.cpu().numpy(), y.cpu().numpy(), self.apply_softmax diff --git a/dlclive/pose_estimation_pytorch/runner.py b/dlclive/pose_estimation_pytorch/runner.py index 11188f1..95fab39 100644 --- a/dlclive/pose_estimation_pytorch/runner.py +++ b/dlclive/pose_estimation_pytorch/runner.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Literal +import warnings import numpy as np import torch @@ -23,92 +24,12 @@ import dlclive.pose_estimation_pytorch.dynamic_cropping as dynamic_cropping from dlclive.core.runner import BaseRunner from dlclive.pose_estimation_pytorch.data.image import AutoPadToDivisor - - -@dataclass -class SkipFrames: - """Configuration for skip frames. - - Skip-frames can be used for top-down models running with a detector. If skip > 0, - then the detector will only be run every `skip` frames. Between frames where the - detector is run, bounding boxes will be computed from the pose estimated in the - previous frame. - - Every `N` frames, the detector will be run to detect bounding boxes for individuals. - In the "skipped" frames between the frames where the object detector is run, the - bounding boxes will be computed from the poses estimated in the previous frame (with - some margin added around the poses). - - Attributes: - skip: The number of frames to skip between each run of the detector. - margin: The margin (in pixels) to use when generating bboxes - """ - - skip: int - margin: int - _age: int = 0 - _detections: dict[str, torch.Tensor] | None = None - - def get_detections(self) -> dict[str, torch.Tensor] | None: - return self._detections - - def update(self, pose: torch.Tensor, w: int, h: int) -> None: - """Generates bounding boxes from a pose. - - Args: - pose: The pose from which to generate bounding boxes. - w: The width of the image. - h: The height of the image. - - Returns: - A dictionary containing the bounding boxes and scores for each detection. - """ - if self._age >= self.skip: - self._age = 0 - self._detections = None - return - - num_det, num_kpts = pose.shape[:2] - size = max(w, h) - - bboxes = torch.zeros((num_det, 4)) - bboxes[:, :2] = ( - torch.min(torch.nan_to_num(pose, size)[..., :2], dim=1)[0] - self.margin - ) - bboxes[:, 2:4] = ( - torch.max(torch.nan_to_num(pose, 0)[..., :2], dim=1)[0] + self.margin - ) - bboxes = torch.clip(bboxes, min=torch.zeros(4), max=torch.tensor([w, h, w, h])) - self._detections = dict(boxes=bboxes, scores=torch.ones(num_det)) - self._age += 1 - - -@dataclass -class TopDownConfig: - """Configuration for top-down models. - - Attributes: - bbox_cutoff: The minimum score required for a bounding box to be considered. - max_detections: The maximum number of detections to keep in a frame. If None, - the `max_detections` will be set to the number of individuals in the model - configuration file when `read_config` is called. - skip_frames: If defined, the detector will only be run every - `skip_frames.skip` frames. - """ - - bbox_cutoff: float = 0.6 - max_detections: int | None = 30 - crop_size: tuple[int, int] = (256, 256) - skip_frames: SkipFrames | None = None - - def read_config(self, model_cfg: dict) -> None: - crop = model_cfg.get("data", {}).get("inference", {}).get("top_down_crop") - if crop is not None: - self.crop_size = (crop["width"], crop["height"]) - - if self.max_detections is None: - individuals = model_cfg.get("metadata", {}).get("individuals", []) - self.max_detections = len(individuals) +from dlclive.pose_estimation_pytorch.config import ( + load_exported_model, + SkipFrames, + TopDownConfig, + BaseConfig +) class PyTorchRunner(BaseRunner): @@ -131,15 +52,25 @@ def __init__( path: str | Path, device: str = "auto", precision: Literal["FP16", "FP32"] = "FP32", - single_animal: bool = True, + single_animal: bool | None = None, dynamic: dict | dynamic_cropping.DynamicCropper | None = None, top_down_config: dict | TopDownConfig | None = None, ) -> None: super().__init__(path) self.device = _parse_device(device) self.precision = precision + if single_animal is not None: + warnings.warn( + "The `single_animal` parameter is deprecated and will be removed " + "in a future version. The number of individuals will be automaticalliy inferred " + "from the model configuration. Remove argument `single_animal` or set " + "`single_animal=None` to accept the inferred value and silence this warning.", + DeprecationWarning, + stacklevel=2, + ) self.single_animal = single_animal - + self.n_individuals = None + self.n_bodyparts = None self.cfg = None self.detector = None self.model = None @@ -191,9 +122,14 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray: frame_batch, offsets_and_scales = self._prepare_top_down(tensor, detections) if len(frame_batch) == 0: - offsets_and_scales = [(0, 0), 1] - else: - tensor = frame_batch # still CHW, batched + zero_pose = ( + np.zeros((self.n_bodyparts, 3)) + if self.n_individuals < 2 else + np.zeros((self.n_individuals, self.n_bodyparts, 3)) + ) + return zero_pose + + tensor = frame_batch # still CHW, batched if self.dynamic is not None: tensor = self.dynamic.crop(tensor) @@ -257,11 +193,21 @@ def init_inference(self, frame: np.ndarray, **kwargs) -> np.ndarray: def load_model(self) -> None: """Loads the model from the provided path.""" - raw_data = torch.load(self.path, map_location="cpu", weights_only=True) + # Load the model from the provided path and get the base config + + # state dictionaries. Validation takes place in `runner_config.py`. + base_cfg, pose_state_dict, detector_state_dict = load_exported_model(self.path) + self.cfg = base_cfg.to_dict() + + # Infer single animal mode and n_bodyparts from model configuration + individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1']) + bodyparts = self.cfg.get("metadata", {}).get("bodyparts", []) + self.n_individuals = len(individuals) + self.n_bodyparts = len(bodyparts) + if self.single_animal is None: + self.single_animal = self.n_individuals == 1 - self.cfg = raw_data["config"] self.model = models.PoseModel.build(self.cfg["model"]) - self.model.load_state_dict(raw_data["pose"]) + self.model.load_state_dict(pose_state_dict) self.model = self.model.to(self.device) self.model.eval() @@ -269,10 +215,10 @@ def load_model(self) -> None: self.model = self.model.half() self.detector = None - if self.dynamic is None and raw_data.get("detector") is not None: + if self.dynamic is None and detector_state_dict is not None: self.detector = models.DETECTORS.build(self.cfg["detector"]["model"]) self.detector.to(self.device) - self.detector.load_state_dict(raw_data["detector"]) + self.detector.load_state_dict(detector_state_dict) self.detector.eval() if self.precision == "FP16": self.detector = self.detector.half() diff --git a/dlclive/utils.py b/dlclive/utils.py index 94a3dba..a378bb6 100644 --- a/dlclive/utils.py +++ b/dlclive/utils.py @@ -6,8 +6,10 @@ """ import warnings - +from pathlib import Path import numpy as np +import urllib.request +import urllib.error from dlclive.exceptions import DLCLiveWarning @@ -31,6 +33,13 @@ DLCLiveWarning, ) +try: + from tqdm import tqdm + + has_tqdm = True +except ImportError: + has_tqdm = False + def convert_to_ubyte(frame: np.ndarray) -> np.ndarray: """Converts an image to unsigned 8-bit integer numpy array. @@ -203,3 +212,68 @@ def decode_fourcc(cc): decoded = "" return decoded + + +def download_file(url: str, filepath: str, chunk_size: int = 8192) -> None: + """ + Download a file from a URL with progress bar and error handling. + + Args: + url: URL to download from + filepath: Local path to save the file + chunk_size: Size of chunks to read (default: 8192 bytes) + + Raises: + urllib.error.URLError: If the download fails + IOError: If the file cannot be written + """ + filepath = Path(filepath) + + # Check if file already exists + if filepath.exists(): + print(f"File already exists at {filepath}, skipping download.") + return + + # Ensure parent directory exists + filepath.parent.mkdir(parents=True, exist_ok=True) + + try: + # Open the URL + with urllib.request.urlopen(url) as response: + # Get file size if available + total_size = int(response.headers.get('Content-Length', 0)) + + # Create progress bar if tqdm is available + if has_tqdm and total_size > 0: + pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") + else: + pbar = None + print("Downloading...") + + # Download and write file + downloaded = 0 + with open(filepath, 'wb') as f: + while True: + chunk = response.read(chunk_size) + if not chunk: + break + f.write(chunk) + downloaded += len(chunk) + if pbar: + pbar.update(len(chunk)) + + if pbar: + pbar.close() + + # Verify file was written + if not filepath.exists() or filepath.stat().st_size == 0: + raise IOError(f"Downloaded file is empty or was not written to {filepath}") + + print(f"Successfully downloaded to {filepath}") + + except urllib.error.HTTPError as e: + raise urllib.error.URLError(f"HTTP error {e.code}: {e.reason} when downloading from {url}") + except urllib.error.URLError as e: + raise urllib.error.URLError(f"Failed to download from {url}: {e.reason}") + except IOError as e: + raise IOError(f"Failed to write file to {filepath}: {e}") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7597994..876ffe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,66 +1,86 @@ -[tool.poetry] +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] name = "deeplabcut-live" version = "3.0.0a0" description = "Class to load exported DeepLabCut networks and perform pose estimation on single frames (from a camera feed)" -authors = ["A. & M. Mathis Labs "] -license = "AGPL-3.0-or-later" readme = "README.md" -homepage = "https://github.com/DeepLabCut/DeepLabCut-live" -repository = "https://github.com/DeepLabCut/DeepLabCut-live" +requires-python = ">=3.10,<3.12" +license = { text = "GNU Affero General Public License v3 or later (AGPLv3+)" } +authors = [ + { name = "A. & M. Mathis Labs", email = "admin@deeplabcut.org" } +] + +keywords = ["deeplabcut", "pose-estimation", "real-time", "deep-learning"] + classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", - "Operating System :: OS Independent" + "Operating System :: OS Independent", +] + +dependencies = [ + "numpy>=1.20,<2", + "ruamel.yaml>=0.17.20", + "colorcet>=3.0.0", + "einops>=0.6.1", + "Pillow>=8.0.0", + "py-cpuinfo>=5.0.0", + "tqdm>=4.62.3", + "pandas>=1.0.1,!=1.5.0", + "tables>=3.8", + "opencv-python-headless>=4.5", + "dlclibrary>=0.0.6", + "scipy>=1.9", +] + +[project.optional-dependencies] +pytorch = [ + "timm>=1.0.7", + "torch>=2.0.0", + "torchvision>=0.15", ] + +tf = [ + "tensorflow>=2.7.0,<2.12; platform_system != 'Darwin' and python_version < '3.11'", + "tensorflow>=2.12.0,<=2.12; platform_system != 'Darwin' and python_version >= '3.11'", + "tensorflow-macos>=2.7.0,<2.12; platform_system == 'Darwin' and python_version < '3.11'", + "tensorflow-macos>=2.12.0,<=2.12; platform_system == 'Darwin' and python_version >= '3.11'", + "tensorflow-io-gcs-filesystem==0.27; platform_system == 'Windows' and python_version < '3.11'", + "tensorflow-io-gcs-filesystem; platform_system != 'Windows'", +] + +[dependency-groups] +dev = [ + "pytest", + "black", + "ruff", +] + +# Keep compatibility with Poetry +# (without this section, Poetry assumes the wrong root directory of the project) +[tool.poetry] packages = [ { include = "dlclive" } ] -include = ["dlclive/check_install/*"] -[tool.poetry.scripts] +[project.scripts] dlc-live-test = "dlclive.check_install.check_install:main" dlc-live-benchmark = "dlclive.benchmark:main" -[tool.poetry.dependencies] -python = ">=3.10,<3.12" -numpy = ">=1.26,<2.0" -"ruamel.yaml" = "^0.17.20" -colorcet = "^3.0.0" -einops = ">=0.6.1" -Pillow = ">=8.0.0" -opencv-python-headless = ">=4.5.0,<5.0.0" -py-cpuinfo = ">=5.0.0" -tqdm = "^4.62.3" -pandas = ">=1.0.1,!=1.5.0" -tables = "^3.8" -pytest = "^8.0" -dlclibrary = ">=0.0.6" +[project.urls] +Homepage = "https://github.com/DeepLabCut/DeepLabCut-live" +Repository = "https://github.com/DeepLabCut/DeepLabCut-live" -# PyTorch models -scipy = ">=1.9" -timm = { version = ">=1.0.7", optional = true } -torch = { version = ">=2.0.0", optional = true } -torchvision = { version = ">=0.15", optional = true } -# TensorFlow models -tensorflow = [ - { version = "^2.7.0,<=2.10", optional = true, platform = "win32" }, - { version = "^2.7.0,<=2.12", optional = true, platform = "linux" }, -] -tensorflow-macos = { version = "^2.7.0,<=2.12", optional = true, markers = "sys_platform == 'darwin'" } -tensorflow-io-gcs-filesystem = [ - { version = "==0.27", optional = true, platform = "win32", python = ">=3.10,<3.11" }, - { version = "*", optional = true, platform = "linux" }, - { version = "*", optional = true, markers = "sys_platform == 'darwin'" } -] +[tool.setuptools] +include-package-data = true -[tool.poetry.extras] -tf = [ "tensorflow", "tensorflow-macos", "tensorflow-io-gcs-filesystem"] -pytorch = ["scipy", "timm", "torch", "torchvision"] +[tool.setuptools.packages.find] +include = ["dlclive*"] -[tool.poetry.group.dev.dependencies] - -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +[tool.setuptools.package-data] +dlclive = ["check_install/*"] \ No newline at end of file diff --git a/reinstall.sh b/reinstall.sh deleted file mode 100755 index 06d4954..0000000 --- a/reinstall.sh +++ /dev/null @@ -1,5 +0,0 @@ -poetry shell # activating current environment -poetry install # creating and installing current project -poetry build # creating the tarball -poetry publish # uploading to pypi -#poetry publish --username= --password= \ No newline at end of file