Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/ci.yaml → .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI
name: Tests

on:
push:
Expand All @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- name: Check out the code
Expand All @@ -27,14 +27,14 @@ jobs:
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH

- name: Create virtual environment
run: uv venv

- name: Install project with dev dependencies
run: |
uv pip install --system .[dev]
run: uv pip install -e .[dev]

- name: Run ruff
run: |
ruff check .
run: uv run ruff check .

- name: Run tests with pytest
run: |
pytest --cov --cov-report=term-missing
run: uv run pytest --cov --cov-report=term-missing
37 changes: 37 additions & 0 deletions .github/workflows/typecheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Typecheck

on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- name: Check out the code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH

- name: Create virtual environment
run: uv venv

- name: Install project with dev dependencies
run: uv pip install -e .[dev]

- name: Run Pyrefly Type Checker
run: uv run pyrefly check
21 changes: 10 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,21 @@ Homepage = "https://github.com/nirum/jetplot"

[project.optional-dependencies]
dev = [
"pre-commit>=2.21.0",
"pyrefly>=0.14.0",
"pytest>=7.4.4",
"pytest-cov>=4.1.0",
"ruff>=0.11.10",
]
docs = [
"mkdocs>=1.5.3",
"mkdocs-material>=9.2.7",
"mkdocstrings[python]>=0.22.0",
]

[tool.pyrefly]
search_path = [
"src/"
]

[tool.ruff]
lint.extend-ignore = ["E111", "E114", "E501", "F403"]
Expand All @@ -39,13 +48,3 @@ package-dir = {"" = "src"}

[tool.setuptools.dynamic]
version = {attr = "jetplot.__version__"}

[tool.uv]
default-groups = ["dev", "docs"]

[dependency-groups]
docs = [
"mkdocs>=1.5.3",
"mkdocs-material>=9.2.7",
"mkdocstrings[python]>=0.22.0",
]
4 changes: 0 additions & 4 deletions requirements.txt

This file was deleted.

6 changes: 6 additions & 0 deletions src/jetplot/chart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def get_bounds(axis, ax=None):
return ax.spines[spine_key].get_bounds()
else:
lower, upper = None, None

# pyrefly: ignore # no-matching-overload, bad-argument-type
for tick, label in zip(ticks(), labels()):
if label.get_text() != "":
if lower is None:
Expand Down Expand Up @@ -188,6 +190,8 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs):
lims = ax.get_ylim()
y0 = lims[0] if y0 is None else y0
y1 = lims[1] if y1 is None else y1

# pyrefly: ignore # no-matching-overload, bad-argument-type
dt = np.mean(np.diff(ax.get_yticks())) if dt is None else dt

new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt)
Expand All @@ -205,6 +209,8 @@ def xclamp(x0=None, x1=None, dt=None, **kwargs):
lims = ax.get_xlim()
x0 = lims[0] if x0 is None else x0
x1 = lims[1] if x1 is None else x1

# pyrefly: ignore # no-matching-overload, bad-argument-type
dt = np.mean(np.diff(ax.get_xticks())) if dt is None else dt

new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt)
Expand Down
5 changes: 5 additions & 0 deletions src/jetplot/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@ def cubehelix(
lambda_ = np.linspace(vmin, vmax, n)
x = lambda_**gamma
phi = 2 * np.pi * (start / 3 + rot * lambda_)

# pyrefly: ignore # bad-argument-type, no-matching-overload
alpha = 0.5 * hue * x * (1.0 - x)
A = np.array([[-0.14861, 1.78277], [-0.29227, -0.90649], [1.97294, 0.0]])
b = np.stack([np.cos(phi), np.sin(phi)])

# pyrefly: ignore # no-matching-overload, bad-argument-type
return Palette((x + alpha * (A @ b)).T)


def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):
# pyrefly: ignore # missing-attribute
return Palette(cm.__getattribute__(cmap)(np.linspace(vmin, vmax, n)))


Expand Down
8 changes: 8 additions & 0 deletions src/jetplot/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs):
xrng = (-1, 1) if xrng is None else xrng
yrng = xrng if yrng is None else yrng

# pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type
xs = np.linspace(*xrng, n)

# pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type
ys = np.linspace(*yrng, n)

xm, ym = np.meshgrid(xs, ys)
Expand Down Expand Up @@ -126,6 +129,8 @@ def cmat(
cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar)

xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows))

# pyrefly: ignore # no-matching-overload, bad-argument-type
for x, y, value in zip(xs.flat, ys.flat, arr.flat):
color = dark_color if (value <= theta) else light_color
annot = f"{{:{fmt}}}".format(value)
Expand All @@ -137,7 +142,10 @@ def cmat(
ax.set_yticks(np.arange(num_rows))
ax.set_yticklabels(labels, fontsize=label_fontsize)

# pyrefly: ignore # bad-argument-type
ax.xaxis.set_minor_locator(FixedLocator(np.arange(num_cols) - 0.5))

# pyrefly: ignore # bad-argument-type
ax.yaxis.set_minor_locator(FixedLocator(np.arange(num_rows) - 0.5))

ax.grid(
Expand Down
5 changes: 5 additions & 0 deletions src/jetplot/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def violinplot(
pc.set_edgecolor(ec)
pc.set_alpha(1.0)

# pyrefly: ignore # no-matching-overload, bad-argument-type
q1, medians, q3 = np.percentile(data, [25, 50, 75], axis=0)

ax.vlines(
Expand All @@ -76,6 +77,7 @@ def violinplot(
if showmeans:
ax.scatter(
xs,
# pyrefly: ignore # no-matching-overload, bad-argument-type
np.mean(data, axis=0),
marker="s",
color=mc,
Expand Down Expand Up @@ -120,6 +122,8 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs):
bins = 25

# compute the histogram

# pyrefly: ignore # no-matching-overload, unexpected-keyword, bad-argument-type
cnt, xe, ye = np.histogram2d(x, y, bins=bins, normed=True, range=range)

# generate the plot
Expand Down Expand Up @@ -362,6 +366,7 @@ def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs):
mean_y = np.mean(y)

transform = (
# pyrefly: ignore # bad-argument-type
Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y)
)

Expand Down
5 changes: 5 additions & 0 deletions src/jetplot/signals.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Tools for signal processing."""


from typing import Callable

import numpy as np
from numpy.typing import ArrayLike

# pyrefly: ignore # missing-module-attribute
from scipy.ndimage import gaussian_filter1d

__all__ = ["smooth", "canoncorr", "participation_ratio", "stable_rank", "normalize"]
Expand Down Expand Up @@ -58,6 +61,8 @@ def canoncorr(X: ArrayLike, Y: ArrayLike) -> ArrayLike:
between linear subspaces." Mathematics of computation 27.123 (1973): 579-594.
"""
# Orthogonalize each subspace

# pyrefly: ignore # no-matching-overload, bad-argument-type
qu, qv = np.linalg.qr(X)[0], np.linalg.qr(Y)[0]

# singular values of the inner product between the orthogonalized spaces
Expand Down
2 changes: 2 additions & 0 deletions src/jetplot/style.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Opinionated matplotlib style defaults."""

from functools import partial


from typing import Mapping, Any

from cycler import cycler
Expand Down
4 changes: 4 additions & 0 deletions src/jetplot/timepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ def wrapper(*args, **kwargs):
calls.append(tstop - tstart)
return results

# pyrefly: ignore # missing-attribute
wrapper.calls = calls
# pyrefly: ignore # missing-attribute
wrapper.mean = lambda: np.mean(calls)
# pyrefly: ignore # missing-attribute
wrapper.serr = lambda: np.std(calls) / np.sqrt(len(calls))
# pyrefly: ignore # missing-attribute
wrapper.summary = lambda: print(
"Runtimes: {} {} {}".format(
hrtime(wrapper.mean()), "\u00b1", hrtime(wrapper.serr())
Expand Down
1 change: 1 addition & 0 deletions src/jetplot/typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from typing import Sequence, Union

__all__ = ["Color", "Palette"]
Expand Down
1 change: 1 addition & 0 deletions tests/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ def test_peaks():
n = 256
xm, ym, zm = demo.peaks(n=n)


assert xm.shape == ym.shape == zm.shape == (n, n)
11 changes: 11 additions & 0 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@


def test_stable_rank():
# pyrefly: ignore # no-matching-overload, bad-argument-type
U, _ = np.linalg.qr(np.random.randn(32, 32))

# pyrefly: ignore # no-matching-overload, bad-argument-type
V, _ = np.linalg.qr(np.random.randn(32, 32))
S = np.random.randn(32)

Expand All @@ -21,6 +24,8 @@ def test_stable_rank():
def test_participation_ratio():
def _random_matrix(evals):
dim = evals.size

# pyrefly: ignore # no-matching-overload, bad-argument-type
Q, _ = np.linalg.qr(np.random.randn(dim, dim))
return Q @ np.diag(evals) @ Q.T

Expand Down Expand Up @@ -57,6 +62,8 @@ def test_cca():

X = rs.randn(n, k)
Y = rs.randn(n, k)

# pyrefly: ignore # no-matching-overload, bad-argument-type
Z = X @ np.linalg.qr(rs.randn(k, k))[0]

# Correlation with itself should be all ones.
Expand All @@ -65,7 +72,11 @@ def test_cca():

# Correlation with a different random subspace.
xy = signals.canoncorr(X, Y)

# pyrefly: ignore # bad-argument-type
assert np.all(xy <= 1.0)

# pyrefly: ignore # bad-argument-type
assert np.all(0.0 <= xy)
assert 0 < np.sum(xy) < k

Expand Down
6 changes: 6 additions & 0 deletions tests/test_timepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from jetplot.timepiece import hrtime, profile
import numpy as np

# pyrefly: ignore # import-error
import pytest
import time

Expand Down Expand Up @@ -37,8 +39,12 @@ def test_profile():
for _ in range(K):
wrapper(T)

# pyrefly: ignore # missing-attribute
assert isinstance(wrapper.calls, list)
assert len(wrapper.calls) == K
# pyrefly: ignore # missing-attribute
assert np.allclose(wrapper.mean(), T, atol=0.01)
# pyrefly: ignore # missing-attribute
assert np.allclose(wrapper.serr(), 0.0, atol=0.01)
# pyrefly: ignore # missing-attribute
assert wrapper.summary() is None
Loading
Loading