diff --git a/pyproject.toml b/pyproject.toml index f67f3b1..26d7b0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,8 @@ docs = [ ] [tool.pyrefly] -search_path = [ - "src/" -] +project_includes = ["src/**"] +project_excludes = ["**/.[!/.]*", "**/*venv/**/*", "build/**/*"] [tool.ruff] lint.extend-ignore = ["E111", "E114", "E501", "F403"] diff --git a/src/jetplot/chart_utils.py b/src/jetplot/chart_utils.py index a1f8d60..5cf3f29 100644 --- a/src/jetplot/chart_utils.py +++ b/src/jetplot/chart_utils.py @@ -130,8 +130,7 @@ def get_bounds(axis, ax=None): else: lower, upper = None, None - # pyrefly: ignore # no-matching-overload, bad-argument-type - for tick, label in zip(ticks(), labels()): + for tick, label in zip(list(ticks()), list(labels())): if label.get_text() != "": if lower is None: lower = tick @@ -191,8 +190,7 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs): y0 = lims[0] if y0 is None else y0 y1 = lims[1] if y1 is None else y1 - # pyrefly: ignore # no-matching-overload, bad-argument-type - dt = np.mean(np.diff(ax.get_yticks())) if dt is None else dt + dt = float(np.mean(np.diff(ax.get_yticks()))) if dt is None else float(dt) new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt) ax.set_yticks(new_ticks) @@ -210,8 +208,7 @@ def xclamp(x0=None, x1=None, dt=None, **kwargs): x0 = lims[0] if x0 is None else x0 x1 = lims[1] if x1 is None else x1 - # pyrefly: ignore # no-matching-overload, bad-argument-type - dt = np.mean(np.diff(ax.get_xticks())) if dt is None else dt + dt = float(np.mean(np.diff(ax.get_xticks()))) if dt is None else float(dt) new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt) ax.set_xticks(new_ticks) diff --git a/src/jetplot/colors.py b/src/jetplot/colors.py index 2ff7f67..86b7b63 100644 --- a/src/jetplot/colors.py +++ b/src/jetplot/colors.py @@ -38,18 +38,15 @@ def cubehelix( x = lambda_**gamma phi = 2 * np.pi * (start / 3 + rot * lambda_) - # pyrefly: ignore # bad-argument-type, no-matching-overload alpha = 0.5 * hue * x * (1.0 - x) A = np.array([[-0.14861, 1.78277], [-0.29227, -0.90649], [1.97294, 0.0]]) b = np.stack([np.cos(phi), np.sin(phi)]) - # pyrefly: ignore # no-matching-overload, bad-argument-type return Palette((x + alpha * (A @ b)).T) def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0): - # pyrefly: ignore # missing-attribute - return Palette(cm.__getattribute__(cmap)(np.linspace(vmin, vmax, n))) + return Palette(getattr(cm, cmap)(np.linspace(vmin, vmax, n))) black = "#000000" diff --git a/src/jetplot/demo.py b/src/jetplot/demo.py index 6ec7fc6..c910e29 100644 --- a/src/jetplot/demo.py +++ b/src/jetplot/demo.py @@ -1,14 +1,27 @@ import numpy as np +from numpy.typing import NDArray -def peaks(n=256): - """2D peaks function.""" - pts = np.linspace(-3, 3, n) - xm, ym = np.meshgrid(pts, pts) +def peaks(n: int = 256) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]: + """Generate the MATLAB ``peaks`` surface. + + Parameters + ---------- + n: + Number of points per axis. + + Returns + ------- + tuple of ``ndarray`` + ``x`` grid, ``y`` grid and the function value ``z``. + """ + + pts = np.linspace(-3.0, 3.0, n) + xm, ym = np.meshgrid(pts, pts, indexing="xy") zm = ( 3 * (1 - xm) ** 2 * np.exp(-(xm**2)) - (ym + 1) ** 2 - 10 * (0.2 * xm - xm**3 - ym**5) * np.exp(-(xm**2) - (ym**2)) - - (1 / 3) * np.exp(-((xm + 1) ** 2) - ym**2) + - (1.0 / 3.0) * np.exp(-((xm + 1) ** 2) - ym**2) ) return xm, ym, zm diff --git a/src/jetplot/images.py b/src/jetplot/images.py index 7b4d48a..698bd5c 100644 --- a/src/jetplot/images.py +++ b/src/jetplot/images.py @@ -83,16 +83,13 @@ def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs): xrng = (-1, 1) if xrng is None else xrng yrng = xrng if yrng is None else yrng - # pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type - xs = np.linspace(*xrng, n) - - # pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type - ys = np.linspace(*yrng, n) + xs = np.linspace(xrng[0], xrng[1], n) + ys = np.linspace(yrng[0], yrng[1], n) xm, ym = np.meshgrid(xs, ys) if nargs == 1: - zz = np.vstack((xm.ravel(), ym.ravel())) + zz = np.vstack([xm.ravel(), ym.ravel()]) args = (zz,) elif nargs == 2: args = (xm.ravel(), ym.ravel()) @@ -128,9 +125,8 @@ def cmat( ax = kwargs.pop("ax") cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar) - xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows)) + xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy") - # pyrefly: ignore # no-matching-overload, bad-argument-type for x, y, value in zip(xs.flat, ys.flat, arr.flat): color = dark_color if (value <= theta) else light_color annot = f"{{:{fmt}}}".format(value) @@ -142,11 +138,9 @@ def cmat( ax.set_yticks(np.arange(num_rows)) ax.set_yticklabels(labels, fontsize=label_fontsize) - # pyrefly: ignore # bad-argument-type - ax.xaxis.set_minor_locator(FixedLocator(np.arange(num_cols) - 0.5)) + ax.xaxis.set_minor_locator(FixedLocator((np.arange(num_cols) - 0.5).tolist())) - # pyrefly: ignore # bad-argument-type - ax.yaxis.set_minor_locator(FixedLocator(np.arange(num_rows) - 0.5)) + ax.yaxis.set_minor_locator(FixedLocator((np.arange(num_rows) - 0.5).tolist())) ax.grid( visible=True, diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index 5e8ed53..de2fd6b 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -54,7 +54,6 @@ def violinplot( pc.set_edgecolor(ec) pc.set_alpha(1.0) - # pyrefly: ignore # no-matching-overload, bad-argument-type q1, medians, q3 = np.percentile(data, [25, 50, 75], axis=0) ax.vlines( @@ -77,8 +76,7 @@ def violinplot( if showmeans: ax.scatter( xs, - # pyrefly: ignore # no-matching-overload, bad-argument-type - np.mean(data, axis=0), + np.mean(data, axis=0, dtype=float), marker="s", color=mc, s=15, @@ -116,15 +114,16 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs): # parse inputs if range is None: - range = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]]) + range_ = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]]) + else: + range_ = range if bins is None: bins = 25 # compute the histogram - # pyrefly: ignore # no-matching-overload, unexpected-keyword, bad-argument-type - cnt, xe, ye = np.histogram2d(x, y, bins=bins, normed=True, range=range) + cnt, xe, ye = np.histogram2d(x, y, bins=bins, range=range_, density=True) # generate the plot ax = kwargs["ax"] @@ -366,8 +365,10 @@ def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs): mean_y = np.mean(y) transform = ( - # pyrefly: ignore # bad-argument-type - Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y) + Affine2D() + .rotate_deg(45) + .scale(float(scale_x), float(scale_y)) + .translate(float(mean_x), float(mean_y)) ) ellipse.set_transform(transform + ax.transData) diff --git a/src/jetplot/signals.py b/src/jetplot/signals.py index 0f466dc..b9072d4 100644 --- a/src/jetplot/signals.py +++ b/src/jetplot/signals.py @@ -4,10 +4,9 @@ from typing import Callable import numpy as np -from numpy.typing import ArrayLike +from numpy.typing import ArrayLike, NDArray -# pyrefly: ignore # missing-module-attribute -from scipy.ndimage import gaussian_filter1d +from scipy.ndimage import gaussian_filter1d # pyrefly: ignore[missing-module-attribute] __all__ = ["smooth", "canoncorr", "participation_ratio", "stable_rank", "normalize"] @@ -32,14 +31,18 @@ def stable_rank(X): return svals_sq.sum() / svals_sq.max() -def participation_ratio(C): - """Computes the participation ratio of a square matrix.""" +def participation_ratio(C: NDArray[np.floating]) -> float: + """Compute the participation ratio of a square matrix.""" + assert C.ndim == 2, "C must be a matrix" assert C.shape[0] == C.shape[1], "C must be a square matrix" - return np.trace(C) ** 2 / np.trace(np.linalg.matrix_power(C, 2)) + + diag_sum = float(np.trace(C)) + diag_sq_sum = float(np.trace(C @ C)) + return diag_sum**2 / diag_sq_sum -def canoncorr(X: ArrayLike, Y: ArrayLike) -> ArrayLike: +def canoncorr(X: ArrayLike, Y: ArrayLike) -> NDArray[np.floating]: """Canonical correlation between two subspaces. Args: @@ -62,14 +65,16 @@ def canoncorr(X: ArrayLike, Y: ArrayLike) -> ArrayLike: """ # Orthogonalize each subspace - # pyrefly: ignore # no-matching-overload, bad-argument-type - qu, qv = np.linalg.qr(X)[0], np.linalg.qr(Y)[0] + qu = np.linalg.qr(np.asarray(X))[0] + qv = np.linalg.qr(np.asarray(Y))[0] # singular values of the inner product between the orthogonalized spaces return np.linalg.svd(qu.T.dot(qv), compute_uv=False, full_matrices=False) -def normalize(X: ArrayLike, axis: int = -1, norm: Callable = np.linalg.norm): +def normalize( + X: ArrayLike, axis: int = -1, norm: Callable[[ArrayLike], float] = np.linalg.norm +) -> NDArray[np.floating]: """Normalizes elements of an array or matrix. Args: @@ -80,4 +85,4 @@ def normalize(X: ArrayLike, axis: int = -1, norm: Callable = np.linalg.norm): Returns: Xn: Arrays that have been normalized using to the given function. """ - return X / norm(X, axis=axis, keepdims=True) + return np.asarray(X) / norm(X, axis=axis, keepdims=True) diff --git a/src/jetplot/timepiece.py b/src/jetplot/timepiece.py index e64ad5d..d7a0031 100644 --- a/src/jetplot/timepiece.py +++ b/src/jetplot/timepiece.py @@ -110,17 +110,19 @@ def wrapper(*args, **kwargs): calls.append(tstop - tstart) return results - # pyrefly: ignore # missing-attribute wrapper.calls = calls - # pyrefly: ignore # missing-attribute - wrapper.mean = lambda: np.mean(calls) - # pyrefly: ignore # missing-attribute - wrapper.serr = lambda: np.std(calls) / np.sqrt(len(calls)) - # pyrefly: ignore # missing-attribute - wrapper.summary = lambda: print( - "Runtimes: {} {} {}".format( - hrtime(wrapper.mean()), "\u00b1", hrtime(wrapper.serr()) - ) - ) + + def mean() -> float: + return float(np.mean(calls)) + + def serr() -> float: + return float(np.std(calls) / np.sqrt(len(calls))) + + def summary() -> None: + print(f"Runtimes: {hrtime(mean())} \u00b1 {hrtime(serr())}") + + wrapper.mean = mean + wrapper.serr = serr + wrapper.summary = summary return wrapper