Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/
*.egg-info/

# tests
coverage.json
htmlcov/
.tox/
.coverage
Expand Down
19 changes: 19 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
default: test

build:
uv build

upload:
uv upload

docs:
mkdocs build --strict

format:
uv run ruff format

typecheck:
uv run pyrefly check

test:
uv run pytest --cov=jetplot --cov-report=term
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "jetplot"
dynamic = ["version"]
requires-python = ">=3.7"
requires-python = ">=3.10"
dependencies = [
"numpy>=1.19",
"scipy",
Expand All @@ -24,10 +24,12 @@ Homepage = "https://github.com/nirum/jetplot"

[project.optional-dependencies]
dev = [
"matplotlib-stubs>=0.1.0",
"pyrefly>=0.14.0",
"pytest>=7.4.4",
"pytest-cov>=4.1.0",
"ruff>=0.11.10",
"scipy-stubs>=1.15.3.0",
]
docs = [
"mkdocs>=1.5.3",
Expand All @@ -36,6 +38,8 @@ docs = [
]

[tool.pyrefly]
python_version = "3.12"
search_path = ["src"]
project_includes = ["src/**"]
project_excludes = ["**/.[!/.]*", "**/*venv/**/*", "build/**/*"]

Expand Down
7 changes: 3 additions & 4 deletions src/jetplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Jetplot is a set of useful utility functions for scientific python."""

__version__ = "0.6.3"
__version__ = "0.6.4"

from . import colors as c, typing # noqa: F401
from . import colors as c # noqa: F401
from .chart_utils import *
from .colors import *
from .images import *
from .plots import *
from .style import *

from .signals import *
from .style import *
from .timepiece import *
14 changes: 10 additions & 4 deletions src/jetplot/chart_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Plotting utils."""

from collections.abc import Callable
from functools import partial, wraps

import numpy as np
Expand Down Expand Up @@ -117,7 +118,10 @@ def get_bounds(axis, ax=None):
if ax is None:
ax = plt.gca()

axis_map = {

Result = tuple[Callable[[], list[float]], Callable[[], list[str]], Callable[[], tuple[float, float]], str]

axis_map: dict[str, Result] = {
"x": (ax.get_xticks, ax.get_xticklabels, ax.get_xlim, "bottom"),
"y": (ax.get_yticks, ax.get_yticklabels, ax.get_ylim, "left"),
}
Expand All @@ -130,7 +134,7 @@ def get_bounds(axis, ax=None):
else:
lower, upper = None, None

for tick, label in zip(list(ticks()), list(labels())):
for tick, label in zip(list(ticks()), list(labels()), strict=True):
if label.get_text() != "":
if lower is None:
lower = tick
Expand Down Expand Up @@ -190,7 +194,8 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs):
y0 = lims[0] if y0 is None else y0
y1 = lims[1] if y1 is None else y1

dt = float(np.mean(np.diff(ax.get_yticks()))) if dt is None else float(dt)
ticks: list[float] = ax.get_yticks() # pyrefly: ignore
dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt)

new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt)
ax.set_yticks(new_ticks)
Expand All @@ -208,7 +213,8 @@ def xclamp(x0=None, x1=None, dt=None, **kwargs):
x0 = lims[0] if x0 is None else x0
x1 = lims[1] if x1 is None else x1

dt = float(np.mean(np.diff(ax.get_xticks()))) if dt is None else float(dt)
ticks: list[float] = ax.get_xticks() # pyrefly: ignore
dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt)

new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt)
ax.set_xticks(new_ticks)
Expand Down
23 changes: 16 additions & 7 deletions src/jetplot/colors.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""Colorschemes"""

import numpy as np
from matplotlib import cm, pyplot as plt
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, to_hex
from matplotlib.typing import ColorType

from .chart_utils import noticks

__all__ = ["Palette", "cubehelix", "cmap_colors"]


class Palette(list):
"""Color palette."""
class Palette(list[ColorType]):
"""Color palette based on a list of values."""

@property
def hex(self):
Expand All @@ -22,7 +24,7 @@ def cmap(self):

def plot(self, figsize=(5, 1)):
fig, axs = plt.subplots(1, len(self), figsize=figsize)
for c, ax in zip(self, axs):
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
ax.set_facecolor(c)
ax.set_aspect("equal")
noticks(ax=ax)
Expand All @@ -31,18 +33,25 @@ def plot(self, figsize=(5, 1)):


def cubehelix(
n: int, vmin=0.85, vmax=0.15, gamma: float = 1.0, start=0.0, rot=0.4, hue=0.8
n: int,
vmin: float = 0.85,
vmax: float = 0.15,
gamma: float = 1.0,
start: float = 0.0,
rot: float = 0.4,
hue: float = 0.8,
):
"""Cubehelix parameterized colormap."""
lambda_ = np.linspace(vmin, vmax, n)
x = lambda_**gamma
phi = 2 * np.pi * (start / 3 + rot * lambda_)

alpha = 0.5 * hue * x * (1.0 - x)
alpha = 0.5 * hue * x * (1.0 - x) # pyrefly: ignore
A = np.array([[-0.14861, 1.78277], [-0.29227, -0.90649], [1.97294, 0.0]])
b = np.stack([np.cos(phi), np.sin(phi)])

return Palette((x + alpha * (A @ b)).T)
colors: list[tuple[float, float, float]] = (x + alpha * (A @ b)).T.tolist()
return Palette(colors)


def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):
Expand Down
27 changes: 0 additions & 27 deletions src/jetplot/demo.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/jetplot/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def cmat(

xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy")

for x, y, value in zip(xs.flat, ys.flat, arr.flat):
for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
color = dark_color if (value <= theta) else light_color
annot = f"{{:{fmt}}}".format(value)
ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize)
Expand Down
19 changes: 11 additions & 8 deletions src/jetplot/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import numpy as np
from matplotlib.patches import Ellipse
from matplotlib.transforms import Affine2D
from matplotlib.typing import ColorType
from numpy.typing import NDArray
from scipy.stats import gaussian_kde
from sklearn.covariance import EmpiricalCovariance, MinCovDet

from .chart_utils import figwrapper, nospines, plotwrapper
from .colors import cmap_colors, neutral
from .typing import Color

__all__ = [
"hist",
Expand All @@ -25,7 +26,7 @@

@plotwrapper
def violinplot(
data,
data: NDArray[np.floating],
xs,
fc=neutral[3],
ec=neutral[9],
Expand Down Expand Up @@ -54,6 +55,7 @@ def violinplot(
pc.set_edgecolor(ec)
pc.set_alpha(1.0)

# pyrefly: ignore # no-matching-overload, bad-argument-type
q1, medians, q3 = np.percentile(data, [25, 50, 75], axis=0)

ax.vlines(
Expand All @@ -76,7 +78,8 @@ def violinplot(
if showmeans:
ax.scatter(
xs,
np.mean(data, axis=0, dtype=float),
# pyrefly: ignore # no-matching-overload, bad-argument-type
np.mean(data, axis=0),
marker="s",
color=mc,
s=15,
Expand Down Expand Up @@ -122,7 +125,7 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs):
bins = 25

# compute the histogram

# pyrefly: ignore # no-matching-overload, bad-argument-type
cnt, xe, ye = np.histogram2d(x, y, bins=bins, range=range_, density=True)

# generate the plot
Expand All @@ -139,10 +142,10 @@ def errorplot(
y,
yerr,
method="patch",
color: Color = "#222222",
color: ColorType = "#222222",
xscale="linear",
fmt="-",
err_color: Color = "#cccccc",
err_color: ColorType = "#cccccc",
alpha_fill=1.0,
clip_on=True,
**kwargs,
Expand Down Expand Up @@ -257,7 +260,7 @@ def lines(x, lines=None, cmap="viridis", **kwargs):
lines = list(lines)

colors = cmap_colors(cmap, len(lines))
for line, color in zip(lines, colors):
for line, color in zip(lines, colors, strict=False):
ax.plot(x, line, color=color)


Expand All @@ -282,7 +285,7 @@ def ridgeline(t, xs, colors, edgecolor="#ffffff", ymax=0.6, **kwargs):
fig = kwargs["fig"]
axs = []

for k, (x, c) in enumerate(zip(xs, colors)):
for k, (x, c) in enumerate(zip(xs, colors, strict=False)):
ax = fig.add_subplot(len(xs), 1, k + 1)
y = gaussian_kde(x).evaluate(t)
ax.fill_between(t, y, color=c, clip_on=False)
Expand Down
38 changes: 25 additions & 13 deletions src/jetplot/signals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Tools for signal processing."""

from __future__ import annotations

from typing import Callable
from typing import Protocol, SupportsIndex

import numpy as np
from numpy.typing import ArrayLike, NDArray

from scipy.ndimage import gaussian_filter1d # pyrefly: ignore[missing-module-attribute]
from scipy.ndimage import gaussian_filter1d

__all__ = ["smooth", "canoncorr", "participation_ratio", "stable_rank", "normalize"]

FloatArray = NDArray[np.floating]


def smooth(x, sigma=1.0, axis=0):
"""Smooths a 1D signal with a gaussian filter.
Expand All @@ -31,18 +33,21 @@ def stable_rank(X):
return svals_sq.sum() / svals_sq.max()


def participation_ratio(C: NDArray[np.floating]) -> float:
def participation_ratio(C: np.ndarray) -> float:
"""Compute the participation ratio of a square matrix."""

assert C.ndim == 2, "C must be a matrix"
assert C.shape[0] == C.shape[1], "C must be a square matrix"
if C.ndim != 2:
raise ValueError("C must be a matrix")

if C.shape[0] != C.shape[1]:
raise ValueError("C must be a square matrix")

diag_sum = float(np.trace(C))
diag_sq_sum = float(np.trace(C @ C))
return diag_sum**2 / diag_sq_sum


def canoncorr(X: ArrayLike, Y: ArrayLike) -> NDArray[np.floating]:
def canoncorr(X: FloatArray, Y: FloatArray) -> FloatArray:
"""Canonical correlation between two subspaces.

Args:
Expand All @@ -58,22 +63,29 @@ def canoncorr(X: ArrayLike, Y: ArrayLike) -> NDArray[np.floating]:
the principal vectors and angles via the QR decomposition [2]_.

References:
.. [1] Angles between flats. (2016, August 4). In Wikipedia, The Free Encyclopedia.
.. [1] Angles between flats. (2016, August 4). In Wikipedia, The Free Encyclopedia
https://en.wikipedia.org/w/index.php?title=Angles_between_flats
.. [2] Björck, Ȧke, and Gene H. Golub. "Numerical methods for computing angles
between linear subspaces." Mathematics of computation 27.123 (1973): 579-594.
"""
# Orthogonalize each subspace

qu = np.linalg.qr(np.asarray(X))[0]
qv = np.linalg.qr(np.asarray(Y))[0]
# pyrefly: ignore # no-matching-overload, bad-argument-type
Qx, _ = np.linalg.qr(X, mode="reduced")
# pyrefly: ignore # no-matching-overload, bad-argument-type
Qy, _ = np.linalg.qr(Y, mode="reduced")

# singular values of the inner product between the orthogonalized spaces
return np.linalg.svd(qu.T.dot(qv), compute_uv=False, full_matrices=False)
return np.linalg.svd(Qx.T @ Qy, compute_uv=False)


class NormFunction(Protocol):
def __call__(
self, x: ArrayLike, *, axis: SupportsIndex, keepdims: bool
) -> NDArray[np.floating]: ...


def normalize(
X: ArrayLike, axis: int = -1, norm: Callable[[ArrayLike], float] = np.linalg.norm
X: ArrayLike, axis: int = -1, norm: NormFunction = np.linalg.norm
) -> NDArray[np.floating]:
"""Normalizes elements of an array or matrix.

Expand Down
Loading