这篇文章主要介绍了如何使用TensorRT实现自定义算子。 Note: 其实自定义算子写多了发现其实还挺好写的,格式都差不多,主要区别是enqueue的前向计算逻辑可能写起来复杂些。 以Upsample为例,TensorRT不支持Caffe的Upsample层,所以这里实现了一个自定义层类型,即plugin。需要实现: 需要实现的函数详见如下代码段。 Upsample类的实现: 下面是对应的Creator类的实现: 下面是对应的plugin factory类的实现: 如有问题可加公众号交流:AI算法爱好者
整个实现过程基本上是:
class Upsample : public nvinfer1::IPluginV2IOExt { public: // 直接解析网络时候需要用到 Upsample(); // 反序列化时候需要用到 Upsample(const void *data, size_t length); ~Upsample(); // 直接return输出节点数, int getNbOutputs() override; // return输出的维度信息,如:return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]); Dims getOutputDimensions(int index, const Dims *inputs, int num_input_dims) override; // pos索引到的input/output的数据格式(format)和数据类型(datatype)如果都支持则返回true bool supportsFormatCombination(int pos, const PluginTensorDesc* in_out, int num_inputs, int num_outputs) const override; // 这个函数可以获取到数据类型和输入的维度信息,如果有需要用到的可以在这里将相关信息取出来 configurePlugin(const PluginTensorDesc* in, int num_inputs, const PluginTensorDesc* out, int num_outputs) override; // 在这里返回正确的序列化数据的长度,如我要序列化数据类型和数据维度:return sizeof(data_type) + sizeof(chw); size_t getSerializationSize() const override; // 序列化函数,在这里把反序列化时需要用到的参数或数据序列化 void serialize(void *buffer) const override; // 设置工作空间,不需要直接 return 0; size_t getWorkspaceSize(int max_batch_size) const override; // 前向计算的核心函数,计算逻辑在这里实现,可以使用cublas实现或者自己写cuda核函数实现 int enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override; // 调用enqueue的时候需要用到的资源先在这里Initialize,这个函数是在engine创建之后enqueue调用之前调用的,不需要Initialize则直接 return 0; int initialize() override; // 释放Initialize申请的资源,在enqueue调用之后且engine销毁之后调用 void terminate() override; // 返回输出的数据类型,如何输入相同,可以直接 return input_types[0]; nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* input_types, int num_inputs) const override; // 返回自定义类型,如这里是:return Upsample const char* getPluginType() const override; // 返回plugin version,没啥说的 const char* getPluginVersion() const override; // 销毁对象 void destroy() override { delete this; } // 在这里new一个该自定义类型并返回 nvinfer1::IPluginV2Ext* clone() const override; // 设置命名空间,用来在网络中查找和创建plugin void setPluginNamespace(const char* lib_namespace) override; // 返回plugin对象的命名空间 const char* getPluginNamespace() const override; bool isOutputBroadcastAcrossBatch(int output_index, const bool* input_is_broadcasted, int num_inputs) const override; bool canBroadcastInputAcrossBatch(int input_index) const override; }
class UpsampleCreator : public nvinfer1::IPluginCreator { public: const char* getPluginName() const override; const char* getPluginVersion() const override; const PluginFieldCollection* getFieldNames() override; // 创建自定义层pluin的对象并返回 nvinfer1::IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override; // 创建自定义层pluin的对象并返回,反序列化用到 nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serial_data, size_t serial_length) override; void setPluginNamespace(const char* lib_namespace) override; const char* getPluginNamespace() const override; }
class CaffePluginFactory : public nvcaffeparser1::IPluginFactoryV2 { public: // 在这里判断一个层是否为自定义层类型 bool isPluginV2(const char* name) override; // 在这里创建自定义层类型的对象并返回 nvinfer1::IPluginV2* createPlugin(const char* layer_name, const nvinfer1::Weights* weights, int num_weights, const char* libNamespace="") override; }
本网页所有视频内容由 imoviebox边看边下-网页视频下载, iurlBox网页地址收藏管理器 下载并得到。
ImovieBox网页视频下载器 下载地址: ImovieBox网页视频下载器-最新版本下载
本文章由: imapbox邮箱云存储,邮箱网盘,ImageBox 图片批量下载器,网页图片批量下载专家,网页图片批量下载器,获取到文章图片,imoviebox网页视频批量下载器,下载视频内容,为您提供.
阅读和此文章类似的: 全球云计算