Rust bindings for the C++ api of PyTorch

LIghtweight wrapper for pytorch eg libtorch in rust

tch-rs

Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as possible to the original C++ api. More idiomatic rust bindings could then be developed on top of this. The documentation can be found on docs.rs.

The code generation part for the C api on top of libtorch comes from ocaml-torch.

Getting Started

This crate requires the C++ PyTorch library (libtorch) in version v1.10.0 to be available on your system. You can either:

  • Use the system-wide libtorch installation (default).
  • Install libtorch manually and let the build script know about it via the LIBTORCH environment variable.
  • When a system-wide libtorch can't be found and LIBTORCH is not set, the build script will download a pre-built binary version of libtorch. By default a CPU version is used. The TORCH_CUDA_VERSION environment variable can be set to cu111 in order to get a pre-built binary using CUDA 11.1.

System-wide Libtorch

The build script will look for a system-wide libtorch library in the following locations:

  • In Linux: /usr/lib/libtorch.so

Libtorch Manual Install

  • Get libtorch from the PyTorch website download section and extract the content of the zip file.

  • For Linux users, add the following to your .bashrc or equivalent, where /path/to/libtorch is the path to the directory that was created when unzipping the file.

  • For Windows users, assuming that X:\path\to\libtorch is the unzipped libtorch directory.

    • Navigate to Control Panel -> View advanced system settings -> Environment variables.
    • Create the LIBTORCH variable and set it to X:\path\to\libtorch.
    • Append X:\path\to\libtorch\lib to the Path variable.

    If you prefer to temporarily set environment variables, in PowerShell you can run

  • You should now be able to run some examples, e.g. cargo run --example basics.

Windows Specific Notes

As per the pytorch docs the Windows debug and release builds are not ABI-compatible. This could lead to some segfaults if the incorrect version of libtorch is used.

Examples

Basic Tensor Operations

This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.

Training a Model via Gradient Descent

PyTorch provides automatic differentiation for most tensor operations it supports. This is commonly used to train models using gradient descent. The optimization is performed over variables which are created via a nn::VarStore by defining their shapes and initializations.

In the example below my_module uses two variables x1 and x2 which initial values are 0. The forward pass applied to tensor xs returns xs * x1 + exp(xs) * x2.

Once the model has been generated, a nn::Sgd optimizer is created. Then on each step of the training loop:

  • The forward pass is applied to a mini-batch of data.
  • A loss is computed as the mean square error between the model output and the mini-batch ground truth.
  • Finally an optimization step is performed: gradients are computed and variables from the VarStore are modified accordingly.

Writing a Simple Neural Network

The nn api can be used to create neural network architectures, e.g. the following code defines a simple model with one hidden layer and trains it on the MNIST dataset using the Adam optimizer.

More details on the training loop can be found in the detailed tutorial.

Using some Pre-Trained Model

The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.

The example can then be run via the following command:

This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.

Further examples include:

  • A simplified version of char-rnn illustrating character level language modeling using Recurrent Neural Networks.
  • Neural style transfer uses a pre-trained VGG-16 model to compose an image in the style of another image (pre-trained weights: vgg16.ot).
  • Some ResNet examples on CIFAR-10.
  • A tutorial showing how to deploy/run some Python trained models using TorchScript JIT.
  • Some Reinforcement Learning examples using the OpenAI Gym environment. This includes a policy gradient example as well as an A2C implementation that can run on Atari games.
  • A Transfer Learning Tutorial shows how to finetune a pre-trained ResNet model on a very small dataset.

External material:

  • A tutorial showing how to use Torch to compute option prices and greeks.

License

tch-rs is distributed under the terms of both the MIT license and the Apache license (version 2.0), at your option.

See LICENSE-APACHE, LICENSE-MIT for more details.

Issues

Collection of the latest Issues

Calandiel

Calandiel

0

Hello. I'm trying to compile a project that uses your library as a dependency. I'm currently on Linux and, as per instructions in the readme about manual installations, I added environmental variables for LIBTORCH and LD_LIBRARY_PATH. I used config.toml for that:

As far as I can tell, it does pick up that LIBTORCH is defined (the behavior is the same whether I use "export" in bash or define environmental variables with cargo's config.toml), but in either case, it doesnt seem to realize that the include folder exists. The error I encountered is: torch/csrc/autograd/engine.h: No such file or directory, but that file does exist inside models/libtorch/include/torch/csrc/autograd/. The models folder is inside crate's root folder, next to src and I call cargo from that root folder with cargo run/build. I did try different combinations of ./models/libtorch, /models/libtorch/, models/libtorch/include, both through [env] and export, but to no avail.

Am I doing something wrong or is this an issue with the library? If I am, could you guide me in the right direction?

With regards, Cal

jafioti

jafioti

0

I do a lot of NLP, and after tokenizing and converting to indexes, my tokens are of type usize. Is it possible to impl Element for usize so a type conversion of every element in a potentially large array doesn't need to be converted? I've seen the types supported, so I was wondering if maybe usize can be seen as like a u32.

AlexanderEkdahl

AlexanderEkdahl

0

Hi,

First off, thank you for this amazing library!

I noticed how Tensor has been annotated with #[must_use] but this isn't necessarily correct for in-place operations such as cumsum_.

Could the generating method be updated to only annotate #[must_use] for non-in-place operations? I'd be happy to implement this if you think it is a good idea?

zzeee

zzeee

0

1.I have PyTorch 1.10 that works natively on my M1 Mac I've downloaded libtorch from the website 2. I've added tch = "0.6.1" to cargo.toml and when run cargo build I see:

note: ld: warning: ignoring file /Users/andrey/prj/nntest/target/debug/build/torch-sys-3c7217b1a617897c/out/libtorch/libtorch/lib/libtorch.dylib, building for macOS-arm64 but attempting to link with file built for macOS-x86_64 ld: warning: ignoring file /Users/andrey/prj/nntest/target/debug/build/torch-sys-3c7217b1a617897c/out/libtorch/libtorch/lib/libtorch_cpu.dylib, building for macOS-arm64 but attempting to link with file built for macOS-x86_64 ld: warning: ignoring file /Users/andrey/prj/nntest/target/debug/build/torch-sys-3c7217b1a617897c/out/libtorch/libtorch/lib/libc10.dylib, building for macOS-arm64 but attempting to link with file built for macOS-x86_64 Undefined symbols for architecture arm64: "c10::impl::ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard()", referenced from: at::AutoDispatchBelowADInplaceOrView::~AutoDispatchBelowADInplaceOrView() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet)", referenced from: at::AutoDispatchBelowADInplaceOrView::AutoDispatchBelowADInplaceOrView() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "at::_ops::zeros::call(c10::ArrayRef, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional)", referenced from: at::zeros(c10::ArrayRef, c10::TensorOptions) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "vtable for torch::autograd::AutogradMeta", referenced from: torch::autograd::AutogradMeta::AutogradMeta(c10::TensorImpl*, bool, torch::autograd::Edge) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) NOTE: a missing vtable usually means the first non-inline virtual member function has no definition. "c10::AutogradMetaInterface::~AutogradMetaInterface()", referenced from: torch::autograd::AutogradMeta::AutogradMeta(c10::TensorImpl*, bool, torch::autograd::Edge) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "caffe2::TypeMeta::typeMetaDatas()", referenced from: caffe2::TypeMeta::data() const in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "c10::TensorImpl::set_autograd_meta(std::__1::unique_ptr<c10::AutogradMetaInterface, std::__1::default_deletec10::AutogradMetaInterface >)", referenced from: torch::autograd::make_variable(at::Tensor, bool, bool) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "c10::UndefinedTensorImpl::_singleton", referenced from: c10::UndefinedTensorImpl::singleton() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > const&)", referenced from: c10::Device::validate() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) caffe2::TypeMeta::fromScalarType(c10::ScalarType) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "at::ops::mul_Scalar::call(at::Tensor const&, c10::Scalar const&)", referenced from: at::mul(at::Tensor const&, c10::Scalar const&) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*)", referenced from: c10::intrusive_ptr<c10::VariableVersion::VersionCounter, c10::detail::intrusive_target_default_null_typec10::VariableVersion::VersionCounter >::intrusive_ptr(c10::VariableVersion::VersionCounter*) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) c10::intrusive_ptr_target::~intrusive_ptr_target() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::retain() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) c10::ArrayRef::debugCheckNullptrInvariant() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "vtable for c10::AutogradMetaInterface", referenced from: c10::AutogradMetaInterface::AutogradMetaInterface() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) NOTE: a missing vtable usually means the first non-inline virtual member function has no definition. "caffe2::TypeMeta::error_unsupported_typemeta(caffe2::TypeMeta)", referenced from: caffe2::TypeMeta::toScalarType() in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*)", referenced from: torch::autograd::AutogradMeta::AutogradMeta(c10::TensorImpl*, bool, torch::autograd::Edge) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, c10::detail::CompileTimeEmptyString) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) c10::TensorImpl::itemsize() const in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) c10::TensorImpl::data() const in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) "at::print(std::__1::basic_ostream<char, std::__1::char_traits >&, at::Tensor const&, long long)", referenced from: at::operator<<(std::__1::basic_ostream<char, std::__1::char_traits >&, at::Tensor const&) in libtorch_sys-2bfcaab8f2fcc6a6.rlib(torch_api.o) ld: symbol(s) not found for architecture arm64 clang: error: linker command failed with exit code 1 (use -v to see invocation)

rhobro

rhobro

7

Hi there,

In the documentation, it says that libtorch must be on the system for the crate to work.

I am creating a program. Does this mean that libtorch must be on every system that I run the program on or will it be packaged with the binary?

Thanks.

tsubakisakura

tsubakisakura

4

Hello!

I have tried restoring VarStore from , , etc, but there was no direct loading API. So, I thought it would be convenient if the following APIs were available in the VarStore.

pub trait ReadStream : Read+Seek {}

impl VarStore { ... pub fn load_from_stream<S: ReadStream>(&mut self, stream: S) -> Result<(), TchError> ... }

libtorch has a function that takes . If I implemented this, it seemed to connect to the Rust side. I tried to implement it and it worked fine. How about an API like this? If this API policy is acceptable, I will make a pull request.

Thanks.

any35

any35

2

Hi, all.

I try to rewrite a python lib in rust. But I have some problem when I try to implement a custom backward method.

there is the sample code

Is there any equivement impl in rust?

Thank you.

ycat3

ycat3

2

I am using Rockpro64 SBC with Armbian Focal and I am developing Rust application. Libtorch1.10.1 native build for Arm64 from pytorch source tree is successful. I followed the tch-rs instruction and some tch-rs examples work fine. cargo run --example basics cargo run --example mnist Accordingly I believe my Libtorch environment is correct. However, other example build fails. cargo test cargo run --example pretrained-models

Rust 1.57.0 says help: some extern functions couldn't be found; some native libraries may need to be installed or have their path specified

What kind of native libraries to install ?


[email protected]:~/tch-rs$ cargo run --example pretrained-models Compiling tch v0.6.1 (/home/rock64/tch-rs) error: linking with cc failed: exit status: 1 // = note: /usr/bin/ld: /home/rock64/tch-rs/target/debug/deps/libtorch_sys-0f95d54be7ee1814.rlib(torch_api.o): in function torch::jit::Object::type() const': /home/rock64/libtorch/include/torch/csrc/jit/api/object.h:37: undefined reference to torch::jit::Object::_ivalue() const' /usr/bin/ld: /home/rock64/tch-rs/target/debug/deps/libtorch_sys-0f95d54be7ee1814.rlib(torch_api.o): in function torch::jit::slot_iterator_impl<torch::jit::detail::NamedPolicy<torch::jit::detail::ParameterPolicy> >::cur() const': /home/rock64/libtorch/include/torch/csrc/jit/api/module.h:370: undefined reference to torch::jit::Object::_ivalue() const' /usr/bin/ld: /home/rock64/libtorch/include/torch/csrc/jit/api/module.h:370: undefined reference to torch::jit::Object::_ivalue() const' /usr/bin/ld: /home/rock64/tch-rs/target/debug/deps/libtorch_sys-0f95d54be7ee1814.rlib(torch_api.o): in function torch::jit::slot_iterator_impl<torch::jit::detail::NamedPolicytorch::jit::detail::ParameterPolicy >::next()': /home/rock64/libtorch/include/torch/csrc/jit/api/module.h:386: undefined reference to torch::jit::Object::_ivalue() const' /usr/bin/ld: /home/rock64/libtorch/include/torch/csrc/jit/api/module.h:396: undefined reference to torch::jit::Object::_ivalue() const' /usr/bin/ld: /home/rock64/tch-rs/target/debug/deps/libtorch_sys-0f95d54be7ee1814.rlib(torch_api.o):/home/rock64/libtorch/include/torch/csrc/jit/api/module.h:407: more undefined references to `torch::jit::Object::_ivalue() const' follow collect2: error: ld returned 1 exit status

= help: some extern functions couldn't be found; some native libraries may need to be installed or have their path specified = note: use the -l flag to specify native libraries to link = note: use the cargo:rustc-link-lib directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)

error: could not compile tch due to previous error

xuexl

xuexl

3

I get a model from detectron2 of mask-rcnn . And it is panicked when i load the model by tch::CModule . And the panicked infomation is :

thread 'main' panicked at 'called Result::unwrap() on an Err value: Torch("\nUnknown builtin op: torchvision::nms.\nCould not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript.\n:\n File "/home/lin/anaconda3/envs/tailings_pond/lib/python3.9/site-packages/torchvision/ops/boxes.py", line 35\n """\n _assert_has_ops()\n return torch.ops.torchvision.nms(boxes, scores, iou_threshold)\n ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE\nSerialized File "code/torch/torchvision/ops/boxes.py", line 24\n _6 = torch.torchvision.extension._assert_has_ops\n _7 = _6()\n _8 = ops.torchvision.nms(boxes, scores, iou_threshold)\n ~~~~~~~~~~~~~~~~~~~ <--- HERE\n return _8\n'nms' is being compiled since it was called from '_batched_nms_coordinate_trick'\n File "/home/lin/anaconda3/envs/tailings_pond/lib/python3.9/site-packages/torchvision/ops/boxes.py", line 87\n offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))\n boxes_for_nms = boxes + offsets[:, None]\n keep = nms(boxes_for_nms, scores, iou_threshold)\n ~~~ <--- HERE\n return keep\nSerialized File "code/torch/torchvision/ops/boxes.py", line 16\n _5 = torch.unsqueeze(torch.slice(offsets), 1)\n boxes_for_nms = torch.add(boxes, _5)\n keep = torch.torchvision.ops.boxes.nms(boxes_for_nms, scores, iou_threshold, )\n ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE\n _0 = keep\n return _0\n'_batched_nms_coordinate_trick' is being compiled since it was called from 'RPN.forward'\nSerialized File "code/torch/detectron2/modeling/proposal_generator/rpn.py", line 19\n argument_9: Tensor,\n image_size: Tensor) -> Tensor:\n _0 = torch.torchvision.ops.boxes._batched_nms_coordinate_trick\n ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <---


I had loaded this model in a c++ program, and need to include "torchvison" . I know this , but how to use "torchvision" in tch ?

Anyone can help me ? thanks a lot . (wait on line)

robz

robz

1

I'm considering using tch-rs to preprocess a bunch of data to use in an pytorch pipeline. Thanks for implementing these bindings!

When saving a tch::Tensor like this:

Loading it with torch.load in python results in an error:

Loading it with torch.jit.load appears to work like this:

Is this the proper way to save and load data between tch-rs and pytorch? Or is there a way to save tch::Tensor so that I can load it with torch.load?

skyne98

skyne98

0

Hi there!

First, awesome project! To give some context, I have never used PyTorch before and I have a question related to Vulkan and AMD GPUs. It seems like PyTorch supports running using Vulkan compute.

Does it mean it should be possible to enable it in the bindings as well to run models on, let's say, AMD GPUs and maybe even Android?

If it's easy enough and doesn't require any specialized knowledge, I can even try to help!

Thanks!

cauliflowew

cauliflowew

2

I'm trying to use the tch crate for a rust project on an M1 Mac. I've successfully compiled libtorch following this tutorial. However, when I try to run a simple example of a tensor operation, I get an error that I'll paste below. However, due to the compilation of libtorch being successful I suspect that it's simply some misconfiguration on my side, but I'm not sure what exactly is wrong. In the error I see that it tries to link with the x86 version of libtorch even though I've compiled an arm version.

Error:

dbsxdbsx

dbsxdbsx

2

When doing with reinforcement learning project, I have a struct like this:

I need to save the whole instance of the struct to disk, so that I could load it latter when using it next time. I know Tenosr has methods called save and load to do so, but it is just for only one tensor object. It is easy to see this way is not appropriate for this case.

Then I came up the way that concating vec<Tensor> into a single tensor befor calling save. But I think this may be inefficient when the number of tensor is large, say 1,000,000.

Is there some other approach to do so? Like using seralization(like crate serde)

drewm1980

drewm1980

1

I would like to be able to debug a crate using tch-rs in pycharm or clion, but I'm hitting:

"error while loading shared libraries: libtorch_cpu.so"

More details at: https://youtrack.jetbrains.com/issue/CPP-26265

I'm not sure if this is purely an IDE issue, or if tch-rs's doing something ~really non-standard in its build, but as the crate author you might have some insight.

The build binary ~does run fine with "cargo run".

Thanks!

JonathanWoollett-Light

JonathanWoollett-Light

4

In my view it seems unnecessary to duplicate every function which may return an error, e.g. Tensor::reshape and Tensor::f_reshape.

This pattern is not seen in any other popular libraries (the Tensorflow bindings and ndarray being two examples), and I believe for good reason. It adds gigantic bloat to documentation and the codebase all to simply abstract an unwrap away (the inverse of propagating an error, so it both adds bloat and makes errors slightly less clear).

What is the reason behind this design choice?

rustrust

rustrust

5

I need to check whether a matrix I am working with is singular, so I am placing a call to linalg_matrix_rank. Everything is fine when the matrix is full rank, but tch-rs panics if the matrix is at all singular.

Since the purpose of this function is to report on the rank of the matrix, which is not necessarily full rank, some other behavior is preferred. Most preferably, this function should never panic or error if the matrix is singular, because its purpose is to report on the numeric rank of the matrix.

The result of running this looks like:

NOBLES5E

NOBLES5E

2

As datasets and models grow larger, single GPU training can become a limiting factor in many moderate sized tasks. I am thinking of adding a distributed training example for tch. To achieve this, there are two things to be done

  1. Distributed communication engine supporting Rust: I can do it with our recently open sourced bagua, which has a Rust backend bagua-core.
  2. Tensor hooks so that we can schedule communication when for example a gradient is ready: we need to wrap the VariableHooksInterface.h in torch-sys, as mentioned in https://github.com/LaurentMazare/tch-rs/issues/218. This seems to be not difficult.

@LaurentMazare I appreciate if you have time commenting this and see if the direction is right. Thanks!

edlanglois

edlanglois

2

I am interested in the reasons for / against implementing Clone for tensor. The only existing discussion I'm aware of is part of #281.

To start, it would be nice to have Clone implemented for Tensor. I'd like to be able to easily clone structures that store tensors.

There is a C++ clone function that could be used. I see that this function is in Declarations-*.yaml but it does not appear in the generated api (at least for tch-rs, it's in the api for ocaml-torch). Even if the decision is not to implement Clone for Tensor it would be nice to expose this function as g_clone or something. [Edit: Oops I forgot about Tensor::copy. Why is that done in rust with copy_ instead of with the C++ clone()?]

There is possibly an issue with whether Tensor::clone should have a gradient back to the source. From the torch perspective you would expect a clone() method to behave the same as it does in C++ and Python, which is to be differentiable with respect to the source.

From the rust side that might be unexpected: if I clone a parameter tensor then I don't want the new parameter tensor to have gradients back to the original. I'm not sure that detaching is the solution either. If you have

then should z be differentiable with respect to x as though it were y or simply be detached? From a rust perspective I'd kind of expect z to be exactly like y which means being differentiable with respect to x.

But the more I think about it though the more I think it's fine for clone() to be differentiable. In the above example, dy/dx = 2 and dy/dy = 1 and z would behave the same when substituting it for y: dz/dx = 2, dz/dy = 1, dz/dz = 1. As for differentiably cloning parameter tensors I think it's unlikely that you would differentiate a function of the clone with respect to the original and if you do, the fact that C++/Pytorch clone is differentiable ought to be enough of a warning to consider that the tch version might be too.

Another risk with implementing Clone is that the new variable isn't tracked by a VarStore but I don't think that should be problem either. If you are calling clone() manually then it should be expected that the clone won't be in a VarStore. The risk would be deriving Clone for a struct that stores Tensors and has those Tensors in a VarStore except you can't do that because VarStore doesn't implement Clone.

Any other reasons for / against? At this point I am in favour of implementing Clone for Tensor using the differentiable C++ clone() function.

edlanglois

edlanglois

2

I'd like to copy a module to another device. In C++ this is the clone method with an optional device.

While not something I currently need, I think it would also make sense to implement Clone for Module where it does a deep copy to the same device (and the same for implementing Clone for Tensor, but maybe that's a separate discussion).

Anyways, for C++ it looks like they have a separate Cloneable template class so a separate trait for clone_to_device would be an option rather than implementing directly on Module.

I'd be happy to try making a PR for this.

joverwey

joverwey

0

Is there a reason why LBFGS is not supported as an optimizer? Would it be hard to add?

For Style Transfer, I've noticed that it converges much faster than AdamW when used from Python.

Also see this comparison with other optimizers: #404

Information - Updated Jan 24, 2022

Stars: 1.4K
Forks: 124
Issues: 72

Linfa unlocks verified machine learning algorithms in Rust

Supports ML & data processing algorithms such as logistic regression, linear regression, vector machines, normalization & vectorization

Linfa unlocks verified machine learning algorithms in Rust

HAL : Hyper Adaptive Learning

Rust based Cross-GPU Machine Learning

HAL : Hyper Adaptive Learning

Rustml is a library for doing machine learning in Rust

The documentation of the project with a descprition of the modules can be found

Rustml is a library for doing machine learning in Rust

A machine learning library for Rust

To use autograph in your crate, add it as a dependency in Cargo

A machine learning library for Rust

A CHIP-8 virtual machine written in Rust

Install Rust using the How to write an emulator (CHIP-8 interpreter)

A CHIP-8 virtual machine written in Rust

The Rust Machine Learning Book

This repository contains the source of &quot;&quot;

The Rust Machine Learning Book

My first attempt at machine learning in rust

This library currently only offers very basic KNN

My first attempt at machine learning in rust

rust-machine-learning-api-example

Example of Rust API for Machine Learning

rust-machine-learning-api-example

MachineID for Rust - Like

This Rust package is inspired by

MachineID for Rust - Like

R2VM is the Rust for RISC-V Virtual Machine

R2VM is a full-system, multi-core, cycle-level simulator, with binary translation to provide high performance

R2VM is the Rust for RISC-V Virtual Machine
Facebook Instagram Twitter GitHub Dribbble
Privacy