Deep Neural Network Library (DNNL)  1.1.3
Performance library for Deep Learning
rnn_training_f32.cpp

This C++ API example demonstrates how to build GNMT model training.

/*******************************************************************************
* Copyright 2018-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cstring>
#include <math.h>
#include <numeric>
#include "example_utils.hpp"
using namespace dnnl;
// User input is:
// N0 sequences of length T0
const int N0 = 1 + rand() % 31;
// N1 sequences of length T1
const int N1 = 1 + rand() % 31;
// Assume T0 > T1
const int T0 = 31 + 1 + rand() % 31;
const int T1 = 1 + rand() % 31;
// Memory required to hold it: N0 * T0 + N1 * T1
// However it is possible to have these coming
// as padded chunks in larger memory:
// e.g. (N0 + N1) * T0
// We don't need to compact the data before processing,
// we can address the chunks via sub-memory and
// process the data via two RNN primitives:
// of time lengths T1 and T0 - T1.
// The leftmost primitive will process N0 + N1 subsequences of length T1
// The rightmost primitive will process remaining N0 subsequences
// of T0 - T1 length
const int leftmost_batch = N0 + N1;
const int rightmost_batch = N0;
const int leftmost_seq_length = T1;
const int rightmost_seq_length = T0 - T1;
// Number of channels
const int common_feature_size = 1024;
// RNN primitive characteristics
const int common_n_layers = 1;
const int lstm_n_gates = 4;
void simple_net(engine::kind engine_kind) {
using tag = memory::format_tag;
using dt = memory::data_type;
auto eng = engine(engine_kind, 0);
stream s(eng);
bool is_training = true;
auto fwd_inf_train = is_training ? prop_kind::forward_training
std::vector<primitive> fwd_net;
std::vector<primitive> bwd_net;
// Input tensor holds two batches with different sequence lengths.
// Shorter sequences are padded
memory::dims net_src_dims = {
T0, // time, maximum sequence length
N0 + N1, // n, total batch size
common_feature_size // c, common number of channels
};
// Two RNN primitives for different sequence lengths,
// one unidirectional layer, LSTM-based
memory::dims leftmost_src_layer_dims = {
leftmost_seq_length, // time
leftmost_batch, // n
common_feature_size // c
};
memory::dims rightmost_src_layer_dims = {
rightmost_seq_length, // time
rightmost_batch, // n
common_feature_size // c
};
memory::dims common_weights_layer_dims = {
common_n_layers, // layers
1, // directions
common_feature_size, // input feature size
lstm_n_gates, // gates number
common_feature_size // output feature size
};
memory::dims common_weights_iter_dims = {
common_n_layers, // layers
1, // directions
common_feature_size, // input feature size
lstm_n_gates, // gates number
common_feature_size // output feature size
};
memory::dims common_bias_dims = {
common_n_layers, // layers
1, // directions
lstm_n_gates, // gates number
common_feature_size // output feature size
};
memory::dims leftmost_dst_layer_dims = {
leftmost_seq_length, // time
leftmost_batch, // n
common_feature_size // c
};
memory::dims rightmost_dst_layer_dims = {
rightmost_seq_length, // time
rightmost_batch, // n
common_feature_size // c
};
// leftmost primitive passes its states to the next RNN iteration
// so it needs dst_iter parameter.
//
// rightmost primitive will consume these as src_iter and will access the
// memory via a sub-memory because it will have different batch dimension.
// We have arranged our primitives so that
// leftmost_batch >= rightmost_batch, and so the rightmost data will fit
// into the memory allocated for the leftmost.
memory::dims leftmost_dst_iter_dims = {
common_n_layers, // layers
1, // directions
leftmost_batch, // n
common_feature_size // c
};
memory::dims leftmost_dst_iter_c_dims = {
common_n_layers, // layers
1, // directions
leftmost_batch, // n
common_feature_size // c
};
memory::dims rightmost_src_iter_dims = {
common_n_layers, // layers
1, // directions
rightmost_batch, // n
common_feature_size // c
};
memory::dims rightmost_src_iter_c_dims = {
common_n_layers, // layers
1, // directions
rightmost_batch, // n
common_feature_size // c
};
// multiplication of tensor dimensions
auto tz_volume = [=](memory::dims tz_dims) {
return std::accumulate(tz_dims.begin(), tz_dims.end(), (memory::dim)1,
std::multiplies<memory::dim>());
};
// Create auxillary f32 memory descriptor
// based on user- supplied dimensions and layout.
auto formatted_md
= [=](memory::dims dimensions, memory::format_tag layout) {
return memory::desc {{dimensions}, dt::f32, layout};
};
// Create auxillary generic f32 memory descriptor
// based on supplied dimensions, with format_tag::any.
auto generic_md = [=](memory::dims dimensions) {
return formatted_md(dimensions, tag::any);
};
//
// I/O memory, coming from user
//
// Net input
std::vector<float> net_src(tz_volume(net_src_dims), 1.0f);
// NOTE: in this example we study input sequences with variable batch
// dimension, which get processed by two separate RNN primitives, thus
// the destination memory for the two will have different shapes: batch
// is the second dimension currently: see format_tag::tnc.
// We are not copying the output to some common user provided memory as we
// suggest that the user should rather keep the two output memories separate
// throughout the whole topology and only reorder to something else as
// needed.
// So there's no common net_dst, but there are two destinations instead:
// leftmost_dst_layer_memory
// rightmost_dst_layer_memory
// Memory for the user allocated memory
// Suppose user data is in tnc format.
auto net_src_memory
= dnnl::memory({{net_src_dims}, dt::f32, tag::tnc}, eng);
write_to_dnnl_memory(net_src.data(), net_src_memory);
// src_layer memory of the leftmost and rightmost RNN primitives
// are accessed through the respective sub-memories in larger memory.
// View primitives compute the strides to accommodate for padding.
auto user_leftmost_src_layer_md = net_src_memory.get_desc().submemory_desc(
leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets
auto user_rightmost_src_layer_md
= net_src_memory.get_desc().submemory_desc(rightmost_src_layer_dims,
{leftmost_seq_length, 0, 0}); // t, n, c offsets
auto leftmost_src_layer_memory = net_src_memory;
auto rightmost_src_layer_memory = net_src_memory;
// Other user provided memory arrays, descriptors and primitives with the
// data layouts chosen by user. We'll have to reorder if RNN
// primitive prefers it in a different format.
std::vector<float> user_common_weights_layer(
tz_volume(common_weights_layer_dims), 1.0f);
auto user_common_weights_layer_memory = dnnl::memory(
{common_weights_layer_dims, dt::f32, tag::ldigo}, eng);
write_to_dnnl_memory(
user_common_weights_layer.data(), user_common_weights_layer_memory);
std::vector<float> user_common_weights_iter(
tz_volume(common_weights_iter_dims), 1.0f);
auto user_common_weights_iter_memory = dnnl::memory(
{{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng);
write_to_dnnl_memory(
user_common_weights_layer.data(), user_common_weights_iter_memory);
std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_bias_memory
= dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng);
write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory);
std::vector<float> user_leftmost_dst_layer(
tz_volume(leftmost_dst_layer_dims), 1.0f);
auto user_leftmost_dst_layer_memory
= dnnl::memory({{leftmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
write_to_dnnl_memory(
user_leftmost_dst_layer.data(), user_leftmost_dst_layer_memory);
std::vector<float> user_rightmost_dst_layer(
tz_volume(rightmost_dst_layer_dims), 1.0f);
auto user_rightmost_dst_layer_memory = dnnl::memory(
{{rightmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
write_to_dnnl_memory(
user_rightmost_dst_layer.data(), user_rightmost_dst_layer_memory);
// Describe layer, forward pass, leftmost primitive.
// There are no primitives to the left from here,
// so src_iter_desc needs to be zero memory desc
lstm_forward::desc leftmost_layer_desc(fwd_inf_train, // aprop_kind
rnn_direction::unidirectional_left2right, // direction
user_leftmost_src_layer_md, // src_layer_desc
memory::desc(), // src_iter_desc
memory::desc(), // src_iter_c_desc
generic_md(common_weights_layer_dims), // weights_layer_desc
generic_md(common_weights_iter_dims), // weights_iter_desc
generic_md(common_bias_dims), // bias_desc
formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc
generic_md(leftmost_dst_iter_dims), // dst_iter_desc
generic_md(leftmost_dst_iter_c_dims) // dst_iter_c_desc
);
// Describe primitive
auto leftmost_prim_desc
= dnnl::lstm_forward::primitive_desc(leftmost_layer_desc, eng);
//
// Need to connect leftmost and rightmost via "iter" parameters.
// We allocate memory here based on the shapes provided by RNN primitive.
//
auto leftmost_dst_iter_memory
= dnnl::memory(leftmost_prim_desc.dst_iter_desc(), eng);
auto leftmost_dst_iter_c_memory
= dnnl::memory(leftmost_prim_desc.dst_iter_c_desc(), eng);
// rightmost src_iter will be a sub-memory of dst_iter of leftmost
auto rightmost_src_iter_md
= leftmost_dst_iter_memory.get_desc().submemory_desc(
rightmost_src_iter_dims,
{0, 0, 0, 0}); // l, d, n, c offsets
auto rightmost_src_iter_memory = leftmost_dst_iter_memory;
auto rightmost_src_iter_c_md
= leftmost_dst_iter_c_memory.get_desc().submemory_desc(
rightmost_src_iter_c_dims,
{0, 0, 0, 0}); // l, d, n, c offsets
auto rightmost_src_iter_c_memory = leftmost_dst_iter_c_memory;
// Now rightmost primitive
// There are no primitives to the right from here,
// so dst_iter_desc is explicit zero memory desc
lstm_forward::desc rightmost_layer_desc(fwd_inf_train, // aprop_kind
rnn_direction::unidirectional_left2right, // direction
user_rightmost_src_layer_md, // src_layer_desc
rightmost_src_iter_md, // src_iter_desc
rightmost_src_iter_c_md, // src_iter_c_desc
generic_md(common_weights_layer_dims), // weights_layer_desc
generic_md(common_weights_iter_dims), // weights_iter_desc
generic_md(common_bias_dims), // bias_desc
formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc
memory::desc(), // dst_iter_desc
memory::desc() // dst_iter_c_desc
);
auto rightmost_prim_desc
= lstm_forward::primitive_desc(rightmost_layer_desc, eng);
//
// Weights and biases, layer memory
// Same layout should work across the layer, no reorders
// needed between leftmost and rigthmost, only reordering
// user memory to the RNN-friendly shapes.
//
auto common_weights_layer_memory = user_common_weights_layer_memory;
if (leftmost_prim_desc.weights_layer_desc()
!= common_weights_layer_memory.get_desc()) {
common_weights_layer_memory
= dnnl::memory(leftmost_prim_desc.weights_layer_desc(), eng);
reorder(user_common_weights_layer_memory, common_weights_layer_memory)
.execute(s, user_common_weights_layer_memory,
common_weights_layer_memory);
}
auto common_weights_iter_memory = user_common_weights_iter_memory;
if (leftmost_prim_desc.weights_iter_desc()
!= common_weights_iter_memory.get_desc()) {
common_weights_iter_memory
= dnnl::memory(leftmost_prim_desc.weights_iter_desc(), eng);
reorder(user_common_weights_iter_memory, common_weights_iter_memory)
.execute(s, user_common_weights_iter_memory,
common_weights_iter_memory);
}
auto common_bias_memory = user_common_bias_memory;
if (leftmost_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
common_bias_memory = dnnl::memory(leftmost_prim_desc.bias_desc(), eng);
reorder(user_common_bias_memory, common_bias_memory)
.execute(s, user_common_bias_memory, common_bias_memory);
}
//
// Destination layer memory
//
auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory;
if (leftmost_prim_desc.dst_layer_desc()
!= leftmost_dst_layer_memory.get_desc()) {
leftmost_dst_layer_memory
= dnnl::memory(leftmost_prim_desc.dst_layer_desc(), eng);
reorder(user_leftmost_dst_layer_memory, leftmost_dst_layer_memory)
.execute(s, user_leftmost_dst_layer_memory,
leftmost_dst_layer_memory);
}
auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory;
if (rightmost_prim_desc.dst_layer_desc()
!= rightmost_dst_layer_memory.get_desc()) {
rightmost_dst_layer_memory
= dnnl::memory(rightmost_prim_desc.dst_layer_desc(), eng);
reorder(user_rightmost_dst_layer_memory, rightmost_dst_layer_memory)
.execute(s, user_rightmost_dst_layer_memory,
rightmost_dst_layer_memory);
}
// We also create workspace memory based on the information from
// the workspace_primitive_desc(). This is needed for internal
// communication between forward and backward primitives during
// training.
auto create_ws = [=](dnnl::lstm_forward::primitive_desc &pd) {
return dnnl::memory(pd.workspace_desc(), eng);
};
auto leftmost_workspace_memory = create_ws(leftmost_prim_desc);
auto rightmost_workspace_memory = create_ws(rightmost_prim_desc);
// Construct the RNN primitive objects
lstm_forward leftmost_layer(leftmost_prim_desc);
leftmost_layer.execute(s,
{{DNNL_ARG_SRC_LAYER, leftmost_src_layer_memory},
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory},
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory},
{DNNL_ARG_BIAS, common_bias_memory},
{DNNL_ARG_DST_LAYER, leftmost_dst_layer_memory},
{DNNL_ARG_DST_ITER, leftmost_dst_iter_memory},
{DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_memory},
{DNNL_ARG_WORKSPACE, leftmost_workspace_memory}});
lstm_forward rightmost_layer(rightmost_prim_desc);
rightmost_layer.execute(s,
{{DNNL_ARG_SRC_LAYER, rightmost_src_layer_memory},
{DNNL_ARG_SRC_ITER, rightmost_src_iter_memory},
{DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory},
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory},
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory},
{DNNL_ARG_BIAS, common_bias_memory},
{DNNL_ARG_DST_LAYER, rightmost_dst_layer_memory},
{DNNL_ARG_WORKSPACE, rightmost_workspace_memory}});
// No backward pass for inference
if (!is_training) return;
//
// Backward primitives will reuse memory from forward
// and allocate/describe specifics here. Only relevant for training.
//
// User-provided memory for backward by data output
std::vector<float> net_diff_src(tz_volume(net_src_dims), 1.0f);
auto net_diff_src_memory
= dnnl::memory(formatted_md(net_src_dims, tag::tnc), eng);
write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory);
// diff_src follows the same layout we have for net_src
auto user_leftmost_diff_src_layer_md
= net_diff_src_memory.get_desc().submemory_desc(
leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets
auto user_rightmost_diff_src_layer_md
= net_diff_src_memory.get_desc().submemory_desc(
rightmost_src_layer_dims,
{leftmost_seq_length, 0, 0}); // t, n, c offsets
auto leftmost_diff_src_layer_memory = net_diff_src_memory;
auto rightmost_diff_src_layer_memory = net_diff_src_memory;
// User-provided memory for backpropagation by weights
std::vector<float> user_common_diff_weights_layer(
tz_volume(common_weights_layer_dims), 1.0f);
auto user_common_diff_weights_layer_memory = dnnl::memory(
formatted_md(common_weights_layer_dims, tag::ldigo), eng);
write_to_dnnl_memory(user_common_diff_weights_layer.data(),
user_common_diff_weights_layer_memory);
std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_diff_bias_memory
= dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng);
write_to_dnnl_memory(
user_common_diff_bias.data(), user_common_diff_bias_memory);
// User-provided input to the backward primitive.
// To be updated by the user after forward pass using some cost function.
memory::dims net_diff_dst_dims = {
T0, // time
N0 + N1, // n
common_feature_size // c
};
// Suppose user data is in tnc format.
std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f);
auto net_diff_dst_memory
= dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng);
write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory);
// diff_dst_layer memory of the leftmost and rightmost RNN primitives
// are accessed through the respective sub-memory in larger memory.
// View primitives compute the strides to accommodate for padding.
auto user_leftmost_diff_dst_layer_md
= net_diff_dst_memory.get_desc().submemory_desc(
leftmost_dst_layer_dims, {0, 0, 0}); // t, n, c offsets
auto user_rightmost_diff_dst_layer_md
= net_diff_dst_memory.get_desc().submemory_desc(
rightmost_dst_layer_dims,
{leftmost_seq_length, 0, 0}); // t, n, c offsets
auto leftmost_diff_dst_layer_memory = net_diff_dst_memory;
auto rightmost_diff_dst_layer_memory = net_diff_dst_memory;
// Backward leftmost primitive descriptor
lstm_backward::desc leftmost_layer_bwd_desc(
prop_kind::backward, // aprop_kind
rnn_direction::unidirectional_left2right, // direction
user_leftmost_src_layer_md, // src_layer_desc
memory::desc(), // src_iter_desc
memory::desc(), // src_iter_c_desc
generic_md(common_weights_layer_dims), // weights_layer_desc
generic_md(common_weights_iter_dims), // weights_iter_desc
generic_md(common_bias_dims), // bias_desc
formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc
generic_md(leftmost_dst_iter_dims), // dst_iter_desc
generic_md(leftmost_dst_iter_c_dims), // dst_iter_c_desc
user_leftmost_diff_src_layer_md, // diff_src_layer_desc
memory::desc(), // diff_src_iter_desc
memory::desc(), // diff_src_iter_c_desc
generic_md(common_weights_layer_dims), // diff_weights_layer_desc
generic_md(common_weights_iter_dims), // diff_weights_iter_desc
generic_md(common_bias_dims), // diff_bias_desc
user_leftmost_diff_dst_layer_md, // diff_dst_layer_desc
generic_md(leftmost_dst_iter_dims), // diff_dst_iter_desc
generic_md(leftmost_dst_iter_c_dims) // diff_dst_iter_c_desc
);
auto leftmost_bwd_prim_desc = lstm_backward::primitive_desc(
leftmost_layer_bwd_desc, eng, leftmost_prim_desc);
// As the batch dimensions are different between leftmost and rightmost
// we need to use a sub-memory. rightmost needs less memory, so it will
// be a sub-memory of leftmost.
auto leftmost_diff_dst_iter_memory
= dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_desc(), eng);
auto leftmost_diff_dst_iter_c_memory
= dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_c_desc(), eng);
auto rightmost_diff_src_iter_md
= leftmost_diff_dst_iter_memory.get_desc().submemory_desc(
rightmost_src_iter_dims,
{0, 0, 0, 0}); // l, d, n, c offsets
auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory;
auto rightmost_diff_src_iter_c_md
= leftmost_diff_dst_iter_c_memory.get_desc().submemory_desc(
rightmost_src_iter_c_dims,
{0, 0, 0, 0}); // l, d, n, c offsets
auto rightmost_diff_src_iter_c_memory = leftmost_diff_dst_iter_c_memory;
// Backward rightmost primitive descriptor
lstm_backward::desc rightmost_layer_bwd_desc(
prop_kind::backward, // aprop_kind
rnn_direction::unidirectional_left2right, // direction
user_rightmost_src_layer_md, // src_layer_desc
generic_md(rightmost_src_iter_dims), // src_iter_desc
generic_md(rightmost_src_iter_c_dims), // src_iter_c_desc
generic_md(common_weights_layer_dims), // weights_layer_desc
generic_md(common_weights_iter_dims), // weights_iter_desc
generic_md(common_bias_dims), // bias_desc
formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc
memory::desc(), // dst_iter_desc
memory::desc(), // dst_iter_c_desc
user_rightmost_diff_src_layer_md, // diff_src_layer_desc
rightmost_diff_src_iter_md, // diff_src_iter_desc
rightmost_diff_src_iter_c_md, // diff_src_iter_c_desc
generic_md(common_weights_layer_dims), // diff_weights_layer_desc
generic_md(common_weights_iter_dims), // diff_weights_iter_desc
generic_md(common_bias_dims), // diff_bias_desc
user_rightmost_diff_dst_layer_md, // diff_dst_layer_desc
memory::desc(), // diff_dst_iter_desc
memory::desc() // diff_dst_iter_c_desc
);
auto rightmost_bwd_prim_desc = lstm_backward::primitive_desc(
rightmost_layer_bwd_desc, eng, rightmost_prim_desc);
//
// Memory for backward pass
//
// src layer uses the same memory as forward
auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory;
auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory;
// Memory for weights and biases for backward pass
// Try to use the same memory between forward and backward, but
// sometimes reorders are needed.
auto common_weights_layer_bwd_memory = common_weights_layer_memory;
if (leftmost_bwd_prim_desc.weights_layer_desc()
!= leftmost_prim_desc.weights_layer_desc()) {
common_weights_layer_bwd_memory
= memory(leftmost_bwd_prim_desc.weights_layer_desc(), eng);
reorder(common_weights_layer_memory, common_weights_layer_bwd_memory)
.execute(s, common_weights_layer_memory,
common_weights_layer_bwd_memory);
}
auto common_weights_iter_bwd_memory = common_weights_iter_memory;
if (leftmost_bwd_prim_desc.weights_iter_desc()
!= leftmost_prim_desc.weights_iter_desc()) {
common_weights_iter_bwd_memory
= memory(leftmost_bwd_prim_desc.weights_iter_desc(), eng);
reorder(common_weights_iter_memory, common_weights_iter_bwd_memory)
.execute(s, common_weights_iter_memory,
common_weights_iter_bwd_memory);
}
auto common_bias_bwd_memory = common_bias_memory;
if (leftmost_bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
common_bias_bwd_memory
= dnnl::memory(leftmost_bwd_prim_desc.bias_desc(), eng);
reorder(common_bias_memory, common_bias_bwd_memory)
.execute(s, common_bias_memory, common_bias_bwd_memory);
}
// diff_weights and biases
auto common_diff_weights_layer_memory
= user_common_diff_weights_layer_memory;
auto reorder_common_diff_weights_layer = false;
if (leftmost_bwd_prim_desc.diff_weights_layer_desc()
!= common_diff_weights_layer_memory.get_desc()) {
common_diff_weights_layer_memory = dnnl::memory(
leftmost_bwd_prim_desc.diff_weights_layer_desc(), eng);
reorder_common_diff_weights_layer = true;
}
auto common_diff_bias_memory = user_common_diff_bias_memory;
auto reorder_common_diff_bias = false;
if (leftmost_bwd_prim_desc.diff_bias_desc()
!= common_diff_bias_memory.get_desc()) {
common_diff_bias_memory
= dnnl::memory(leftmost_bwd_prim_desc.diff_bias_desc(), eng);
reorder_common_diff_bias = true;
}
// dst_layer memory for backward pass
auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory;
if (leftmost_bwd_prim_desc.dst_layer_desc()
!= leftmost_dst_layer_bwd_memory.get_desc()) {
leftmost_dst_layer_bwd_memory
= dnnl::memory(leftmost_bwd_prim_desc.dst_layer_desc(), eng);
reorder(leftmost_dst_layer_memory, leftmost_dst_layer_bwd_memory)
.execute(s, leftmost_dst_layer_memory,
leftmost_dst_layer_bwd_memory);
}
auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory;
if (rightmost_bwd_prim_desc.dst_layer_desc()
!= rightmost_dst_layer_bwd_memory.get_desc()) {
rightmost_dst_layer_bwd_memory
= dnnl::memory(rightmost_bwd_prim_desc.dst_layer_desc(), eng);
reorder(rightmost_dst_layer_memory, rightmost_dst_layer_bwd_memory)
.execute(s, rightmost_dst_layer_memory,
rightmost_dst_layer_bwd_memory);
}
// Similar to forward, the backward primitives are connected
// via "iter" parameters.
auto common_diff_weights_iter_memory = dnnl::memory(
leftmost_bwd_prim_desc.diff_weights_iter_desc(), eng);
auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory;
if (leftmost_bwd_prim_desc.dst_iter_desc()
!= leftmost_dst_iter_bwd_memory.get_desc()) {
leftmost_dst_iter_bwd_memory
= dnnl::memory(leftmost_bwd_prim_desc.dst_iter_desc(), eng);
reorder(leftmost_dst_iter_memory, leftmost_dst_iter_bwd_memory)
.execute(s, leftmost_dst_iter_memory,
leftmost_dst_iter_bwd_memory);
}
auto leftmost_dst_iter_c_bwd_memory = leftmost_dst_iter_c_memory;
if (leftmost_bwd_prim_desc.dst_iter_c_desc()
!= leftmost_dst_iter_c_bwd_memory.get_desc()) {
leftmost_dst_iter_c_bwd_memory
= dnnl::memory(leftmost_bwd_prim_desc.dst_iter_c_desc(), eng);
reorder(leftmost_dst_iter_c_memory, leftmost_dst_iter_c_bwd_memory)
.execute(s, leftmost_dst_iter_c_memory,
leftmost_dst_iter_c_bwd_memory);
}
// Construct the RNN primitive objects for backward
lstm_backward rightmost_layer_bwd(rightmost_bwd_prim_desc);
rightmost_layer_bwd.execute(s,
{{DNNL_ARG_SRC_LAYER, rightmost_src_layer_bwd_memory},
{DNNL_ARG_SRC_ITER, rightmost_src_iter_memory},
{DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory},
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory},
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory},
{DNNL_ARG_BIAS, common_bias_bwd_memory},
{DNNL_ARG_DST_LAYER, rightmost_dst_layer_bwd_memory},
{DNNL_ARG_DIFF_SRC_LAYER, rightmost_diff_src_layer_memory},
{DNNL_ARG_DIFF_SRC_ITER, rightmost_diff_src_iter_memory},
{DNNL_ARG_DIFF_SRC_ITER_C,
rightmost_diff_src_iter_c_memory},
{DNNL_ARG_DIFF_WEIGHTS_LAYER,
common_diff_weights_layer_memory},
{DNNL_ARG_DIFF_WEIGHTS_ITER,
common_diff_weights_iter_memory},
{DNNL_ARG_DIFF_BIAS, common_diff_bias_memory},
{DNNL_ARG_DIFF_DST_LAYER, rightmost_diff_dst_layer_memory},
{DNNL_ARG_WORKSPACE, rightmost_workspace_memory}});
lstm_backward leftmost_layer_bwd(leftmost_bwd_prim_desc);
leftmost_layer_bwd.execute(s,
{{DNNL_ARG_SRC_LAYER, leftmost_src_layer_bwd_memory},
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory},
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory},
{DNNL_ARG_BIAS, common_bias_bwd_memory},
{DNNL_ARG_DST_LAYER, leftmost_dst_layer_bwd_memory},
{DNNL_ARG_DST_ITER, leftmost_dst_iter_bwd_memory},
{DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_bwd_memory},
{DNNL_ARG_DIFF_SRC_LAYER, leftmost_diff_src_layer_memory},
{DNNL_ARG_DIFF_WEIGHTS_LAYER,
common_diff_weights_layer_memory},
{DNNL_ARG_DIFF_WEIGHTS_ITER,
common_diff_weights_iter_memory},
{DNNL_ARG_DIFF_BIAS, common_diff_bias_memory},
{DNNL_ARG_DIFF_DST_LAYER, leftmost_diff_dst_layer_memory},
{DNNL_ARG_DIFF_DST_ITER, leftmost_diff_dst_iter_memory},
{DNNL_ARG_DIFF_DST_ITER_C, leftmost_diff_dst_iter_c_memory},
{DNNL_ARG_WORKSPACE, leftmost_workspace_memory}});
if (reorder_common_diff_weights_layer) {
reorder(common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory)
.execute(s, common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory);
}
if (reorder_common_diff_bias) {
reorder(common_diff_bias_memory, user_common_diff_bias_memory)
.execute(s, common_diff_bias_memory,
user_common_diff_bias_memory);
}
//
// User updates weights and bias using diffs
//
s.wait();
}
int main(int argc, char **argv) {
try {
simple_net(parse_engine_kind(argc, argv));
std::cout << "Simple net f32 training example passed!\n";
} catch (error &e) {
std::cerr << "status: " << e.status << std::endl;
std::cerr << "message: " << e.message << std::endl;
return 1;
}
return 0;
}