Rust bindings for the C++ api of PyTorch. The goal of the
tch crate is to
provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It
aims at staying as close as possible to the original C++ api. More idiomatic
rust bindings could then be developed on top of this. The
documentation can be found on docs.rs.
The code generation part for the C api on top of libtorch comes from ocaml-torch.
This crate requires the C++ PyTorch library (libtorch) in version v1.10.0 to be available on your system. You can either:
- Use the system-wide libtorch installation (default).
- Install libtorch manually and let the build script know about it via the
- When a system-wide libtorch can't be found and
LIBTORCHis not set, the build script will download a pre-built binary version of libtorch. By default a CPU version is used. The
TORCH_CUDA_VERSIONenvironment variable can be set to
cu111in order to get a pre-built binary using CUDA 11.1.
The build script will look for a system-wide libtorch library in the following locations:
- In Linux:
Libtorch Manual Install
libtorchfrom the PyTorch website download section and extract the content of the zip file.
For Linux users, add the following to your
.bashrcor equivalent, where
/path/to/libtorchis the path to the directory that was created when unzipping the file.
For Windows users, assuming that
X:\path\to\libtorchis the unzipped libtorch directory.
- Navigate to Control Panel -> View advanced system settings -> Environment variables.
- Create the
LIBTORCHvariable and set it to
If you prefer to temporarily set environment variables, in PowerShell you can run
You should now be able to run some examples, e.g.
cargo run --example basics.
Windows Specific Notes
As per the pytorch docs the Windows debug and release builds are not ABI-compatible. This could lead to some segfaults if the incorrect version of libtorch is used.
Basic Tensor Operations
This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.
Training a Model via Gradient Descent
PyTorch provides automatic differentiation for most tensor operations
it supports. This is commonly used to train models using gradient
descent. The optimization is performed over variables which are created
nn::VarStore by defining their shapes and initializations.
In the example below
my_module uses two variables
which initial values are 0. The forward pass applied to tensor
xs * x1 + exp(xs) * x2.
Once the model has been generated, a
nn::Sgd optimizer is created.
Then on each step of the training loop:
- The forward pass is applied to a mini-batch of data.
- A loss is computed as the mean square error between the model output and the mini-batch ground truth.
- Finally an optimization step is performed: gradients are computed and variables from the
VarStoreare modified accordingly.
Writing a Simple Neural Network
nn api can be used to create neural network architectures, e.g. the following code defines
a simple model with one hidden layer and trains it on the MNIST dataset using the Adam optimizer.
More details on the training loop can be found in the detailed tutorial.
Using some Pre-Trained Model
The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.
The example can then be run via the following command:
This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.
Further examples include:
- A simplified version of char-rnn illustrating character level language modeling using Recurrent Neural Networks.
- Neural style transfer uses a pre-trained VGG-16 model to compose an image in the style of another image (pre-trained weights: vgg16.ot).
- Some ResNet examples on CIFAR-10.
- A tutorial showing how to deploy/run some Python trained models using TorchScript JIT.
- Some Reinforcement Learning examples using the OpenAI Gym environment. This includes a policy gradient example as well as an A2C implementation that can run on Atari games.
- A Transfer Learning Tutorial shows how to finetune a pre-trained ResNet model on a very small dataset.
- A tutorial showing how to use Torch to compute option prices and greeks.
tch-rs is distributed under the terms of both the MIT license
and the Apache license (version 2.0), at your option.