Ask yourself the following:
- Are you using
matplotlib.pyplotto plot pytorch tensors? - Do you forget to call
.cpu().detach().numpy()everytime you want to plot a tensor?
Then torchplot may be something for you. torchplot is a simple drop-in replacement
for plotting pytorch tensors. We simply override every matplotlib.pyplot function such
that pytorch tensors are automatically converted.
Simply just change your default matplotlib import statement:
Instead of
from matplotlib.pyplot import *use
from torchplot import *and instead of
import matplotlib.pyplot as pltuse
import torchplot as pltHerafter, then you can remove every .cpu().detach().numpy() (or variations heroff) from
your code and everything should just work. If you do not want to mix implementations,
we recommend importing torchplot as seperaly package:
import torchplot as tpSimple as
pip install torchplot
# lets make a scatter plot of two pytorch variables that are stored on gpu
import torch
import torchplot as plt
x = torch.randn(100, requires_grad=True, device='cuda')
y = torch.randn(100, requires_grad=True, device='cuda')
plt.plot(x, y, '.') # easy and simpleTested using torch>=1.6 and matplotlib>=3.3.3 but should perfectly work with
both earlier and later versions.
Please observe the Apache 2.0 license that is listed in this repository.
If you want to cite the framework feel free to use this (but only if you loved it 😊):
@article{detlefsen2021torchplot,
title={TorchPlot},
author={Detlefsen, Nicki S. and Hauberg, Søren},
journal={GitHub. Note: https://github.com/MachineLearningLifeScience/torchplot},
year={2021}
}