This is kind of a deviation from the web-services posts. This post will be about using PyTorch APIs with Rust-based libraries.

But why?

But why?

A couple of reasons why:

  • Type safety
  • Better dependency management, Cargo
  • Faster web servers (tokio vs Flask) to serve models
  • Lighter applications

If you’re not overly concerned about any of those, let’s just say Rust is cool and we want to train/serve models in Rust✨

The How

There’s a couple of steps here you need to understand because version mismatch is a pain. As I discovered while running the examples.

Concepts

  1. PyTorch has LibTorch which is a C++ API.
  2. C/C++ -> Rust FFI to generate bindings.
  3. tch-rs which provides wrapper functions for idiomatic Rust.

Setup

Library Versions
LibTorch v1.7.0
tch-rs v0.3.0
torch-sys v0.3.0
  1. Download LibTorch from Source

Get Started with PyTorch locally has LibTorch (in C++). Download and extract the zip file to a path.

  1. Set environment variables

Linux:

export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH

Windows:

$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
  1. Clone tch-rs

The tch-rs cargo package hasn’t been updated to include the latest changes. At the time of writing this, it’s still on v0.2.1 and PyTorch v1.6. So we’ll clone the tch-rs to use the latest APIs.

Coming from Go, I try to organize my repos like so -> (language)/src/github.com/(github_id)/(repo_name)

cd rust/src/github.com/${github_id}/
git clone https://github.com/LaurentMazare/tch-rs

Reference: The tch-rs crate README.

Model Inference in Rust ✨

Create a new Rust project

Now, let’s create a new project with Cargo

cd rust/src/github.com/${github_id}/
cargo new --bin tch-inference-rs
cd tch-inference-rs

Update Cargo dependencies with tch-rs

Let’s edit the Cargo.toml file

[dependencies]
tch = {version="0.3.0", path="../tch-rs"}
anyhow = "1.0.33"

[dev-dependencies]
torch-sys = {version="0.3.0", path="../tch-rs/torch-sys"}

Now notice, I’m using the path in the package dependencies to point to the local versions of the packages. This is because I wanted to use the v1.7 PyTorch APIs and v0.3.0(master) is the corresponding compatible package.

Not-quite-pro-tip: Use cargo search <package_name> to search packages and their versions.

Let’s rumble!

Now for model inference, we need two things:

  • Trained model
  • Input image and inference code

Tracing Pre-trained models into TorchScript

mkdir -p pretrained
touch pretrained/resnet.py
import torch
import torchvision

model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

Now, what we’re doing here is we’re fetching the model weights for a pretrained ResNet-18 model. Then we’re setting the model to it’s evaluation mode. We’re using the JIT module to load the pretrained model weights using TorchScript. And finally, we’re saving the traced model in the filesystem.

Run the python script.

python pretrained/resnet.py

You should see the

TorchScript traced model-inference from Rust

Edit main.rs.

use std::env;
use anyhow::{bail, Result};
use tch::vision::imagenet;
use tch::Kind::Float;

fn main() -> Result<()> {
    // Get CLI arguments at runtime
    let args: Vec<String> = env::args().collect();
    
    // Parse arguments
    // The first argument is the model path
    // The second argument is the input image to be classified 
    let (model_file, image_file) = match args.as_slice() {
        [_, m, i] => (m.to_owned(), i.to_owned()),
        _ => bail!("usage: main model.pt tiger.jpg"),
    };

    // Load, resize the image to fit the classifier's tensors
    // ResNet's standard 224 x 224
    let image = imagenet::load_image_and_resize(image_file, 224, 224)?;

    // Load the TorchScript traced model
    let model = tch::CModule::load(model_file)?;

    // Pass the image through the network and apply a softmax layer
    // to extract the learned the classes
    let output = image
        .unsqueeze(0)
        .apply(&model)
        .softmax(-1, Float);

    // Iterate through the top-5 results,
    // print the probability and class for each
    for (probability, class) in imagenet::top(&output, 5).iter() {
        println!("{:50} {:5.2}%", class, 100.0 * probability)
    }
    Ok(())
}

Run it

Download the image and run inference.

wget -O image.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/3/3b/Royal_Bengal_Tiger_at_Kanha_National_Park.jpg/800px-Royal_Bengal_Tiger_at_Kanha_National_Park.jpg
cargo run --bin model.pt image.jpg

Classification results:

tiger, Panthera tigris                             96.33%
tiger cat                                           3.56%
zebra                                               0.09%
jaguar, panther, Panthera onca, Felis onca          0.01%
tabby, tabby cat                                    0.01%

Next post

Next up, we’ll look at serving a 🤗 transformers-based model like BERT behind a REST-API with Tonic.