Tensorboard for Pytorch
One of my favorite tool in machine learning is Tensorboard which can be used to visualize a number of different statistics about your model. It has been always hard for me to interpret the behaviour of a neural network.
Taking a glance at the parameters or how they progress over the training procedure is always valuable, especially getting to know your model better enables you further improve the performance of the model. Tensorboard provides convenient and efficient tools to visualize your models’ statistics.
Installation
Firstly, install tensorboard via pip:
pip install tensorboard
Usage
Call tensorboard with the following command on command line (here –logdir accepts the directory for the tensorboard data):
tensorboard --logdir=runs
If you are on a jupyter-notebook:
%tensorboard --logdir=runs
The dashboard can be accessed through this link (http://localhost:6006/)
Remote
On the other hand, if you are working on a remote server, you need to forward everything on port 6006 of the server (in 127.0.0.1:6006) to the local machine on port 16006 (again this can be any port). This can be done by connecting to a server with a small addition to the original ssh command.
ssh -L 16006:127.0.0.1:6006 acc_name@server_ip
Here, the “-L” option allows you to map a specific port on the server machine to your local machine’s port. Then you will be able to connect the tensorboard of interest through the following link (http://127.0.0.1:16006) from your local device.
Example Usage
Then, you are good to go. Enjoy tracking the important statistics through training progress. The following code snippet helps you understand how to track the statistics of interest.
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
for n_iter in range(100):
writer.add_scalar('Loss/train', np.random.random(), n_iter)
writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
'xcosx':i*np.cos(i/r),
'tanx': np.tan(i/r)}, i)
x = np.random.random(1000)
writer.add_histogram('distribution centers', x + i, i)
writer.close()
References
For further details: