C++调用pytorch模型(vs2015+libtorch+pytorch)

it2022-05-08  7

C++调用pytorch模型(vs2015+libtorch+pytorch)

参考网站 另一个网站:知乎平台

1.开发环境

vs2015+windows10

2.转换模型

pytorch的C++版本用的是Torch Script,官方给了两种将pytorch模型转成Torch Script的方法。这里主要介绍第一种Tracing。

3.准备工作

(1)在Pytorch官网下载所需要的Libtorch版本,有GPU版本和CPU版本,而且有DEBUG和RELEASE区别。 Eg:我所需要的版本是这样的,第一项选择Stable(1.1)就好,第五项(9.0,10.0与None)说的是自己的CUDA版本,请参考我的第一篇博客。 (2)解压到一个新的文件夹。 (3)接下来,和之前配置OpenCV是类似的。新建一个VS项目工程(新建-win32控制台程序),然后添加一个源文件进行使用;视图-其他窗口-属性管理器(这里以Release版本为例) (4)右击Release|x64下的Microsoft.Cpp.x64.user,选择属性,VC++目录,将相应的include和lib加到包含目录和库目录。 (5)在连接器中,加入如下: c10.lib caffe2.lib caffe2_detectron_ops.lib caffe2_module_test_dynamic.lib clog.lib cpuinfo.lib foxi_dummy.lib foxi_loader.lib libprotobuf.lib libprotobuf-lite.lib libprotoc.lib onnx.lib onnx_proto.lib onnxifi_dummy.lib onnxifi_loader.lib torch.lib (6)还有一个地方需要修改:属性->C/C++ ->常规->SDL检查->否。 (7)将最初Libtorch解压缩到的文件中的lib文件夹,将其中所有的.dll文件复制。 (8)在第三步中新建的ConsoleApplication4,在这里选择x64-Release文件夹,将(1)中的复制文件粘贴至此。

4.保存网络模型,生成.pt文件

(说明:这里训练的网络是resnet18,CIFAR10数据集)

import torch import torchvision # An instance of your model. model = torchvision.models.resnet18()#这里保存的模型是自己训练好的网络模型 # An example input you would normally provide to your model's forward() method. example = torch.rand(1, 3, 224, 224) # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, example) traced_script_module.save("model.pt")

5.在VS2015中加载使用网络模型

#include <torch/script.h> // One-stop header. #include <iostream> #include <memory> int main() { // Deserialize the ScriptModule from a file using torch::jit::load(). std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("F:/project/尝试1/model.pt"); assert(module != nullptr); std::cout << "ok\n"; std::cout << module <<'\n' ; // Create a vector of inputs. std::vector<torch::jit::IValue> inputs; inputs.push_back(torch::ones({ 1, 3, 224, 224 })); // Execute the model and turn its output into a tensor. at::Tensor output = module->forward(inputs).toTensor(); std::cout << output.slice(1, 0, 5) << '\n'; while (1); }

(2)结果如下


最新回复(0)