Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ docs = [
]

[tool.pyrefly]
search_path = [
"src/"
]
project_includes = ["src/**"]
project_excludes = ["**/.[!/.]*", "**/*venv/**/*", "build/**/*"]

[tool.ruff]
lint.extend-ignore = ["E111", "E114", "E501", "F403"]
Expand Down
9 changes: 3 additions & 6 deletions src/jetplot/chart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def get_bounds(axis, ax=None):
else:
lower, upper = None, None

# pyrefly: ignore # no-matching-overload, bad-argument-type
for tick, label in zip(ticks(), labels()):
for tick, label in zip(list(ticks()), list(labels())):
if label.get_text() != "":
if lower is None:
lower = tick
Expand Down Expand Up @@ -191,8 +190,7 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs):
y0 = lims[0] if y0 is None else y0
y1 = lims[1] if y1 is None else y1

# pyrefly: ignore # no-matching-overload, bad-argument-type
dt = np.mean(np.diff(ax.get_yticks())) if dt is None else dt
dt = float(np.mean(np.diff(ax.get_yticks()))) if dt is None else float(dt)

new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt)
ax.set_yticks(new_ticks)
Expand All @@ -210,8 +208,7 @@ def xclamp(x0=None, x1=None, dt=None, **kwargs):
x0 = lims[0] if x0 is None else x0
x1 = lims[1] if x1 is None else x1

# pyrefly: ignore # no-matching-overload, bad-argument-type
dt = np.mean(np.diff(ax.get_xticks())) if dt is None else dt
dt = float(np.mean(np.diff(ax.get_xticks()))) if dt is None else float(dt)

new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt)
ax.set_xticks(new_ticks)
Expand Down
5 changes: 1 addition & 4 deletions src/jetplot/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,15 @@ def cubehelix(
x = lambda_**gamma
phi = 2 * np.pi * (start / 3 + rot * lambda_)

# pyrefly: ignore # bad-argument-type, no-matching-overload
alpha = 0.5 * hue * x * (1.0 - x)
A = np.array([[-0.14861, 1.78277], [-0.29227, -0.90649], [1.97294, 0.0]])
b = np.stack([np.cos(phi), np.sin(phi)])

# pyrefly: ignore # no-matching-overload, bad-argument-type
return Palette((x + alpha * (A @ b)).T)


def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):
# pyrefly: ignore # missing-attribute
return Palette(cm.__getattribute__(cmap)(np.linspace(vmin, vmax, n)))
return Palette(getattr(cm, cmap)(np.linspace(vmin, vmax, n)))


black = "#000000"
Expand Down
23 changes: 18 additions & 5 deletions src/jetplot/demo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import numpy as np
from numpy.typing import NDArray


def peaks(n=256):
"""2D peaks function."""
pts = np.linspace(-3, 3, n)
xm, ym = np.meshgrid(pts, pts)
def peaks(n: int = 256) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
"""Generate the MATLAB ``peaks`` surface.

Parameters
----------
n:
Number of points per axis.

Returns
-------
tuple of ``ndarray``
``x`` grid, ``y`` grid and the function value ``z``.
"""

pts = np.linspace(-3.0, 3.0, n)
xm, ym = np.meshgrid(pts, pts, indexing="xy")
zm = (
3 * (1 - xm) ** 2 * np.exp(-(xm**2))
- (ym + 1) ** 2
- 10 * (0.2 * xm - xm**3 - ym**5) * np.exp(-(xm**2) - (ym**2))
- (1 / 3) * np.exp(-((xm + 1) ** 2) - ym**2)
- (1.0 / 3.0) * np.exp(-((xm + 1) ** 2) - ym**2)
)
return xm, ym, zm
18 changes: 6 additions & 12 deletions src/jetplot/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,13 @@ def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs):
xrng = (-1, 1) if xrng is None else xrng
yrng = xrng if yrng is None else yrng

# pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type
xs = np.linspace(*xrng, n)

# pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type
ys = np.linspace(*yrng, n)
xs = np.linspace(xrng[0], xrng[1], n)
ys = np.linspace(yrng[0], yrng[1], n)

xm, ym = np.meshgrid(xs, ys)

if nargs == 1:
zz = np.vstack((xm.ravel(), ym.ravel()))
zz = np.vstack([xm.ravel(), ym.ravel()])
args = (zz,)
elif nargs == 2:
args = (xm.ravel(), ym.ravel())
Expand Down Expand Up @@ -128,9 +125,8 @@ def cmat(
ax = kwargs.pop("ax")
cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar)

xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows))
xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy")

# pyrefly: ignore # no-matching-overload, bad-argument-type
for x, y, value in zip(xs.flat, ys.flat, arr.flat):
color = dark_color if (value <= theta) else light_color
annot = f"{{:{fmt}}}".format(value)
Expand All @@ -142,11 +138,9 @@ def cmat(
ax.set_yticks(np.arange(num_rows))
ax.set_yticklabels(labels, fontsize=label_fontsize)

# pyrefly: ignore # bad-argument-type
ax.xaxis.set_minor_locator(FixedLocator(np.arange(num_cols) - 0.5))
ax.xaxis.set_minor_locator(FixedLocator((np.arange(num_cols) - 0.5).tolist()))

# pyrefly: ignore # bad-argument-type
ax.yaxis.set_minor_locator(FixedLocator(np.arange(num_rows) - 0.5))
ax.yaxis.set_minor_locator(FixedLocator((np.arange(num_rows) - 0.5).tolist()))

ax.grid(
visible=True,
Expand Down
17 changes: 9 additions & 8 deletions src/jetplot/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def violinplot(
pc.set_edgecolor(ec)
pc.set_alpha(1.0)

# pyrefly: ignore # no-matching-overload, bad-argument-type
q1, medians, q3 = np.percentile(data, [25, 50, 75], axis=0)

ax.vlines(
Expand All @@ -77,8 +76,7 @@ def violinplot(
if showmeans:
ax.scatter(
xs,
# pyrefly: ignore # no-matching-overload, bad-argument-type
np.mean(data, axis=0),
np.mean(data, axis=0, dtype=float),
marker="s",
color=mc,
s=15,
Expand Down Expand Up @@ -116,15 +114,16 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs):

# parse inputs
if range is None:
range = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]])
range_ = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]])
else:
range_ = range

if bins is None:
bins = 25

# compute the histogram

# pyrefly: ignore # no-matching-overload, unexpected-keyword, bad-argument-type
cnt, xe, ye = np.histogram2d(x, y, bins=bins, normed=True, range=range)
cnt, xe, ye = np.histogram2d(x, y, bins=bins, range=range_, density=True)

# generate the plot
ax = kwargs["ax"]
Expand Down Expand Up @@ -366,8 +365,10 @@ def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs):
mean_y = np.mean(y)

transform = (
# pyrefly: ignore # bad-argument-type
Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y)
Affine2D()
.rotate_deg(45)
.scale(float(scale_x), float(scale_y))
.translate(float(mean_x), float(mean_y))
)

ellipse.set_transform(transform + ax.transData)
Expand Down
27 changes: 16 additions & 11 deletions src/jetplot/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from typing import Callable

import numpy as np
from numpy.typing import ArrayLike
from numpy.typing import ArrayLike, NDArray

# pyrefly: ignore # missing-module-attribute
from scipy.ndimage import gaussian_filter1d
from scipy.ndimage import gaussian_filter1d # pyrefly: ignore[missing-module-attribute]

__all__ = ["smooth", "canoncorr", "participation_ratio", "stable_rank", "normalize"]

Expand All @@ -32,14 +31,18 @@ def stable_rank(X):
return svals_sq.sum() / svals_sq.max()


def participation_ratio(C):
"""Computes the participation ratio of a square matrix."""
def participation_ratio(C: NDArray[np.floating]) -> float:
"""Compute the participation ratio of a square matrix."""

assert C.ndim == 2, "C must be a matrix"
assert C.shape[0] == C.shape[1], "C must be a square matrix"
return np.trace(C) ** 2 / np.trace(np.linalg.matrix_power(C, 2))

diag_sum = float(np.trace(C))
diag_sq_sum = float(np.trace(C @ C))
return diag_sum**2 / diag_sq_sum


def canoncorr(X: ArrayLike, Y: ArrayLike) -> ArrayLike:
def canoncorr(X: ArrayLike, Y: ArrayLike) -> NDArray[np.floating]:
"""Canonical correlation between two subspaces.

Args:
Expand All @@ -62,14 +65,16 @@ def canoncorr(X: ArrayLike, Y: ArrayLike) -> ArrayLike:
"""
# Orthogonalize each subspace

# pyrefly: ignore # no-matching-overload, bad-argument-type
qu, qv = np.linalg.qr(X)[0], np.linalg.qr(Y)[0]
qu = np.linalg.qr(np.asarray(X))[0]
qv = np.linalg.qr(np.asarray(Y))[0]

# singular values of the inner product between the orthogonalized spaces
return np.linalg.svd(qu.T.dot(qv), compute_uv=False, full_matrices=False)


def normalize(X: ArrayLike, axis: int = -1, norm: Callable = np.linalg.norm):
def normalize(
X: ArrayLike, axis: int = -1, norm: Callable[[ArrayLike], float] = np.linalg.norm
) -> NDArray[np.floating]:
"""Normalizes elements of an array or matrix.

Args:
Expand All @@ -80,4 +85,4 @@ def normalize(X: ArrayLike, axis: int = -1, norm: Callable = np.linalg.norm):
Returns:
Xn: Arrays that have been normalized using to the given function.
"""
return X / norm(X, axis=axis, keepdims=True)
return np.asarray(X) / norm(X, axis=axis, keepdims=True)
24 changes: 13 additions & 11 deletions src/jetplot/timepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,19 @@ def wrapper(*args, **kwargs):
calls.append(tstop - tstart)
return results

# pyrefly: ignore # missing-attribute
wrapper.calls = calls
# pyrefly: ignore # missing-attribute
wrapper.mean = lambda: np.mean(calls)
# pyrefly: ignore # missing-attribute
wrapper.serr = lambda: np.std(calls) / np.sqrt(len(calls))
# pyrefly: ignore # missing-attribute
wrapper.summary = lambda: print(
"Runtimes: {} {} {}".format(
hrtime(wrapper.mean()), "\u00b1", hrtime(wrapper.serr())
)
)

def mean() -> float:
return float(np.mean(calls))

def serr() -> float:
return float(np.std(calls) / np.sqrt(len(calls)))

def summary() -> None:
print(f"Runtimes: {hrtime(mean())} \u00b1 {hrtime(serr())}")

wrapper.mean = mean
wrapper.serr = serr
wrapper.summary = summary

return wrapper
Loading