diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index 6fe157d..2bf1da2 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -1,10 +1,11 @@ """Plotting utils.""" -from collections.abc import Callable from functools import partial, wraps +from typing import Any, Literal import numpy as np from matplotlib import pyplot as plt +from matplotlib.axes import Axes __all__ = [ "noticks", @@ -114,14 +115,25 @@ def nospines(left=False, bottom=False, top=True, right=True, **kwargs): return ax -def get_bounds(axis, ax=None): - if ax is None: - ax = plt.gca() +def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float, float]: + """Return the axis spine bounds for the given axis. + Parameters + ---------- + axis : str + Axis to inspect, either ``"x"`` or ``"y"``. + ax : matplotlib.axes.Axes | None, optional + Axes object to inspect. If ``None``, the current axes are used. - Result = tuple[Callable[[], list[float]], Callable[[], list[str]], Callable[[], tuple[float, float]], str] + Returns + ------- + tuple[float, float] + Lower and upper bounds of the axis spine. + """ + if ax is None: + ax = plt.gca() - axis_map: dict[str, Result] = { + axis_map: dict[str, Any] = { "x": (ax.get_xticks, ax.get_xticklabels, ax.get_xlim, "bottom"), "y": (ax.get_yticks, ax.get_yticklabels, ax.get_ylim, "left"), } @@ -187,14 +199,20 @@ def identity(x): @axwrapper -def yclamp(y0=None, y1=None, dt=None, **kwargs): +def yclamp( + y0: float | None = None, + y1: float | None = None, + dt: float | None = None, + **kwargs, +) -> Axes: + """Clamp the y-axis to evenly spaced tick marks.""" ax = kwargs["ax"] lims = ax.get_ylim() y0 = lims[0] if y0 is None else y0 y1 = lims[1] if y1 is None else y1 - ticks: list[float] = ax.get_yticks() # pyrefly: ignore + ticks: list[float] = ax.get_yticks() # pyrefly: ignore dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt) new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt) @@ -206,14 +224,20 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs): @axwrapper -def xclamp(x0=None, x1=None, dt=None, **kwargs): +def xclamp( + x0: float | None = None, + x1: float | None = None, + dt: float | None = None, + **kwargs, +) -> Axes: + """Clamp the x-axis to evenly spaced tick marks.""" ax = kwargs["ax"] lims = ax.get_xlim() x0 = lims[0] if x0 is None else x0 x1 = lims[1] if x1 is None else x1 - ticks: list[float] = ax.get_xticks() # pyrefly: ignore + ticks: list[float] = ax.get_xticks() # pyrefly: ignore dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt) new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt) diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 040e2ad..088d91f 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -3,8 +3,11 @@ import numpy as np from matplotlib import cm from matplotlib import pyplot as plt +from matplotlib.axes import Axes from matplotlib.colors import LinearSegmentedColormap, to_hex +from matplotlib.figure import Figure from matplotlib.typing import ColorType +from numpy.typing import NDArray from .chart_utils import noticks @@ -16,15 +19,18 @@ class Palette(list[ColorType]): @property def hex(self): + """Return the palette colors as hexadecimal strings.""" return Palette([to_hex(rgb) for rgb in self]) @property - def cmap(self): + def cmap(self) -> LinearSegmentedColormap: + """Return the palette as a Matplotlib colormap.""" return LinearSegmentedColormap.from_list("", self) - def plot(self, figsize=(5, 1)): + def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]: + """Visualize the colors in the palette.""" fig, axs = plt.subplots(1, len(self), figsize=figsize) - for c, ax in zip(self, axs, strict=True): # pyrefly: ignore + for c, ax in zip(self, axs, strict=True): # pyrefly: ignore ax.set_facecolor(c) ax.set_aspect("equal") noticks(ax=ax) @@ -54,7 +60,13 @@ def cubehelix( return Palette(colors) -def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0): +def cmap_colors( + cmap: str, + n: int, + vmin: float = 0.0, + vmax: float = 1.0, +) -> Palette: + """Extract ``n`` colors from a Matplotlib colormap.""" return Palette(getattr(cm, cmap)(np.linspace(vmin, vmax, n))) @@ -371,6 +383,8 @@ def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0): def rainbow(k: int) -> Palette: + """Return a palette of distinct colors from several base palettes.""" + _colors = ( blue, orange, diff --git a/src/jetplot/images.py b/src/jetplot/images.py index c464786..afb71a7 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -1,5 +1,6 @@ """Image visualization tools.""" +from collections.abc import Callable from functools import partial import numpy as np @@ -79,7 +80,15 @@ def img( @plotwrapper -def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs): +def fsurface( + func: Callable[..., np.ndarray], + xrng: tuple[float, float] | None = None, + yrng: tuple[float, float] | None = None, + n: int = 100, + nargs: int = 2, + **kwargs, +) -> None: + """Plot a 2‑D function as a filled surface.""" xrng = (-1, 1) if xrng is None else xrng yrng = xrng if yrng is None else yrng @@ -127,7 +136,7 @@ def cmat( xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy") - for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore + for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore color = dark_color if (value <= theta) else light_color annot = f"{{:{fmt}}}".format(value) ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index e57f18b..6b17e77 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -4,6 +4,9 @@ from matplotlib.patches import Ellipse from matplotlib.transforms import Affine2D from matplotlib.typing import ColorType +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from collections.abc import Sequence from numpy.typing import NDArray from scipy.stats import gaussian_kde from sklearn.covariance import EmpiricalCovariance, MinCovDet @@ -35,7 +38,8 @@ def violinplot( showmeans=False, showquartiles=True, **kwargs, -): +) -> Axes: + """Violin plot with customizable elements.""" _ = kwargs.pop("fig") ax = kwargs.pop("ax") @@ -86,6 +90,8 @@ def violinplot( zorder=20, ) + return ax + @plotwrapper def hist(*args, **kwargs): @@ -249,11 +255,17 @@ def bar( @plotwrapper -def lines(x, lines=None, cmap="viridis", **kwargs): +def lines( + x: NDArray[np.floating] | NDArray[np.integer], + lines: list[NDArray[np.floating]] | None = None, + cmap: str = "viridis", + **kwargs, +) -> Axes: + """Plot multiple lines using a color map.""" ax = kwargs["ax"] if lines is None: - lines = list(x) + lines = list(x) # pyrefly: ignore x = np.arange(len(lines[0])) else: @@ -263,6 +275,8 @@ def lines(x, lines=None, cmap="viridis", **kwargs): for line, color in zip(lines, colors, strict=False): ax.plot(x, line, color=color) + return ax + @plotwrapper def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **kwargs): @@ -281,7 +295,15 @@ def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **k @figwrapper -def ridgeline(t, xs, colors, edgecolor="#ffffff", ymax=0.6, **kwargs): +def ridgeline( + t: NDArray[np.floating], + xs: Sequence[NDArray[np.floating]], + colors: Sequence[ColorType], + edgecolor: ColorType = "#ffffff", + ymax: float = 0.6, + **kwargs, +) -> tuple[Figure, list[Axes]]: + """Stacked density plots reminiscent of a ridgeline plot.""" fig = kwargs["fig"] axs = [] diff --git a/src/jetplot/style.py b/src/jetplot/style.py index 4cd5eba..930535a 100644 --- a/src/jetplot/style.py +++ b/src/jetplot/style.py @@ -140,7 +140,7 @@ def set_defaults( def available_fonts() -> list[str]: """Returns a list of available fonts.""" - return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore + return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore def install_fonts(filepath: str): @@ -150,7 +150,7 @@ def install_fonts(filepath: str): font_files = fm.findSystemFonts(fontpaths=[filepath]) for font_file in font_files: - fm.fontManager.addfont(font_file) # pyrefly: ignore + fm.fontManager.addfont(font_file) # pyrefly: ignore new_fonts = set(available_fonts()) - original_fonts if new_fonts: diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index dd86946..210d536 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -9,28 +9,30 @@ class Stopwatch: - def __init__(self, name=""): + """Simple timer utility for measuring code execution time.""" + + def __init__(self, name: str = "") -> None: self.name = name self.start = time.perf_counter() self.absolute_start = time.perf_counter() - def __str__(self): + def __str__(self) -> str: return "\u231a Stopwatch for: " + self.name @property - def elapsed(self): + def elapsed(self) -> float: current = time.perf_counter() elapsed = current - self.start self.start = time.perf_counter() return elapsed - def checkpoint(self, name=""): + def checkpoint(self, name: str = "") -> None: print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip()) - def __enter__(self): + def __enter__(self) -> "Stopwatch": return self - def __exit__(self, *_): + def __exit__(self, *_: object) -> None: total = hrtime(time.perf_counter() - self.absolute_start) print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}")