.. index:: pair: example; cpu_cnn_training_f32.c .. _doxid-cpu_cnn_training_f32_8c-example: cpu_cnn_training_f32.c ====================== This C API example demonstrates how to build an AlexNet model training. This C API example demonstrates how to build an AlexNet model training. .. ref-code-block:: cpp /******************************************************************************* * Copyright 2016-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ // Required for posix_memalign #define _POSIX_C_SOURCE 200112L #include #include #include #include "oneapi/dnnl/dnnl.h" #include "example_utils.h" #define BATCH 8 #define IC 3 #define OC 96 #define CONV_IH 227 #define CONV_IW 227 #define CONV_OH 55 #define CONV_OW 55 #define CONV_STRIDE 4 #define CONV_PAD 0 #define POOL_OH 27 #define POOL_OW 27 #define POOL_STRIDE 2 #define POOL_PAD 0 static size_t product(:ref:`dnnl_dim_t ` *arr, size_t size) { size_t prod = 1; for (size_t i = 0; i < size; ++i) prod *= arr[i]; return prod; } static void init_net_data(float *data, uint32_t dim, const :ref:`dnnl_dim_t ` *dims) { if (dim == 1) { for (:ref:`dnnl_dim_t ` i = 0; i < dims[0]; ++i) { data[i] = (float)(i % 1637); } } else if (dim == 4) { for (:ref:`dnnl_dim_t ` in = 0; in < dims[0]; ++in) for (:ref:`dnnl_dim_t ` ic = 0; ic < dims[1]; ++ic) for (:ref:`dnnl_dim_t ` ih = 0; ih < dims[2]; ++ih) for (:ref:`dnnl_dim_t ` iw = 0; iw < dims[3]; ++iw) { :ref:`dnnl_dim_t ` indx = in * dims[1] * dims[2] * dims[3] + ic * dims[2] * dims[3] + ih * dims[3] + iw; data[indx] = (float)(indx % 1637); } } } typedef struct { int nargs; :ref:`dnnl_exec_arg_t ` *args; } args_t; static void prepare_arg_node(args_t *node, int nargs) { node->args = (:ref:`dnnl_exec_arg_t ` *)malloc(sizeof(:ref:`dnnl_exec_arg_t `) * nargs); node->nargs = nargs; } static void free_arg_node(args_t *node) { free(node->args); } static void set_arg(:ref:`dnnl_exec_arg_t ` *arg, int arg_idx, :ref:`dnnl_memory_t ` memory) { arg->:ref:`arg ` = arg_idx; arg->:ref:`memory ` = memory; } static void init_data_memory(uint32_t dim, const :ref:`dnnl_dim_t ` *dims, :ref:`dnnl_format_tag_t ` user_tag, :ref:`dnnl_engine_t ` engine, float *data, :ref:`dnnl_memory_t ` *memory) { :ref:`dnnl_memory_desc_t ` user_md; CHECK(:ref:`dnnl_memory_desc_create_with_tag `( &user_md, dim, dims, :ref:`dnnl_f32 `, user_tag)); CHECK(:ref:`dnnl_memory_create `(memory, user_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); CHECK(:ref:`dnnl_memory_desc_destroy `(user_md)); write_to_dnnl_memory(data, *memory); } :ref:`dnnl_status_t ` prepare_reorder(:ref:`dnnl_memory_t ` *user_memory, // in :ref:`const_dnnl_memory_desc_t ` prim_memory_md, // in :ref:`dnnl_engine_t ` prim_engine, // in: primitive's engine int dir_is_user_to_prim, // in: user -> prim or prim -> user :ref:`dnnl_memory_t ` *prim_memory, // out: primitive's memory created :ref:`dnnl_primitive_t ` *reorder, // out: reorder primitive created uint32_t *net_index, // primitive index in net (inc if reorder created) :ref:`dnnl_primitive_t ` *net, args_t *net_args) { // net params :ref:`const_dnnl_memory_desc_t ` user_memory_md; :ref:`dnnl_memory_get_memory_desc `(*user_memory, &user_memory_md); :ref:`dnnl_engine_t ` user_mem_engine; :ref:`dnnl_memory_get_engine `(*user_memory, &user_mem_engine); if (!:ref:`dnnl_memory_desc_equal `(user_memory_md, prim_memory_md)) { CHECK(:ref:`dnnl_memory_create `(prim_memory, prim_memory_md, prim_engine, :ref:`DNNL_MEMORY_ALLOCATE `)); :ref:`dnnl_primitive_desc_t ` reorder_pd; if (dir_is_user_to_prim) { CHECK(:ref:`dnnl_reorder_primitive_desc_create `(&reorder_pd, user_memory_md, user_mem_engine, prim_memory_md, prim_engine, NULL)); } else { CHECK(:ref:`dnnl_reorder_primitive_desc_create `(&reorder_pd, prim_memory_md, prim_engine, user_memory_md, user_mem_engine, NULL)); } CHECK(:ref:`dnnl_primitive_create `(reorder, reorder_pd)); CHECK(:ref:`dnnl_primitive_desc_destroy `(reorder_pd)); net[*net_index] = *reorder; prepare_arg_node(&net_args[*net_index], 2); set_arg(&net_args[*net_index].args[0], :ref:`DNNL_ARG_FROM `, dir_is_user_to_prim ? *user_memory : *prim_memory); set_arg(&net_args[*net_index].args[1], :ref:`DNNL_ARG_TO `, dir_is_user_to_prim ? *prim_memory : *user_memory); (*net_index)++; } else { *prim_memory = NULL; *reorder = NULL; } return :ref:`dnnl_success `; } void simple_net() { :ref:`dnnl_engine_t ` :ref:`engine `; CHECK(:ref:`dnnl_engine_create `(&engine, :ref:`dnnl_cpu `, 0)); // idx // build a simple net uint32_t n_fwd = 0, n_bwd = 0; :ref:`dnnl_primitive_t ` net_fwd[10], net_bwd[10]; args_t net_fwd_args[10], net_bwd_args[10]; const int ndims = 4; :ref:`dnnl_dims_t ` net_src_sizes = {BATCH, IC, CONV_IH, CONV_IW}; :ref:`dnnl_dims_t ` net_dst_sizes = {BATCH, OC, POOL_OH, POOL_OW}; float *net_src = (float *)malloc(product(net_src_sizes, ndims) * sizeof(float)); float *net_dst = (float *)malloc(product(net_dst_sizes, ndims) * sizeof(float)); init_net_data(net_src, ndims, net_src_sizes); memset(net_dst, 0, product(net_dst_sizes, ndims) * sizeof(float)); //---------------------------------------------------------------------- //----------------- Forward Stream ------------------------------------- // AlexNet: conv // {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, 11, 11} -> // {BATCH, OC, CONV_OH, CONV_OW} // strides: {CONV_STRIDE, CONV_STRIDE} :ref:`dnnl_dims_t ` conv_user_src_sizes; for (int i = 0; i < ndims; i++) conv_user_src_sizes[i] = net_src_sizes[i]; :ref:`dnnl_dims_t ` conv_user_weights_sizes = {OC, IC, 11, 11}; :ref:`dnnl_dims_t ` conv_bias_sizes = {OC}; :ref:`dnnl_dims_t ` conv_user_dst_sizes = {BATCH, OC, CONV_OH, CONV_OW}; :ref:`dnnl_dims_t ` conv_strides = {CONV_STRIDE, CONV_STRIDE}; :ref:`dnnl_dims_t ` conv_dilation = {0, 0}; :ref:`dnnl_dims_t ` conv_padding = {CONV_PAD, CONV_PAD}; float *conv_src = net_src; float *conv_weights = (float *)malloc( product(conv_user_weights_sizes, ndims) * sizeof(float)); float *conv_bias = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float)); init_net_data(conv_weights, ndims, conv_user_weights_sizes); init_net_data(conv_bias, 1, conv_bias_sizes); // create memory for user data :ref:`dnnl_memory_t ` conv_user_src_memory, conv_user_weights_memory, conv_user_bias_memory; init_data_memory(ndims, conv_user_src_sizes, :ref:`dnnl_nchw `, engine, conv_src, &conv_user_src_memory); init_data_memory(ndims, conv_user_weights_sizes, :ref:`dnnl_oihw `, engine, conv_weights, &conv_user_weights_memory); init_data_memory(1, conv_bias_sizes, :ref:`dnnl_x `, engine, conv_bias, &conv_user_bias_memory); // create a convolution :ref:`dnnl_primitive_desc_t ` conv_pd; { // create data descriptors for convolution w/ no specified format :ref:`dnnl_memory_desc_t ` conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md; CHECK(:ref:`dnnl_memory_desc_create_with_tag `(&conv_src_md, ndims, conv_user_src_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_format_tag_any `)); CHECK(:ref:`dnnl_memory_desc_create_with_tag `(&conv_weights_md, ndims, conv_user_weights_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_format_tag_any `)); CHECK(:ref:`dnnl_memory_desc_create_with_tag `( &conv_bias_md, 1, conv_bias_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_x `)); CHECK(:ref:`dnnl_memory_desc_create_with_tag `(&conv_dst_md, ndims, conv_user_dst_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_format_tag_any `)); CHECK(:ref:`dnnl_convolution_forward_primitive_desc_create `(&conv_pd, engine, :ref:`dnnl_forward `, :ref:`dnnl_convolution_direct `, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding, NULL)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_src_md)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_weights_md)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_bias_md)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_dst_md)); } :ref:`dnnl_memory_t ` conv_internal_src_memory, conv_internal_weights_memory, conv_internal_dst_memory; // create memory for dst data, we don't need to reorder it to user data :ref:`const_dnnl_memory_desc_t ` conv_dst_md = :ref:`dnnl_primitive_desc_query_md `(conv_pd, :ref:`dnnl_query_dst_md `, 0); CHECK(:ref:`dnnl_memory_create `(&conv_internal_dst_memory, conv_dst_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // create reorder primitives between user data and convolution srcs // if required :ref:`dnnl_primitive_t ` conv_reorder_src, conv_reorder_weights; :ref:`const_dnnl_memory_desc_t ` conv_src_md = :ref:`dnnl_primitive_desc_query_md `(conv_pd, :ref:`dnnl_query_src_md `, 0); CHECK(prepare_reorder(&conv_user_src_memory, conv_src_md, engine, 1, &conv_internal_src_memory, &conv_reorder_src, &n_fwd, net_fwd, net_fwd_args)); :ref:`const_dnnl_memory_desc_t ` conv_weights_md = :ref:`dnnl_primitive_desc_query_md `(conv_pd, :ref:`dnnl_query_weights_md `, 0); CHECK(prepare_reorder(&conv_user_weights_memory, conv_weights_md, engine, 1, &conv_internal_weights_memory, &conv_reorder_weights, &n_fwd, net_fwd, net_fwd_args)); :ref:`dnnl_memory_t ` conv_src_memory = conv_internal_src_memory ? conv_internal_src_memory : conv_user_src_memory; :ref:`dnnl_memory_t ` conv_weights_memory = conv_internal_weights_memory ? conv_internal_weights_memory : conv_user_weights_memory; // finally create a convolution primitive :ref:`dnnl_primitive_t ` conv; CHECK(:ref:`dnnl_primitive_create `(&conv, conv_pd)); net_fwd[n_fwd] = conv; prepare_arg_node(&net_fwd_args[n_fwd], 4); set_arg(&net_fwd_args[n_fwd].args[0], :ref:`DNNL_ARG_SRC `, conv_src_memory); set_arg(&net_fwd_args[n_fwd].args[1], :ref:`DNNL_ARG_WEIGHTS `, conv_weights_memory); set_arg(&net_fwd_args[n_fwd].args[2], :ref:`DNNL_ARG_BIAS `, conv_user_bias_memory); set_arg(&net_fwd_args[n_fwd].args[3], :ref:`DNNL_ARG_DST `, conv_internal_dst_memory); n_fwd++; // AlexNet: relu // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW} float negative_slope = 0.0f; // keep memory format of source same as the format of convolution // output in order to avoid reorder :ref:`const_dnnl_memory_desc_t ` relu_src_md = conv_dst_md; :ref:`const_dnnl_memory_desc_t ` relu_dst_md = relu_src_md; // create a relu primitive descriptor :ref:`dnnl_primitive_desc_t ` relu_pd; CHECK(:ref:`dnnl_eltwise_forward_primitive_desc_create `(&relu_pd, engine, :ref:`dnnl_forward `, :ref:`dnnl_eltwise_relu `, relu_src_md, relu_dst_md, negative_slope, 0, NULL)); // create relu dst memory :ref:`dnnl_memory_t ` relu_dst_memory; CHECK(:ref:`dnnl_memory_create `( &relu_dst_memory, relu_dst_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // finally create a relu primitive :ref:`dnnl_primitive_t ` relu; CHECK(:ref:`dnnl_primitive_create `(&relu, relu_pd)); net_fwd[n_fwd] = relu; prepare_arg_node(&net_fwd_args[n_fwd], 2); set_arg(&net_fwd_args[n_fwd].args[0], :ref:`DNNL_ARG_SRC `, conv_internal_dst_memory); set_arg(&net_fwd_args[n_fwd].args[1], :ref:`DNNL_ARG_DST `, relu_dst_memory); n_fwd++; // AlexNet: lrn // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW} // local size: 5 // alpha: 0.0001 // beta: 0.75 // k: 1.0 uint32_t local_size = 5; float alpha = 0.0001f; float beta = 0.75f; float k = 1.0f; // create lrn src memory descriptor using dst memory descriptor // from previous primitive :ref:`const_dnnl_memory_desc_t ` lrn_src_md = relu_dst_md; :ref:`const_dnnl_memory_desc_t ` lrn_dst_md = lrn_src_md; // create a lrn primitive descriptor :ref:`dnnl_primitive_desc_t ` lrn_pd; CHECK(:ref:`dnnl_lrn_forward_primitive_desc_create `(&lrn_pd, engine, :ref:`dnnl_forward `, :ref:`dnnl_lrn_across_channels `, lrn_src_md, lrn_dst_md, local_size, alpha, beta, k, NULL)); // create primitives for lrn dst and workspace memory :ref:`dnnl_memory_t ` lrn_dst_memory, lrn_ws_memory; CHECK(:ref:`dnnl_memory_create `( &lrn_dst_memory, lrn_dst_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // create workspace only in training and only for forward primitive // query lrn_pd for workspace, this memory will be shared with forward lrn :ref:`const_dnnl_memory_desc_t ` lrn_ws_md = :ref:`dnnl_primitive_desc_query_md `(lrn_pd, :ref:`dnnl_query_workspace_md `, 0); CHECK(:ref:`dnnl_memory_create `( &lrn_ws_memory, lrn_ws_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // finally create a lrn primitive :ref:`dnnl_primitive_t ` lrn; CHECK(:ref:`dnnl_primitive_create `(&lrn, lrn_pd)); net_fwd[n_fwd] = lrn; prepare_arg_node(&net_fwd_args[n_fwd], 3); set_arg(&net_fwd_args[n_fwd].args[0], :ref:`DNNL_ARG_SRC `, relu_dst_memory); set_arg(&net_fwd_args[n_fwd].args[1], :ref:`DNNL_ARG_DST `, lrn_dst_memory); set_arg(&net_fwd_args[n_fwd].args[2], :ref:`DNNL_ARG_WORKSPACE `, lrn_ws_memory); n_fwd++; // AlexNet: pool // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW} // kernel: {3, 3} // strides: {POOL_STRIDE, POOL_STRIDE} // dilation: {0, 0} :ref:`dnnl_dims_t ` pool_dst_sizes; for (int i = 0; i < ndims; i++) pool_dst_sizes[i] = net_dst_sizes[i]; :ref:`dnnl_dims_t ` pool_kernel = {3, 3}; :ref:`dnnl_dims_t ` pool_strides = {POOL_STRIDE, POOL_STRIDE}; :ref:`dnnl_dims_t ` pool_padding = {POOL_PAD, POOL_PAD}; :ref:`dnnl_dims_t ` pool_dilation = {0, 0}; // create memory for user dst data :ref:`dnnl_memory_t ` pool_user_dst_memory; init_data_memory(4, pool_dst_sizes, :ref:`dnnl_nchw `, engine, net_dst, &pool_user_dst_memory); // create a pooling primitive descriptor :ref:`dnnl_primitive_desc_t ` pool_pd; { // create pooling src memory descriptor using dst descriptor // from previous primitive :ref:`const_dnnl_memory_desc_t ` pool_src_md = lrn_dst_md; // create descriptors for dst pooling data :ref:`dnnl_memory_desc_t ` pool_dst_md; CHECK(:ref:`dnnl_memory_desc_create_with_tag `(&pool_dst_md, 4, pool_dst_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_format_tag_any `)); CHECK(:ref:`dnnl_pooling_forward_primitive_desc_create `(&pool_pd, engine, :ref:`dnnl_forward `, :ref:`dnnl_pooling_max `, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_dilation, pool_padding, pool_padding, NULL)); CHECK(:ref:`dnnl_memory_desc_destroy `(pool_dst_md)); } // create memory for workspace :ref:`dnnl_memory_t ` pool_ws_memory; :ref:`const_dnnl_memory_desc_t ` pool_ws_md = :ref:`dnnl_primitive_desc_query_md `(pool_pd, :ref:`dnnl_query_workspace_md `, 0); CHECK(:ref:`dnnl_memory_create `( &pool_ws_memory, pool_ws_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // create reorder primitives between pooling dsts and user format dst // if required :ref:`dnnl_primitive_t ` pool_reorder_dst; :ref:`dnnl_memory_t ` pool_internal_dst_memory; :ref:`const_dnnl_memory_desc_t ` pool_dst_md = :ref:`dnnl_primitive_desc_query_md `(pool_pd, :ref:`dnnl_query_dst_md `, 0); n_fwd += 1; // tentative workaround: preserve space for pooling that should // happen before the reorder CHECK(prepare_reorder(&pool_user_dst_memory, pool_dst_md, engine, 0, &pool_internal_dst_memory, &pool_reorder_dst, &n_fwd, net_fwd, net_fwd_args)); n_fwd -= pool_reorder_dst ? 2 : 1; :ref:`dnnl_memory_t ` pool_dst_memory = pool_internal_dst_memory ? pool_internal_dst_memory : pool_user_dst_memory; // finally create a pooling primitive :ref:`dnnl_primitive_t ` pool; CHECK(:ref:`dnnl_primitive_create `(&pool, pool_pd)); net_fwd[n_fwd] = pool; prepare_arg_node(&net_fwd_args[n_fwd], 3); set_arg(&net_fwd_args[n_fwd].args[0], :ref:`DNNL_ARG_SRC `, lrn_dst_memory); set_arg(&net_fwd_args[n_fwd].args[1], :ref:`DNNL_ARG_DST `, pool_dst_memory); set_arg(&net_fwd_args[n_fwd].args[2], :ref:`DNNL_ARG_WORKSPACE `, pool_ws_memory); n_fwd++; if (pool_reorder_dst) n_fwd += 1; //----------------------------------------------------------------------- //----------------- Backward Stream ------------------------------------- //----------------------------------------------------------------------- // ... user diff_data ... float *net_diff_dst = (float *)malloc(product(pool_dst_sizes, 4) * sizeof(float)); init_net_data(net_diff_dst, 4, pool_dst_sizes); // create memory for user diff dst data :ref:`dnnl_memory_t ` pool_user_diff_dst_memory; init_data_memory(4, pool_dst_sizes, :ref:`dnnl_nchw `, engine, net_diff_dst, &pool_user_diff_dst_memory); // Pooling Backward // pooling diff src memory descriptor :ref:`const_dnnl_memory_desc_t ` pool_diff_src_md = lrn_dst_md; // pooling diff dst memory descriptor :ref:`const_dnnl_memory_desc_t ` pool_diff_dst_md = pool_dst_md; // backward primitive descriptor needs to hint forward descriptor :ref:`dnnl_primitive_desc_t ` pool_bwd_pd; CHECK(:ref:`dnnl_pooling_backward_primitive_desc_create `(&pool_bwd_pd, engine, :ref:`dnnl_pooling_max `, pool_diff_src_md, pool_diff_dst_md, pool_strides, pool_kernel, pool_dilation, pool_padding, pool_padding, pool_pd, NULL)); // create reorder primitive between user diff dst and pool diff dst // if required :ref:`dnnl_memory_t ` pool_diff_dst_memory, pool_internal_diff_dst_memory; :ref:`dnnl_primitive_t ` pool_reorder_diff_dst; CHECK(prepare_reorder(&pool_user_diff_dst_memory, pool_diff_dst_md, engine, 1, &pool_internal_diff_dst_memory, &pool_reorder_diff_dst, &n_bwd, net_bwd, net_bwd_args)); pool_diff_dst_memory = pool_internal_diff_dst_memory ? pool_internal_diff_dst_memory : pool_user_diff_dst_memory; // create memory for pool diff src data :ref:`dnnl_memory_t ` pool_diff_src_memory; CHECK(:ref:`dnnl_memory_create `(&pool_diff_src_memory, pool_diff_src_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // finally create backward pooling primitive :ref:`dnnl_primitive_t ` pool_bwd; CHECK(:ref:`dnnl_primitive_create `(&pool_bwd, pool_bwd_pd)); net_bwd[n_bwd] = pool_bwd; prepare_arg_node(&net_bwd_args[n_bwd], 3); set_arg(&net_bwd_args[n_bwd].args[0], :ref:`DNNL_ARG_DIFF_DST `, pool_diff_dst_memory); set_arg(&net_bwd_args[n_bwd].args[1], :ref:`DNNL_ARG_WORKSPACE `, pool_ws_memory); set_arg(&net_bwd_args[n_bwd].args[2], :ref:`DNNL_ARG_DIFF_SRC `, pool_diff_src_memory); n_bwd++; // Backward lrn :ref:`const_dnnl_memory_desc_t ` lrn_diff_dst_md = pool_diff_src_md; :ref:`const_dnnl_memory_desc_t ` lrn_diff_src_md = lrn_diff_dst_md; // create backward lrn descriptor :ref:`dnnl_primitive_desc_t ` lrn_bwd_pd; CHECK(:ref:`dnnl_lrn_backward_primitive_desc_create `(&lrn_bwd_pd, engine, :ref:`dnnl_lrn_across_channels `, lrn_diff_src_md, lrn_diff_dst_md, lrn_src_md, local_size, alpha, beta, k, lrn_pd, NULL)); // create memory for lrn diff src :ref:`dnnl_memory_t ` lrn_diff_src_memory; CHECK(:ref:`dnnl_memory_create `(&lrn_diff_src_memory, lrn_diff_src_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // finally create backward lrn primitive :ref:`dnnl_primitive_t ` lrn_bwd; CHECK(:ref:`dnnl_primitive_create `(&lrn_bwd, lrn_bwd_pd)); net_bwd[n_bwd] = lrn_bwd; prepare_arg_node(&net_bwd_args[n_bwd], 4); set_arg(&net_bwd_args[n_bwd].args[0], :ref:`DNNL_ARG_SRC `, relu_dst_memory); set_arg(&net_bwd_args[n_bwd].args[1], :ref:`DNNL_ARG_DIFF_DST `, pool_diff_src_memory); set_arg(&net_bwd_args[n_bwd].args[2], :ref:`DNNL_ARG_WORKSPACE `, lrn_ws_memory); set_arg(&net_bwd_args[n_bwd].args[3], :ref:`DNNL_ARG_DIFF_SRC `, lrn_diff_src_memory); n_bwd++; // Backward relu :ref:`const_dnnl_memory_desc_t ` relu_diff_src_md = lrn_diff_src_md; :ref:`const_dnnl_memory_desc_t ` relu_diff_dst_md = lrn_diff_src_md; // create backward relu descriptor :ref:`dnnl_primitive_desc_t ` relu_bwd_pd; CHECK(:ref:`dnnl_eltwise_backward_primitive_desc_create `(&relu_bwd_pd, engine, :ref:`dnnl_eltwise_relu `, relu_diff_src_md, relu_diff_dst_md, relu_src_md, negative_slope, 0, relu_pd, NULL)); // create memory for relu diff src :ref:`dnnl_memory_t ` relu_diff_src_memory; CHECK(:ref:`dnnl_memory_create `(&relu_diff_src_memory, relu_diff_src_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); // finally create backward relu primitive :ref:`dnnl_primitive_t ` relu_bwd; CHECK(:ref:`dnnl_primitive_create `(&relu_bwd, relu_bwd_pd)); net_bwd[n_bwd] = relu_bwd; prepare_arg_node(&net_bwd_args[n_bwd], 3); set_arg(&net_bwd_args[n_bwd].args[0], :ref:`DNNL_ARG_SRC `, conv_internal_dst_memory); set_arg(&net_bwd_args[n_bwd].args[1], :ref:`DNNL_ARG_DIFF_DST `, lrn_diff_src_memory); set_arg(&net_bwd_args[n_bwd].args[2], :ref:`DNNL_ARG_DIFF_SRC `, relu_diff_src_memory); n_bwd++; // Backward convolution with respect to weights float *conv_diff_bias_buffer = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float)); float *conv_user_diff_weights_buffer = (float *)malloc( product(conv_user_weights_sizes, 4) * sizeof(float)); // initialize memory for diff weights in user format :ref:`dnnl_memory_t ` conv_user_diff_weights_memory; init_data_memory(4, conv_user_weights_sizes, :ref:`dnnl_oihw `, engine, conv_user_diff_weights_buffer, &conv_user_diff_weights_memory); // create backward convolution primitive descriptor :ref:`dnnl_primitive_desc_t ` conv_bwd_weights_pd; { // memory descriptors should be in format `any` to allow backward // convolution for // weights to chose the format it prefers for best performance :ref:`dnnl_memory_desc_t ` conv_diff_src_md, conv_diff_weights_md, conv_diff_bias_md, conv_diff_dst_md; CHECK(:ref:`dnnl_memory_desc_create_with_tag `(&conv_diff_src_md, 4, conv_user_src_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_format_tag_any `)); CHECK(:ref:`dnnl_memory_desc_create_with_tag `(&conv_diff_weights_md, 4, conv_user_weights_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_format_tag_any `)); CHECK(:ref:`dnnl_memory_desc_create_with_tag `( &conv_diff_bias_md, 1, conv_bias_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_x `)); CHECK(:ref:`dnnl_memory_desc_create_with_tag `(&conv_diff_dst_md, 4, conv_user_dst_sizes, :ref:`dnnl_f32 `, :ref:`dnnl_format_tag_any `)); // create backward convolution descriptor CHECK(:ref:`dnnl_convolution_backward_weights_primitive_desc_create `( &conv_bwd_weights_pd, engine, :ref:`dnnl_convolution_direct `, conv_diff_src_md, conv_diff_weights_md, conv_diff_bias_md, conv_diff_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding, conv_pd, NULL)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_diff_src_md)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_diff_weights_md)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_diff_bias_md)); CHECK(:ref:`dnnl_memory_desc_destroy `(conv_diff_dst_md)); } // for best performance convolution backward might chose // different memory format for src and diff_dst // than the memory formats preferred by forward convolution // for src and dst respectively // create reorder primitives for src from forward convolution to the // format chosen by backward convolution :ref:`dnnl_primitive_t ` conv_bwd_reorder_src; :ref:`dnnl_memory_t ` conv_bwd_internal_src_memory; :ref:`const_dnnl_memory_desc_t ` conv_diff_src_md = :ref:`dnnl_primitive_desc_query_md `( conv_bwd_weights_pd, :ref:`dnnl_query_src_md `, 0); CHECK(prepare_reorder(&conv_src_memory, conv_diff_src_md, engine, 1, &conv_bwd_internal_src_memory, &conv_bwd_reorder_src, &n_bwd, net_bwd, net_bwd_args)); :ref:`dnnl_memory_t ` conv_bwd_weights_src_memory = conv_bwd_internal_src_memory ? conv_bwd_internal_src_memory : conv_src_memory; // create reorder primitives for diff_dst between diff_src from relu_bwd // and format preferred by conv_diff_weights :ref:`dnnl_primitive_t ` conv_reorder_diff_dst; :ref:`dnnl_memory_t ` conv_internal_diff_dst_memory; :ref:`const_dnnl_memory_desc_t ` conv_diff_dst_md = :ref:`dnnl_primitive_desc_query_md `( conv_bwd_weights_pd, :ref:`dnnl_query_diff_dst_md `, 0); CHECK(prepare_reorder(&relu_diff_src_memory, conv_diff_dst_md, engine, 1, &conv_internal_diff_dst_memory, &conv_reorder_diff_dst, &n_bwd, net_bwd, net_bwd_args)); :ref:`dnnl_memory_t ` conv_diff_dst_memory = conv_internal_diff_dst_memory ? conv_internal_diff_dst_memory : relu_diff_src_memory; // create reorder primitives for conv diff weights memory :ref:`dnnl_primitive_t ` conv_reorder_diff_weights; :ref:`dnnl_memory_t ` conv_internal_diff_weights_memory; :ref:`const_dnnl_memory_desc_t ` conv_diff_weights_md = :ref:`dnnl_primitive_desc_query_md `( conv_bwd_weights_pd, :ref:`dnnl_query_diff_weights_md `, 0); n_bwd += 1; // tentative workaround: preserve space for conv_bwd_weights // that should happen before the reorder CHECK(prepare_reorder(&conv_user_diff_weights_memory, conv_diff_weights_md, engine, 0, &conv_internal_diff_weights_memory, &conv_reorder_diff_weights, &n_bwd, net_bwd, net_bwd_args)); n_bwd -= conv_reorder_diff_weights ? 2 : 1; :ref:`dnnl_memory_t ` conv_diff_weights_memory = conv_internal_diff_weights_memory ? conv_internal_diff_weights_memory : conv_user_diff_weights_memory; // create memory for diff bias memory :ref:`dnnl_memory_t ` conv_diff_bias_memory; :ref:`const_dnnl_memory_desc_t ` conv_diff_bias_md = :ref:`dnnl_primitive_desc_query_md `( conv_bwd_weights_pd, :ref:`dnnl_query_diff_weights_md `, 1); CHECK(:ref:`dnnl_memory_create `(&conv_diff_bias_memory, conv_diff_bias_md, engine, :ref:`DNNL_MEMORY_ALLOCATE `)); CHECK(:ref:`dnnl_memory_set_data_handle `( conv_diff_bias_memory, conv_diff_bias_buffer)); // finally created backward convolution weights primitive :ref:`dnnl_primitive_t ` conv_bwd_weights; CHECK(:ref:`dnnl_primitive_create `(&conv_bwd_weights, conv_bwd_weights_pd)); net_bwd[n_bwd] = conv_bwd_weights; prepare_arg_node(&net_bwd_args[n_bwd], 4); set_arg(&net_bwd_args[n_bwd].args[0], :ref:`DNNL_ARG_SRC `, conv_bwd_weights_src_memory); set_arg(&net_bwd_args[n_bwd].args[1], :ref:`DNNL_ARG_DIFF_DST `, conv_diff_dst_memory); set_arg(&net_bwd_args[n_bwd].args[2], :ref:`DNNL_ARG_DIFF_WEIGHTS `, conv_diff_weights_memory); set_arg(&net_bwd_args[n_bwd].args[3], :ref:`DNNL_ARG_DIFF_BIAS `, conv_diff_bias_memory); n_bwd++; if (conv_reorder_diff_weights) n_bwd += 1; // output from backward stream void *net_diff_weights = NULL; void *net_diff_bias = NULL; int n_iter = 10; // number of iterations for training. :ref:`dnnl_stream_t ` stream; CHECK(:ref:`dnnl_stream_create `(&stream, engine, :ref:`dnnl_stream_default_flags `)); // Execute the net for (int i = 0; i < n_iter; i++) { for (uint32_t i = 0; i < n_fwd; ++i) CHECK(:ref:`dnnl_primitive_execute `(net_fwd[i], stream, net_fwd_args[i].nargs, net_fwd_args[i].args)); // Update net_diff_dst void *net_output = NULL; // output from forward stream: CHECK(:ref:`dnnl_memory_get_data_handle `(pool_user_dst_memory, &net_output)); // ...user updates net_diff_dst using net_output... // some user defined func update_diff_dst(net_diff_dst, net_output) // Backward pass for (uint32_t i = 0; i < n_bwd; ++i) CHECK(:ref:`dnnl_primitive_execute `(net_bwd[i], stream, net_bwd_args[i].nargs, net_bwd_args[i].args)); // ... update weights ... CHECK(:ref:`dnnl_memory_get_data_handle `( conv_user_diff_weights_memory, &net_diff_weights)); CHECK(:ref:`dnnl_memory_get_data_handle `( conv_diff_bias_memory, &net_diff_bias)); // ...user updates weights and bias using diff weights and bias... // some user defined func update_weights(conv_user_weights_memory, // conv_bias_memory, // net_diff_weights, net_diff_bias); } CHECK(:ref:`dnnl_stream_wait `(stream)); :ref:`dnnl_stream_destroy `(stream); // clean up nets for (uint32_t i = 0; i < n_fwd; ++i) free_arg_node(&net_fwd_args[i]); for (uint32_t i = 0; i < n_bwd; ++i) free_arg_node(&net_bwd_args[i]); // Cleanup forward CHECK(:ref:`dnnl_primitive_desc_destroy `(pool_pd)); CHECK(:ref:`dnnl_primitive_desc_destroy `(lrn_pd)); CHECK(:ref:`dnnl_primitive_desc_destroy `(relu_pd)); CHECK(:ref:`dnnl_primitive_desc_destroy `(conv_pd)); free(net_src); free(net_dst); :ref:`dnnl_memory_destroy `(conv_user_src_memory); :ref:`dnnl_memory_destroy `(conv_user_weights_memory); :ref:`dnnl_memory_destroy `(conv_user_bias_memory); :ref:`dnnl_memory_destroy `(conv_internal_src_memory); :ref:`dnnl_memory_destroy `(conv_internal_weights_memory); :ref:`dnnl_memory_destroy `(conv_internal_dst_memory); :ref:`dnnl_primitive_destroy `(conv_reorder_src); :ref:`dnnl_primitive_destroy `(conv_reorder_weights); :ref:`dnnl_primitive_destroy `(conv); free(conv_weights); free(conv_bias); :ref:`dnnl_memory_destroy `(relu_dst_memory); :ref:`dnnl_primitive_destroy `(relu); :ref:`dnnl_memory_destroy `(lrn_ws_memory); :ref:`dnnl_memory_destroy `(lrn_dst_memory); :ref:`dnnl_primitive_destroy `(lrn); :ref:`dnnl_memory_destroy `(pool_user_dst_memory); :ref:`dnnl_memory_destroy `(pool_internal_dst_memory); :ref:`dnnl_memory_destroy `(pool_ws_memory); :ref:`dnnl_primitive_destroy `(pool_reorder_dst); :ref:`dnnl_primitive_destroy `(pool); // Cleanup backward CHECK(:ref:`dnnl_primitive_desc_destroy `(pool_bwd_pd)); CHECK(:ref:`dnnl_primitive_desc_destroy `(lrn_bwd_pd)); CHECK(:ref:`dnnl_primitive_desc_destroy `(relu_bwd_pd)); CHECK(:ref:`dnnl_primitive_desc_destroy `(conv_bwd_weights_pd)); :ref:`dnnl_memory_destroy `(pool_user_diff_dst_memory); :ref:`dnnl_memory_destroy `(pool_diff_src_memory); :ref:`dnnl_memory_destroy `(pool_internal_diff_dst_memory); :ref:`dnnl_primitive_destroy `(pool_reorder_diff_dst); :ref:`dnnl_primitive_destroy `(pool_bwd); free(net_diff_dst); :ref:`dnnl_memory_destroy `(lrn_diff_src_memory); :ref:`dnnl_primitive_destroy `(lrn_bwd); :ref:`dnnl_memory_destroy `(relu_diff_src_memory); :ref:`dnnl_primitive_destroy `(relu_bwd); :ref:`dnnl_memory_destroy `(conv_user_diff_weights_memory); :ref:`dnnl_memory_destroy `(conv_diff_bias_memory); :ref:`dnnl_memory_destroy `(conv_bwd_internal_src_memory); :ref:`dnnl_primitive_destroy `(conv_bwd_reorder_src); :ref:`dnnl_memory_destroy `(conv_internal_diff_dst_memory); :ref:`dnnl_primitive_destroy `(conv_reorder_diff_dst); :ref:`dnnl_memory_destroy `(conv_internal_diff_weights_memory); :ref:`dnnl_primitive_destroy `(conv_reorder_diff_weights); :ref:`dnnl_primitive_destroy `(conv_bwd_weights); free(conv_diff_bias_buffer); free(conv_user_diff_weights_buffer); :ref:`dnnl_engine_destroy `(engine); } int main(int argc, char **argv) { simple_net(); printf("Example passed on CPU.\n"); return 0; }