pytorch 官方介面 自定義運算元註冊

2021-10-09 06:21:56 字數 1681 閱讀 8302

pytorch官方提供了註冊自己的自定義運算元的介面,不需要像native_functions.yaml那樣每次改原始碼,再編譯

在pytorch倉庫下找乙個地方,建立新的資料夾

我建議在/pytorch/aten/src/aten/core/op_registration下建立

建立的資料夾假設叫做myrelu

在myrelu下繼續建立myrelu.cpp cmakelists.txt 和 build資料夾

myrelu.cpp:

#include

torch::tensor myrelu

(torch::tensor self)

torch_library

(myop, m)

cmakelists.txt:

cmake_minimum_required

(version 3.1 fatal_error)

project

(myrelu)

find_package

(torch required)

# define our library target

add_library

(myrelu shared myrelu.cpp)

# enable c++

14target_compile_features

(myrelu private cxx_std_14)

# link against libtorch

target_link_libraries

(myrelu "$"

)

在命令列開啟

cd build

cmake -dcmake_build_type=debug -dcmake_prefix_path="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..

make -j

編譯好之後,會在build目錄下出現libmyrelu.so檔案

用以下**測試

import torch

torch.ops.load_library(

"/users/admin//pytorch/aten/src/aten/core/op_registration/myrelu/build/libmyrelu.so"

)print

(torch.ops.myop.myrelu(torch.randn(5,

3,requires_grad=

true))

)

輸出

tensor(

[[0.1644, 1.2747,

-0.0000],[

-0.0000, 0.8493,

-0.0000]

,[1.1179,

-0.0000, 2.0053]

,[0.1727,

-0.0000,

-0.0000]

,[0.5056, 0.3437,

-0.0000]

], grad_fn=)

註冊成功,並且根據pytorch c++自己提供的autograd機制,已經註冊上了反向傳播函式

Pytorch自定義引數

如果想要靈活地使用模型,可能需要自定義引數,比如 class net nn.module def init self super net,self init self.a torch.randn 2 3 requires grad true self.b nn.linear 2,2 defforwa...

PyTorch 自定義層

與使用module類構造模型類似。下面的centeredlayer類通過繼承module類自定義了乙個將輸入減掉均值後輸出的層,並將層的計算定義在了forward函式裡。這個層裡不含模型引數。class mydense nn.module def init self super mydense,se...

自定義介面

好久沒寫介面了 好像以前也沒怎麼寫過.已經忘記怎麼寫了 就自己動手寫乙個熟悉一下 demo很簡單 就三個類 名字就隨便起了 public inte ce mylistener2 public class a catch interruptedexception e listener.setliste...