From 687a3ffcec21f7510d7584c6b4c7c89e0b4e29fe Mon Sep 17 00:00:00 2001 From: Niru Maheswaranathan Date: Tue, 20 May 2025 04:48:57 -0700 Subject: [PATCH 1/2] Add basic tests for plots module --- tests/test_plots.py | 82 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/test_plots.py diff --git a/tests/test_plots.py b/tests/test_plots.py new file mode 100644 index 0000000..b7f4f76 --- /dev/null +++ b/tests/test_plots.py @@ -0,0 +1,82 @@ +import numpy as np +from matplotlib import pyplot as plt + +from jetplot import plots + + +def test_hist(): + x = np.arange(10) + fig, ax = plt.subplots() + n, bins, patches = plots.hist(x, fig=fig, ax=ax) + assert n.sum() == len(x) + plt.close(fig) + + +def test_hist2d(): + x = np.random.randn(100) + y = np.random.randn(100) + fig, ax = plt.subplots() + plots.hist2d(x, y, fig=fig, ax=ax) + assert ax.get_aspect() == "equal" + plt.close(fig) + + +def test_errorplot_methods(): + x = np.arange(5) + y = np.arange(5) + yerr = np.ones_like(x) + + fig, ax = plt.subplots() + plots.errorplot(x, y, yerr, method="patch", fig=fig, ax=ax) + assert len(ax.lines) == 1 + plt.close(fig) + + fig, ax = plt.subplots() + plots.errorplot(x, y, yerr, method="line", fig=fig, ax=ax) + assert len(ax.lines) > 1 + plt.close(fig) + + +def test_circle(): + fig, ax = plt.subplots() + plots.circle(fig=fig, ax=ax) + line = ax.lines[0] + assert line.get_xdata()[0] == 1.0 + assert len(line.get_xdata()) == 1001 + plt.close(fig) + + +def test_bar_and_lines(): + labels = ["A", "B", "C"] + data = [1.0, 2.0, 3.0] + err = [0.1, 0.1, 0.1] + + fig, ax = plt.subplots() + plots.bar(labels, data, err=err, fig=fig, ax=ax) + assert len(ax.patches) >= len(labels) + plt.close(fig) + + fig, ax = plt.subplots() + lines = [np.array(data), np.array(data) + 1] + plots.lines(np.arange(3), lines=lines, fig=fig, ax=ax) + assert len(ax.lines) == len(lines) + plt.close(fig) + + +def test_waterfall(): + x = np.arange(5) + ys = [np.linspace(0, 1, 5) for _ in range(3)] + fig, ax = plt.subplots() + plots.waterfall(x, ys, fig=fig, ax=ax) + # waterfall uses fill_between which adds PolyCollections + assert len(ax.collections) >= len(ys) + plt.close(fig) + + +def test_violinplot(): + data = np.random.randn(100) + fig, ax = plt.subplots() + plots.violinplot(data, xs=1, fig=fig, ax=ax) + # Expect at least one polygon from violin body + assert len(ax.collections) > 0 + plt.close(fig) From d64dee256b40903f38e866b65ca2cc49a4c589af Mon Sep 17 00:00:00 2001 From: Niru Nahesh Date: Tue, 20 May 2025 10:53:22 -0700 Subject: [PATCH 2/2] Slight updates to make tests pass. --- src/jetplot/plots.py | 28 +++++++++++++--------------- tests/test_plots.py | 8 +++++--- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/jetplot/plots.py b/src/jetplot/plots.py index e57f18b..217a9cb 100644 --- a/src/jetplot/plots.py +++ b/src/jetplot/plots.py @@ -88,23 +88,23 @@ def violinplot( @plotwrapper -def hist(*args, **kwargs): +def hist(*args, histtype="stepfilled", alpha=0.85, density=True, **kwargs): """Wrapper for matplotlib.hist function.""" - - # remove kwargs that are filled in manually - kwargs.pop("alpha", None) - kwargs.pop("histtype", None) - kwargs.pop("normed", None) - - # get the axis and figure handles ax = kwargs.pop("ax") kwargs.pop("fig") - return ax.hist(*args, histtype="stepfilled", alpha=0.85, normed=True, **kwargs) + return ax.hist(*args, histtype=histtype, alpha=alpha, density=density, **kwargs) @plotwrapper -def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs): +def hist2d( + x: np.ndarray, + y: np.ndarray, + bins: int | None = None, + limits: np.ndarray | None = None, + cmap: str = "hot", + **kwargs, +): """ Visualizes a 2D histogram by binning data. @@ -116,17 +116,15 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs): """ # parse inputs - if range is None: - range_ = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]]) - else: - range_ = range + if limits is None: + limits = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]]) if bins is None: bins = 25 # compute the histogram # pyrefly: ignore # no-matching-overload, bad-argument-type - cnt, xe, ye = np.histogram2d(x, y, bins=bins, range=range_, density=True) + cnt, xe, ye = np.histogram2d(x, y, bins=bins, range=limits, density=True) # generate the plot ax = kwargs["ax"] diff --git a/tests/test_plots.py b/tests/test_plots.py index b7f4f76..5b0ed21 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -5,10 +5,12 @@ def test_hist(): + bins = 11 x = np.arange(10) fig, ax = plt.subplots() - n, bins, patches = plots.hist(x, fig=fig, ax=ax) - assert n.sum() == len(x) + values, bin_edges, patches = plots.hist(x, bins=bins, fig=fig, ax=ax) + assert len(values) == bins + assert len(bin_edges) == bins + 1 plt.close(fig) @@ -17,7 +19,7 @@ def test_hist2d(): y = np.random.randn(100) fig, ax = plt.subplots() plots.hist2d(x, y, fig=fig, ax=ax) - assert ax.get_aspect() == "equal" + assert ax.get_aspect() == 1.0 plt.close(fig)