Rust 🤝🏾 PyTorch - Using LibTorch, tch-rs
This is kind of a deviation from the web-services posts. This post will be about using PyTorch APIs with Rust-based libraries.
But why?
A couple of reasons why:
- Type safety
- Better dependency management, Cargo
- Faster web servers (tokio vs Flask) to serve models
- Lighter applications
If you’re not overly concerned about any of those, let’s just say Rust is cool and we want to train/serve models in Rust✨
The How
There’s a couple of steps here you need to understand because version mismatch is a pain. As I discovered while running the examples.
Concepts
- PyTorch has LibTorch which is a C++ API.
- C/C++ -> Rust FFI to generate bindings.
tch-rs
which provides wrapper functions for idiomatic Rust.
Setup
Library Versions | |
---|---|
LibTorch | v1.7.0 |
tch-rs | v0.3.0 |
torch-sys | v0.3.0 |
- Download LibTorch from Source
Get Started with PyTorch locally has LibTorch (in C++). Download and extract the zip file to a path.
- Set environment variables
Linux:
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
Windows:
$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
- Clone
tch-rs
The tch-rs
cargo package hasn’t been updated to include the latest changes. At the time of writing this, it’s still on v0.2.1 and PyTorch v1.6. So we’ll clone the tch-rs
to use the latest APIs.
Coming from Go, I try to organize my repos like so -> (language)/src/github.com/(github_id)/(repo_name)
cd rust/src/github.com/${github_id}/
git clone https://github.com/LaurentMazare/tch-rs
Reference: The tch-rs
crate README.
Model Inference in Rust ✨
Create a new Rust project
Now, let’s create a new project with Cargo
cd rust/src/github.com/${github_id}/
cargo new --bin tch-inference-rs
cd tch-inference-rs
Update Cargo dependencies with tch-rs
Let’s edit the Cargo.toml
file
[dependencies]
tch = {version="0.3.0", path="../tch-rs"}
anyhow = "1.0.33"
[dev-dependencies]
torch-sys = {version="0.3.0", path="../tch-rs/torch-sys"}
Now notice, I’m using the path in the package dependencies to point to the local versions of the packages. This is because I wanted to use the v1.7 PyTorch APIs and v0.3.0(master) is the corresponding compatible package.
Not-quite-pro-tip: Use
cargo search <package_name>
to search packages and their versions.
Let’s rumble!
Now for model inference, we need two things:
- Trained model
- Input image and inference code
Tracing Pre-trained models into TorchScript
mkdir -p pretrained
touch pretrained/resnet.py
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
Now, what we’re doing here is we’re fetching the model weights for a pretrained ResNet-18 model. Then we’re setting the model to it’s evaluation mode. We’re using the JIT module to load the pretrained model weights using TorchScript. And finally, we’re saving the traced model in the filesystem.
Run the python script.
python pretrained/resnet.py
You should see the
TorchScript traced model-inference from Rust
Edit main.rs
.
use std::env;
use anyhow::{bail, Result};
use tch::vision::imagenet;
use tch::Kind::Float;
fn main() -> Result<()> {
// Get CLI arguments at runtime
let args: Vec<String> = env::args().collect();
// Parse arguments
// The first argument is the model path
// The second argument is the input image to be classified
let (model_file, image_file) = match args.as_slice() {
[_, m, i] => (m.to_owned(), i.to_owned()),
_ => bail!("usage: main model.pt tiger.jpg"),
};
// Load, resize the image to fit the classifier's tensors
// ResNet's standard 224 x 224
let image = imagenet::load_image_and_resize(image_file, 224, 224)?;
// Load the TorchScript traced model
let model = tch::CModule::load(model_file)?;
// Pass the image through the network and apply a softmax layer
// to extract the learned the classes
let output = image
.unsqueeze(0)
.apply(&model)
.softmax(-1, Float);
// Iterate through the top-5 results,
// print the probability and class for each
for (probability, class) in imagenet::top(&output, 5).iter() {
println!("{:50} {:5.2}%", class, 100.0 * probability)
}
Ok(())
}
Run it
Download the image and run inference.
wget -O image.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/3/3b/Royal_Bengal_Tiger_at_Kanha_National_Park.jpg/800px-Royal_Bengal_Tiger_at_Kanha_National_Park.jpg
cargo run --bin model.pt image.jpg
Classification results:
tiger, Panthera tigris 96.33%
tiger cat 3.56%
zebra 0.09%
jaguar, panther, Panthera onca, Felis onca 0.01%
tabby, tabby cat 0.01%
Next post
Next up, we’ll look at serving a 🤗 transformers-based model like BERT behind a REST-API with Tonic.