Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions src/jetplot/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,23 @@ def violinplot(


@plotwrapper
def hist(*args, **kwargs):
def hist(*args, histtype="stepfilled", alpha=0.85, density=True, **kwargs):
"""Wrapper for matplotlib.hist function."""

# remove kwargs that are filled in manually
kwargs.pop("alpha", None)
kwargs.pop("histtype", None)
kwargs.pop("normed", None)

# get the axis and figure handles
ax = kwargs.pop("ax")
kwargs.pop("fig")

return ax.hist(*args, histtype="stepfilled", alpha=0.85, normed=True, **kwargs)
return ax.hist(*args, histtype=histtype, alpha=alpha, density=density, **kwargs)


@plotwrapper
def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs):
def hist2d(
x: np.ndarray,
y: np.ndarray,
bins: int | None = None,
limits: np.ndarray | None = None,
cmap: str = "hot",
**kwargs,
):
"""
Visualizes a 2D histogram by binning data.

Expand All @@ -116,17 +116,15 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs):
"""

# parse inputs
if range is None:
range_ = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]])
else:
range_ = range
if limits is None:
limits = np.array([[np.min(x), np.max(x)], [np.min(y), np.max(y)]])

if bins is None:
bins = 25

# compute the histogram
# pyrefly: ignore # no-matching-overload, bad-argument-type
cnt, xe, ye = np.histogram2d(x, y, bins=bins, range=range_, density=True)
cnt, xe, ye = np.histogram2d(x, y, bins=bins, range=limits, density=True)

# generate the plot
ax = kwargs["ax"]
Expand Down
84 changes: 84 additions & 0 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
from matplotlib import pyplot as plt

from jetplot import plots


def test_hist():
bins = 11
x = np.arange(10)
fig, ax = plt.subplots()
values, bin_edges, patches = plots.hist(x, bins=bins, fig=fig, ax=ax)
assert len(values) == bins
assert len(bin_edges) == bins + 1
plt.close(fig)


def test_hist2d():
x = np.random.randn(100)
y = np.random.randn(100)
fig, ax = plt.subplots()
plots.hist2d(x, y, fig=fig, ax=ax)
assert ax.get_aspect() == 1.0
plt.close(fig)


def test_errorplot_methods():
x = np.arange(5)
y = np.arange(5)
yerr = np.ones_like(x)

fig, ax = plt.subplots()
plots.errorplot(x, y, yerr, method="patch", fig=fig, ax=ax)
assert len(ax.lines) == 1
plt.close(fig)

fig, ax = plt.subplots()
plots.errorplot(x, y, yerr, method="line", fig=fig, ax=ax)
assert len(ax.lines) > 1
plt.close(fig)


def test_circle():
fig, ax = plt.subplots()
plots.circle(fig=fig, ax=ax)
line = ax.lines[0]
assert line.get_xdata()[0] == 1.0
assert len(line.get_xdata()) == 1001
plt.close(fig)


def test_bar_and_lines():
labels = ["A", "B", "C"]
data = [1.0, 2.0, 3.0]
err = [0.1, 0.1, 0.1]

fig, ax = plt.subplots()
plots.bar(labels, data, err=err, fig=fig, ax=ax)
assert len(ax.patches) >= len(labels)
plt.close(fig)

fig, ax = plt.subplots()
lines = [np.array(data), np.array(data) + 1]
plots.lines(np.arange(3), lines=lines, fig=fig, ax=ax)
assert len(ax.lines) == len(lines)
plt.close(fig)


def test_waterfall():
x = np.arange(5)
ys = [np.linspace(0, 1, 5) for _ in range(3)]
fig, ax = plt.subplots()
plots.waterfall(x, ys, fig=fig, ax=ax)
# waterfall uses fill_between which adds PolyCollections
assert len(ax.collections) >= len(ys)
plt.close(fig)


def test_violinplot():
data = np.random.randn(100)
fig, ax = plt.subplots()
plots.violinplot(data, xs=1, fig=fig, ax=ax)
# Expect at least one polygon from violin body
assert len(ax.collections) > 0
plt.close(fig)