.. index:: pair: example; cnn_inference_int8.cpp .. _doxid-cnn_inference_int8_8cpp-example: cnn_inference_int8.cpp ====================== This C++ API example demonstrates how to run AlexNet's conv3 and relu3 with int8 data type. Annotated version: :ref:`CNN int8 inference example ` This C++ API example demonstrates how to run AlexNet's conv3 and relu3 with int8 data type. Annotated version: :ref:`CNN int8 inference example ` .. ref-code-block:: cpp /******************************************************************************* * Copyright 2018-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include #include "oneapi/dnnl/dnnl.hpp" #include "example_utils.hpp" using namespace :ref:`dnnl `; void simple_net_int8(:ref:`engine::kind ` engine_kind) { using :ref:`tag ` = :ref:`memory::format_tag `; using :ref:`dt ` = :ref:`memory::data_type `; auto eng = :ref:`engine `(engine_kind, 0); :ref:`stream ` s(eng); const int batch = 8; //[Configure tensor shapes] // AlexNet: conv3 // {batch, 256, 13, 13} (x) {384, 256, 3, 3}; -> {batch, 384, 13, 13} // strides: {1, 1} :ref:`memory::dims ` conv_src_tz = {batch, 256, 13, 13}; :ref:`memory::dims ` conv_weights_tz = {384, 256, 3, 3}; :ref:`memory::dims ` conv_bias_tz = {384}; :ref:`memory::dims ` conv_dst_tz = {batch, 384, 13, 13}; :ref:`memory::dims ` conv_strides = {1, 1}; :ref:`memory::dims ` conv_padding = {1, 1}; //[Configure tensor shapes] //[Choose scaling factors] // Choose scaling factors for input, weight and output std::vector src_scales = {1.8f}; std::vector weight_scales = {2.0f}; std::vector dst_scales = {0.55f}; //[Choose scaling factors] //[Set scaling mask] const int src_mask = 0; const int weight_mask = 0; const int dst_mask = 0; //[Set scaling mask] // Allocate input and output buffers for user data std::vector user_src(batch * 256 * 13 * 13); std::vector user_dst(batch * 384 * 13 * 13); // Allocate and fill buffers for weights and bias std::vector conv_weights(product(conv_weights_tz)); std::vector conv_bias(product(conv_bias_tz)); //[Allocate buffers] auto user_src_memory = :ref:`memory `({{conv_src_tz}, :ref:`dt::f32 `, tag::nchw}, eng); write_to_dnnl_memory(user_src.data(), user_src_memory); auto user_weights_memory = :ref:`memory `({{conv_weights_tz}, :ref:`dt::f32 `, tag::oihw}, eng); write_to_dnnl_memory(conv_weights.data(), user_weights_memory); auto user_bias_memory = :ref:`memory `({{conv_bias_tz}, :ref:`dt::f32 `, tag::x}, eng); write_to_dnnl_memory(conv_bias.data(), user_bias_memory); //[Allocate buffers] //[Create convolution memory descriptors] auto conv_src_md = :ref:`memory::desc `({conv_src_tz}, dt::u8, :ref:`tag::any `); auto conv_bias_md = :ref:`memory::desc `({conv_bias_tz}, dt::s8, :ref:`tag::any `); auto conv_weights_md = :ref:`memory::desc `({conv_weights_tz}, dt::s8, :ref:`tag::any `); auto conv_dst_md = :ref:`memory::desc `({conv_dst_tz}, dt::u8, :ref:`tag::any `); //[Create convolution memory descriptors] //[Configure scaling] :ref:`primitive_attr ` conv_attr; conv_attr.:ref:`set_scales_mask `(:ref:`DNNL_ARG_SRC `, src_mask); conv_attr.set_scales_mask(:ref:`DNNL_ARG_WEIGHTS `, weight_mask); conv_attr.set_scales_mask(:ref:`DNNL_ARG_DST `, dst_mask); // Prepare dst scales auto dst_scale_md = :ref:`memory::desc `({1}, :ref:`dt::f32 `, tag::x); auto dst_scale_memory = :ref:`memory `(dst_scale_md, eng); write_to_dnnl_memory(dst_scales.data(), dst_scale_memory); //[Configure scaling] //[Configure post-ops] const float ops_alpha = 0.f; // relu negative slope const float ops_beta = 0.f; :ref:`post_ops ` ops; ops.:ref:`append_eltwise `(:ref:`algorithm::eltwise_relu `, ops_alpha, ops_beta); conv_attr.set_post_ops(ops); //[Configure post-ops] // check if int8 convolution is supported try { :ref:`convolution_forward::primitive_desc `(eng, :ref:`prop_kind::forward `, :ref:`algorithm::convolution_direct `, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_padding, conv_padding, conv_attr); } catch (:ref:`error ` &e) { if (e.status == :ref:`dnnl_unimplemented `) throw example_allows_unimplemented { "No int8 convolution implementation is available for this " "platform.\n" "Please refer to the developer guide for details."}; // on any other error just re-throw throw; } //[Create convolution primitive descriptor] auto conv_prim_desc = :ref:`convolution_forward::primitive_desc `(eng, :ref:`prop_kind::forward `, :ref:`algorithm::convolution_direct `, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_padding, conv_padding, conv_attr); //[Create convolution primitive descriptor] //[Quantize data and weights] auto conv_src_memory = :ref:`memory `(conv_prim_desc.src_desc(), eng); :ref:`primitive_attr ` src_attr; src_attr.:ref:`set_scales_mask `(:ref:`DNNL_ARG_DST `, src_mask); auto src_scale_md = :ref:`memory::desc `({1}, :ref:`dt::f32 `, tag::x); auto src_scale_memory = :ref:`memory `(src_scale_md, eng); write_to_dnnl_memory(src_scales.data(), src_scale_memory); auto src_reorder_pd = :ref:`reorder::primitive_desc `(eng, user_src_memory.get_desc(), eng, conv_src_memory.get_desc(), src_attr); auto src_reorder = :ref:`reorder `(src_reorder_pd); src_reorder.execute(s, {{:ref:`DNNL_ARG_FROM `, user_src_memory}, {:ref:`DNNL_ARG_TO `, conv_src_memory}, {:ref:`DNNL_ARG_ATTR_SCALES ` | :ref:`DNNL_ARG_DST `, src_scale_memory}}); auto conv_weights_memory = :ref:`memory `(conv_prim_desc.weights_desc(), eng); :ref:`primitive_attr ` weight_attr; weight_attr.:ref:`set_scales_mask `(:ref:`DNNL_ARG_DST `, weight_mask); auto wei_scale_md = :ref:`memory::desc `({1}, :ref:`dt::f32 `, tag::x); auto wei_scale_memory = :ref:`memory `(wei_scale_md, eng); write_to_dnnl_memory(weight_scales.data(), wei_scale_memory); auto weight_reorder_pd = :ref:`reorder::primitive_desc `(eng, user_weights_memory.get_desc(), eng, conv_weights_memory.get_desc(), weight_attr); auto weight_reorder = :ref:`reorder `(weight_reorder_pd); weight_reorder.execute(s, {{:ref:`DNNL_ARG_FROM `, user_weights_memory}, {:ref:`DNNL_ARG_TO `, conv_weights_memory}, {:ref:`DNNL_ARG_ATTR_SCALES ` | :ref:`DNNL_ARG_DST `, wei_scale_memory}}); auto conv_bias_memory = :ref:`memory `(conv_prim_desc.bias_desc(), eng); write_to_dnnl_memory(conv_bias.data(), conv_bias_memory); //[Quantize data and weights] auto conv_dst_memory = :ref:`memory `(conv_prim_desc.dst_desc(), eng); //[Create convolution primitive] auto conv = :ref:`convolution_forward `(conv_prim_desc); conv.execute(s, {{:ref:`DNNL_ARG_SRC `, conv_src_memory}, {:ref:`DNNL_ARG_WEIGHTS `, conv_weights_memory}, {:ref:`DNNL_ARG_BIAS `, conv_bias_memory}, {:ref:`DNNL_ARG_DST `, conv_dst_memory}, {:ref:`DNNL_ARG_ATTR_SCALES ` | :ref:`DNNL_ARG_SRC `, src_scale_memory}, {:ref:`DNNL_ARG_ATTR_SCALES ` | :ref:`DNNL_ARG_WEIGHTS `, wei_scale_memory}, {:ref:`DNNL_ARG_ATTR_SCALES ` | :ref:`DNNL_ARG_DST `, dst_scale_memory}}); //[Create convolution primitive] auto user_dst_memory = :ref:`memory `({{conv_dst_tz}, :ref:`dt::f32 `, tag::nchw}, eng); write_to_dnnl_memory(user_dst.data(), user_dst_memory); :ref:`primitive_attr ` dst_attr; dst_attr.:ref:`set_scales_mask `(:ref:`DNNL_ARG_SRC `, dst_mask); auto dst_reorder_pd = :ref:`reorder::primitive_desc `(eng, conv_dst_memory.get_desc(), eng, user_dst_memory.get_desc(), dst_attr); auto dst_reorder = :ref:`reorder `(dst_reorder_pd); dst_reorder.execute(s, {{:ref:`DNNL_ARG_FROM `, conv_dst_memory}, {:ref:`DNNL_ARG_TO `, user_dst_memory}, {:ref:`DNNL_ARG_ATTR_SCALES ` | :ref:`DNNL_ARG_SRC `, dst_scale_memory}}); //[Dequantize the result] s.wait(); } int main(int argc, char **argv) { return handle_example_errors( simple_net_int8, parse_engine_kind(argc, argv)); }