rnn_training_f32.cpp¶

This C++ API example demonstrates how to build GNMT model training. Annotated version: RNN f32 training example

This C++ API example demonstrates how to build GNMT model training. Annotated version: RNN f32 training example

/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/


#include <cstring>
#include <math.h>
#include <numeric>
#include <utility>

#include "oneapi/dnnl/dnnl.hpp"

#include "example_utils.hpp"

using namespace dnnl;

// User input is:
//     N0 sequences of length T0
const int N0 = 1 + rand() % 31;
//     N1 sequences of length T1
const int N1 = 1 + rand() % 31;
// Assume T0 > T1
const int T0 = 31 + 1 + rand() % 31;
const int T1 = 1 + rand() % 31;

// Memory required to hold it: N0 * T0 + N1 * T1
// However it is possible to have these coming
// as padded chunks in larger memory:
//      e.g. (N0 + N1) * T0
// We don't need to compact the data before processing,
// we can address the chunks via sub-memory and
// process the data via two RNN primitives:
//     of time lengths T1 and T0 - T1.
// The leftmost primitive will process N0 + N1 subsequences of length T1
// The rightmost primitive will process remaining N0 subsequences
// of T0 - T1 length
const int leftmost_batch = N0 + N1;
const int rightmost_batch = N0;

const int leftmost_seq_length = T1;
const int rightmost_seq_length = T0 - T1;

// Number of channels
const int common_feature_size = 1024;

// RNN primitive characteristics
const int common_n_layers = 1;
const int lstm_n_gates = 4;

void simple_net(engine::kind engine_kind) {
    using tag = memory::format_tag;
    using dt = memory::data_type;

    auto eng = engine(engine_kind, 0);
    stream s(eng);

    bool is_training = true;
    auto fwd_inf_train = is_training ? prop_kind::forward_training
                                     : prop_kind::forward_inference;

    std::vector<primitive> fwd_net;
    std::vector<primitive> bwd_net;

    // Input tensor holds two batches with different sequence lengths.
    // Shorter sequences are padded
    memory::dims net_src_dims = {
            T0, // time, maximum sequence length
            N0 + N1, // n, total batch size
            common_feature_size // c, common number of channels
    };

    // Two RNN primitives for different sequence lengths,
    // one unidirectional layer, LSTM-based
    memory::dims leftmost_src_layer_dims = {
            leftmost_seq_length, // time
            leftmost_batch, // n
            common_feature_size // c
    };
    memory::dims rightmost_src_layer_dims = {
            rightmost_seq_length, // time
            rightmost_batch, // n
            common_feature_size // c
    };
    memory::dims common_weights_layer_dims = {
            common_n_layers, // layers
            1, // directions
            common_feature_size, // input feature size
            lstm_n_gates, // gates number
            common_feature_size // output feature size
    };
    memory::dims common_weights_iter_dims = {
            common_n_layers, // layers
            1, // directions
            common_feature_size, // input feature size
            lstm_n_gates, // gates number
            common_feature_size // output feature size
    };
    memory::dims common_bias_dims = {
            common_n_layers, // layers
            1, // directions
            lstm_n_gates, // gates number
            common_feature_size // output feature size
    };
    memory::dims leftmost_dst_layer_dims = {
            leftmost_seq_length, // time
            leftmost_batch, // n
            common_feature_size // c
    };
    memory::dims rightmost_dst_layer_dims = {
            rightmost_seq_length, // time
            rightmost_batch, // n
            common_feature_size // c
    };

    // leftmost primitive passes its states to the next RNN iteration
    // so it needs dst_iter parameter.
    //
    // rightmost primitive will consume these as src_iter and will access the
    // memory via a sub-memory because it will have different batch dimension.
    // We have arranged our primitives so that
    // leftmost_batch >= rightmost_batch, and so the rightmost data will fit
    // into the memory allocated for the leftmost.
    memory::dims leftmost_dst_iter_dims = {
            common_n_layers, // layers
            1, // directions
            leftmost_batch, // n
            common_feature_size // c
    };
    memory::dims leftmost_dst_iter_c_dims = {
            common_n_layers, // layers
            1, // directions
            leftmost_batch, // n
            common_feature_size // c
    };
    memory::dims rightmost_src_iter_dims = {
            common_n_layers, // layers
            1, // directions
            rightmost_batch, // n
            common_feature_size // c
    };
    memory::dims rightmost_src_iter_c_dims = {
            common_n_layers, // layers
            1, // directions
            rightmost_batch, // n
            common_feature_size // c
    };

    // multiplication of tensor dimensions
    auto tz_volume = [=](memory::dims tz_dims) {
        return std::accumulate(tz_dims.begin(), tz_dims.end(), (memory::dim)1,
                std::multiplies<memory::dim>());
    };

    // Create auxiliary f32 memory descriptor
    // based on user- supplied dimensions and layout.
    auto formatted_md
            = [=](const memory::dims &dimensions, memory::format_tag layout) {
                  return memory::desc {{dimensions}, dt::f32, layout};
              };
    // Create auxiliary generic f32 memory descriptor
    // based on supplied dimensions, with format_tag::any.
    auto generic_md = [=](const memory::dims &dimensions) {
        return formatted_md(dimensions, tag::any);
    };

    //
    // I/O memory, coming from user
    //

    // Net input
    std::vector<float> net_src(tz_volume(net_src_dims), 1.0f);
    // NOTE: in this example we study input sequences with variable batch
    // dimension, which get processed by two separate RNN primitives, thus
    // the destination memory for the two will have different shapes: batch
    // is the second dimension currently: see format_tag::tnc.
    // We are not copying the output to some common user provided memory as we
    // suggest that the user should rather keep the two output memories separate
    // throughout the whole topology and only reorder to something else as
    // needed.
    // So there's no common net_dst, but there are two destinations instead:
    //    leftmost_dst_layer_memory
    //    rightmost_dst_layer_memory

    // Memory for the user allocated memory
    // Suppose user data is in tnc format.
    auto net_src_memory
            = dnnl::memory({{net_src_dims}, dt::f32, tag::tnc}, eng);
    write_to_dnnl_memory(net_src.data(), net_src_memory);
    // src_layer memory of the leftmost and rightmost RNN primitives
    // are accessed through the respective sub-memories in larger memory.
    // View primitives compute the strides to accommodate for padding.
    auto user_leftmost_src_layer_md = net_src_memory.get_desc().submemory_desc(
            leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets
    auto user_rightmost_src_layer_md
            = net_src_memory.get_desc().submemory_desc(rightmost_src_layer_dims,
                    {leftmost_seq_length, 0, 0}); // t, n, c offsets
    auto leftmost_src_layer_memory = net_src_memory;
    auto rightmost_src_layer_memory = net_src_memory;

    // Other user provided memory arrays, descriptors and primitives with the
    // data layouts chosen by user. We'll have to reorder if RNN
    // primitive prefers it in a different format.
    std::vector<float> user_common_weights_layer(
            tz_volume(common_weights_layer_dims), 1.0f);
    auto user_common_weights_layer_memory = dnnl::memory(
            {common_weights_layer_dims, dt::f32, tag::ldigo}, eng);
    write_to_dnnl_memory(
            user_common_weights_layer.data(), user_common_weights_layer_memory);

    std::vector<float> user_common_weights_iter(
            tz_volume(common_weights_iter_dims), 1.0f);
    auto user_common_weights_iter_memory = dnnl::memory(
            {{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng);
    write_to_dnnl_memory(
            user_common_weights_layer.data(), user_common_weights_iter_memory);

    std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f);
    auto user_common_bias_memory
            = dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng);
    write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory);

    std::vector<float> user_leftmost_dst_layer(
            tz_volume(leftmost_dst_layer_dims), 1.0f);
    auto user_leftmost_dst_layer_memory
            = dnnl::memory({{leftmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
    write_to_dnnl_memory(
            user_leftmost_dst_layer.data(), user_leftmost_dst_layer_memory);

    std::vector<float> user_rightmost_dst_layer(
            tz_volume(rightmost_dst_layer_dims), 1.0f);
    auto user_rightmost_dst_layer_memory = dnnl::memory(
            {{rightmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
    write_to_dnnl_memory(
            user_rightmost_dst_layer.data(), user_rightmost_dst_layer_memory);

    // Describe layer, forward pass, leftmost primitive.
    // There are no primitives to the left from here,
    // so src_iter_desc needs to be zero memory desc
    auto leftmost_prim_desc = lstm_forward::primitive_desc(eng, // engine
            fwd_inf_train, // aprop_kind
            rnn_direction::unidirectional_left2right, // direction
            user_leftmost_src_layer_md, // src_layer_desc
            memory::desc(), // src_iter_desc
            memory::desc(), // src_iter_c_desc
            generic_md(common_weights_layer_dims), // weights_layer_desc
            generic_md(common_weights_iter_dims), // weights_iter_desc
            generic_md(common_bias_dims), // bias_desc
            formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc
            generic_md(leftmost_dst_iter_dims), // dst_iter_desc
            generic_md(leftmost_dst_iter_c_dims) // dst_iter_c_desc
    );

    //
    // Need to connect leftmost and rightmost via "iter" parameters.
    // We allocate memory here based on the shapes provided by RNN primitive.
    //
    auto leftmost_dst_iter_memory
            = dnnl::memory(leftmost_prim_desc.dst_iter_desc(), eng);
    auto leftmost_dst_iter_c_memory
            = dnnl::memory(leftmost_prim_desc.dst_iter_c_desc(), eng);

    // rightmost src_iter will be a sub-memory of dst_iter of leftmost
    auto rightmost_src_iter_md
            = leftmost_dst_iter_memory.get_desc().submemory_desc(
                    rightmost_src_iter_dims,
                    {0, 0, 0, 0}); // l, d, n, c offsets
    auto rightmost_src_iter_memory = leftmost_dst_iter_memory;

    auto rightmost_src_iter_c_md
            = leftmost_dst_iter_c_memory.get_desc().submemory_desc(
                    rightmost_src_iter_c_dims,
                    {0, 0, 0, 0}); // l, d, n, c offsets
    auto rightmost_src_iter_c_memory = leftmost_dst_iter_c_memory;

    // Now rightmost primitive
    // There are no primitives to the right from here,
    // so dst_iter_desc is explicit zero memory desc
    auto rightmost_prim_desc = lstm_forward::primitive_desc(eng, // engine
            fwd_inf_train, // aprop_kind
            rnn_direction::unidirectional_left2right, // direction
            user_rightmost_src_layer_md, // src_layer_desc
            rightmost_src_iter_md, // src_iter_desc
            rightmost_src_iter_c_md, // src_iter_c_desc
            generic_md(common_weights_layer_dims), // weights_layer_desc
            generic_md(common_weights_iter_dims), // weights_iter_desc
            generic_md(common_bias_dims), // bias_desc
            formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc
            memory::desc(), // dst_iter_desc
            memory::desc() // dst_iter_c_desc
    );

    //
    // Weights and biases, layer memory
    // Same layout should work across the layer, no reorders
    // needed between leftmost and rigthmost, only reordering
    // user memory to the RNN-friendly shapes.
    //

    auto common_weights_layer_memory = user_common_weights_layer_memory;
    if (leftmost_prim_desc.weights_layer_desc()
            != common_weights_layer_memory.get_desc()) {
        common_weights_layer_memory
                = dnnl::memory(leftmost_prim_desc.weights_layer_desc(), eng);
        reorder(user_common_weights_layer_memory, common_weights_layer_memory)
                .execute(s, user_common_weights_layer_memory,
                        common_weights_layer_memory);
    }

    auto common_weights_iter_memory = user_common_weights_iter_memory;
    if (leftmost_prim_desc.weights_iter_desc()
            != common_weights_iter_memory.get_desc()) {
        common_weights_iter_memory
                = dnnl::memory(leftmost_prim_desc.weights_iter_desc(), eng);
        reorder(user_common_weights_iter_memory, common_weights_iter_memory)
                .execute(s, user_common_weights_iter_memory,
                        common_weights_iter_memory);
    }

    auto common_bias_memory = user_common_bias_memory;
    if (leftmost_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
        common_bias_memory = dnnl::memory(leftmost_prim_desc.bias_desc(), eng);
        reorder(user_common_bias_memory, common_bias_memory)
                .execute(s, user_common_bias_memory, common_bias_memory);
    }

    //
    // Destination layer memory
    //

    auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory;
    if (leftmost_prim_desc.dst_layer_desc()
            != leftmost_dst_layer_memory.get_desc()) {
        leftmost_dst_layer_memory
                = dnnl::memory(leftmost_prim_desc.dst_layer_desc(), eng);
        reorder(user_leftmost_dst_layer_memory, leftmost_dst_layer_memory)
                .execute(s, user_leftmost_dst_layer_memory,
                        leftmost_dst_layer_memory);
    }

    auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory;
    if (rightmost_prim_desc.dst_layer_desc()
            != rightmost_dst_layer_memory.get_desc()) {
        rightmost_dst_layer_memory
                = dnnl::memory(rightmost_prim_desc.dst_layer_desc(), eng);
        reorder(user_rightmost_dst_layer_memory, rightmost_dst_layer_memory)
                .execute(s, user_rightmost_dst_layer_memory,
                        rightmost_dst_layer_memory);
    }

    // We also create workspace memory based on the information from
    // the workspace_primitive_desc(). This is needed for internal
    // communication between forward and backward primitives during
    // training.
    auto create_ws = [=](dnnl::lstm_forward::primitive_desc &pd) {
        return dnnl::memory(pd.workspace_desc(), eng);
    };
    auto leftmost_workspace_memory = create_ws(leftmost_prim_desc);
    auto rightmost_workspace_memory = create_ws(rightmost_prim_desc);

    // Construct the RNN primitive objects
    lstm_forward leftmost_layer(leftmost_prim_desc);
    leftmost_layer.execute(s,
            {{DNNL_ARG_SRC_LAYER, leftmost_src_layer_memory},
                    {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory},
                    {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory},
                    {DNNL_ARG_BIAS, common_bias_memory},
                    {DNNL_ARG_DST_LAYER, leftmost_dst_layer_memory},
                    {DNNL_ARG_DST_ITER, leftmost_dst_iter_memory},
                    {DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_memory},
                    {DNNL_ARG_WORKSPACE, leftmost_workspace_memory}});

    lstm_forward rightmost_layer(rightmost_prim_desc);
    rightmost_layer.execute(s,
            {{DNNL_ARG_SRC_LAYER, rightmost_src_layer_memory},
                    {DNNL_ARG_SRC_ITER, rightmost_src_iter_memory},
                    {DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory},
                    {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory},
                    {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory},
                    {DNNL_ARG_BIAS, common_bias_memory},
                    {DNNL_ARG_DST_LAYER, rightmost_dst_layer_memory},
                    {DNNL_ARG_WORKSPACE, rightmost_workspace_memory}});

    // No backward pass for inference
    if (!is_training) return;

    //
    // Backward primitives will reuse memory from forward
    // and allocate/describe specifics here. Only relevant for training.
    //

    // User-provided memory for backward by data output
    std::vector<float> net_diff_src(tz_volume(net_src_dims), 1.0f);
    auto net_diff_src_memory
            = dnnl::memory(formatted_md(net_src_dims, tag::tnc), eng);
    write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory);

    // diff_src follows the same layout we have for net_src
    auto user_leftmost_diff_src_layer_md
            = net_diff_src_memory.get_desc().submemory_desc(
                    leftmost_src_layer_dims, {0, 0, 0}); // t, n, c offsets
    auto user_rightmost_diff_src_layer_md
            = net_diff_src_memory.get_desc().submemory_desc(
                    rightmost_src_layer_dims,
                    {leftmost_seq_length, 0, 0}); // t, n, c offsets
    auto leftmost_diff_src_layer_memory = net_diff_src_memory;
    auto rightmost_diff_src_layer_memory = net_diff_src_memory;

    // User-provided memory for backpropagation by weights
    std::vector<float> user_common_diff_weights_layer(
            tz_volume(common_weights_layer_dims), 1.0f);
    auto user_common_diff_weights_layer_memory = dnnl::memory(
            formatted_md(common_weights_layer_dims, tag::ldigo), eng);
    write_to_dnnl_memory(user_common_diff_weights_layer.data(),
            user_common_diff_weights_layer_memory);

    std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f);
    auto user_common_diff_bias_memory
            = dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng);
    write_to_dnnl_memory(
            user_common_diff_bias.data(), user_common_diff_bias_memory);

    // User-provided input to the backward primitive.
    // To be updated by the user after forward pass using some cost function.
    memory::dims net_diff_dst_dims = {
            T0, // time
            N0 + N1, // n
            common_feature_size // c
    };
    // Suppose user data is in tnc format.
    std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f);
    auto net_diff_dst_memory
            = dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng);
    write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory);
    // diff_dst_layer memory of the leftmost and rightmost RNN primitives
    // are accessed through the respective sub-memory in larger memory.
    // View primitives compute the strides to accommodate for padding.
    auto user_leftmost_diff_dst_layer_md
            = net_diff_dst_memory.get_desc().submemory_desc(
                    leftmost_dst_layer_dims, {0, 0, 0}); // t, n, c offsets
    auto user_rightmost_diff_dst_layer_md
            = net_diff_dst_memory.get_desc().submemory_desc(
                    rightmost_dst_layer_dims,
                    {leftmost_seq_length, 0, 0}); // t, n, c offsets
    auto leftmost_diff_dst_layer_memory = net_diff_dst_memory;
    auto rightmost_diff_dst_layer_memory = net_diff_dst_memory;

    // Backward leftmost primitive descriptor
    auto leftmost_bwd_prim_desc = lstm_backward::primitive_desc(eng, // engine
            prop_kind::backward, // aprop_kind
            rnn_direction::unidirectional_left2right, // direction
            user_leftmost_src_layer_md, // src_layer_desc
            memory::desc(), // src_iter_desc
            memory::desc(), // src_iter_c_desc
            generic_md(common_weights_layer_dims), // weights_layer_desc
            generic_md(common_weights_iter_dims), // weights_iter_desc
            generic_md(common_bias_dims), // bias_desc
            formatted_md(leftmost_dst_layer_dims, tag::tnc), // dst_layer_desc
            generic_md(leftmost_dst_iter_dims), // dst_iter_desc
            generic_md(leftmost_dst_iter_c_dims), // dst_iter_c_desc
            user_leftmost_diff_src_layer_md, // diff_src_layer_desc
            memory::desc(), // diff_src_iter_desc
            memory::desc(), // diff_src_iter_c_desc
            generic_md(common_weights_layer_dims), // diff_weights_layer_desc
            generic_md(common_weights_iter_dims), // diff_weights_iter_desc
            generic_md(common_bias_dims), // diff_bias_desc
            user_leftmost_diff_dst_layer_md, // diff_dst_layer_desc
            generic_md(leftmost_dst_iter_dims), // diff_dst_iter_desc
            generic_md(leftmost_dst_iter_c_dims), // diff_dst_iter_c_desc
            leftmost_prim_desc // hint from forward pass
    );

    // As the batch dimensions are different between leftmost and rightmost
    // we need to use a sub-memory. rightmost needs less memory, so it will
    // be a sub-memory of leftmost.
    auto leftmost_diff_dst_iter_memory
            = dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_desc(), eng);
    auto leftmost_diff_dst_iter_c_memory
            = dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_c_desc(), eng);

    auto rightmost_diff_src_iter_md
            = leftmost_diff_dst_iter_memory.get_desc().submemory_desc(
                    rightmost_src_iter_dims,
                    {0, 0, 0, 0}); // l, d, n, c offsets
    auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory;

    auto rightmost_diff_src_iter_c_md
            = leftmost_diff_dst_iter_c_memory.get_desc().submemory_desc(
                    rightmost_src_iter_c_dims,
                    {0, 0, 0, 0}); // l, d, n, c offsets
    auto rightmost_diff_src_iter_c_memory = leftmost_diff_dst_iter_c_memory;

    // Backward rightmost primitive descriptor
    auto rightmost_bwd_prim_desc = lstm_backward::primitive_desc(eng, // engine
            prop_kind::backward, // aprop_kind
            rnn_direction::unidirectional_left2right, // direction
            user_rightmost_src_layer_md, // src_layer_desc
            generic_md(rightmost_src_iter_dims), // src_iter_desc
            generic_md(rightmost_src_iter_c_dims), // src_iter_c_desc
            generic_md(common_weights_layer_dims), // weights_layer_desc
            generic_md(common_weights_iter_dims), // weights_iter_desc
            generic_md(common_bias_dims), // bias_desc
            formatted_md(rightmost_dst_layer_dims, tag::tnc), // dst_layer_desc
            memory::desc(), // dst_iter_desc
            memory::desc(), // dst_iter_c_desc
            user_rightmost_diff_src_layer_md, // diff_src_layer_desc
            rightmost_diff_src_iter_md, // diff_src_iter_desc
            rightmost_diff_src_iter_c_md, // diff_src_iter_c_desc
            generic_md(common_weights_layer_dims), // diff_weights_layer_desc
            generic_md(common_weights_iter_dims), // diff_weights_iter_desc
            generic_md(common_bias_dims), // diff_bias_desc
            user_rightmost_diff_dst_layer_md, // diff_dst_layer_desc
            memory::desc(), // diff_dst_iter_desc
            memory::desc(), // diff_dst_iter_c_desc
            rightmost_prim_desc // hint from forward pass
    );

    //
    // Memory for backward pass
    //

    // src layer uses the same memory as forward
    auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory;
    auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory;

    // Memory for weights and biases for backward pass
    // Try to use the same memory between forward and backward, but
    // sometimes reorders are needed.
    auto common_weights_layer_bwd_memory = common_weights_layer_memory;
    if (leftmost_bwd_prim_desc.weights_layer_desc()
            != leftmost_prim_desc.weights_layer_desc()) {
        common_weights_layer_bwd_memory
                = memory(leftmost_bwd_prim_desc.weights_layer_desc(), eng);
        reorder(common_weights_layer_memory, common_weights_layer_bwd_memory)
                .execute(s, common_weights_layer_memory,
                        common_weights_layer_bwd_memory);
    }

    auto common_weights_iter_bwd_memory = common_weights_iter_memory;
    if (leftmost_bwd_prim_desc.weights_iter_desc()
            != leftmost_prim_desc.weights_iter_desc()) {
        common_weights_iter_bwd_memory
                = memory(leftmost_bwd_prim_desc.weights_iter_desc(), eng);
        reorder(common_weights_iter_memory, common_weights_iter_bwd_memory)
                .execute(s, common_weights_iter_memory,
                        common_weights_iter_bwd_memory);
    }

    auto common_bias_bwd_memory = common_bias_memory;
    if (leftmost_bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
        common_bias_bwd_memory
                = dnnl::memory(leftmost_bwd_prim_desc.bias_desc(), eng);
        reorder(common_bias_memory, common_bias_bwd_memory)
                .execute(s, common_bias_memory, common_bias_bwd_memory);
    }

    // diff_weights and biases
    auto common_diff_weights_layer_memory
            = user_common_diff_weights_layer_memory;
    auto reorder_common_diff_weights_layer = false;
    if (leftmost_bwd_prim_desc.diff_weights_layer_desc()
            != common_diff_weights_layer_memory.get_desc()) {
        common_diff_weights_layer_memory = dnnl::memory(
                leftmost_bwd_prim_desc.diff_weights_layer_desc(), eng);
        reorder_common_diff_weights_layer = true;
    }

    auto common_diff_bias_memory = user_common_diff_bias_memory;
    auto reorder_common_diff_bias = false;
    if (leftmost_bwd_prim_desc.diff_bias_desc()
            != common_diff_bias_memory.get_desc()) {
        common_diff_bias_memory
                = dnnl::memory(leftmost_bwd_prim_desc.diff_bias_desc(), eng);
        reorder_common_diff_bias = true;
    }

    // dst_layer memory for backward pass
    auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory;
    if (leftmost_bwd_prim_desc.dst_layer_desc()
            != leftmost_dst_layer_bwd_memory.get_desc()) {
        leftmost_dst_layer_bwd_memory
                = dnnl::memory(leftmost_bwd_prim_desc.dst_layer_desc(), eng);
        reorder(leftmost_dst_layer_memory, leftmost_dst_layer_bwd_memory)
                .execute(s, leftmost_dst_layer_memory,
                        leftmost_dst_layer_bwd_memory);
    }

    auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory;
    if (rightmost_bwd_prim_desc.dst_layer_desc()
            != rightmost_dst_layer_bwd_memory.get_desc()) {
        rightmost_dst_layer_bwd_memory
                = dnnl::memory(rightmost_bwd_prim_desc.dst_layer_desc(), eng);
        reorder(rightmost_dst_layer_memory, rightmost_dst_layer_bwd_memory)
                .execute(s, rightmost_dst_layer_memory,
                        rightmost_dst_layer_bwd_memory);
    }

    // Similar to forward, the backward primitives are connected
    // via "iter" parameters.
    auto common_diff_weights_iter_memory = dnnl::memory(
            leftmost_bwd_prim_desc.diff_weights_iter_desc(), eng);

    auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory;
    if (leftmost_bwd_prim_desc.dst_iter_desc()
            != leftmost_dst_iter_bwd_memory.get_desc()) {
        leftmost_dst_iter_bwd_memory
                = dnnl::memory(leftmost_bwd_prim_desc.dst_iter_desc(), eng);
        reorder(leftmost_dst_iter_memory, leftmost_dst_iter_bwd_memory)
                .execute(s, leftmost_dst_iter_memory,
                        leftmost_dst_iter_bwd_memory);
    }

    auto leftmost_dst_iter_c_bwd_memory = leftmost_dst_iter_c_memory;
    if (leftmost_bwd_prim_desc.dst_iter_c_desc()
            != leftmost_dst_iter_c_bwd_memory.get_desc()) {
        leftmost_dst_iter_c_bwd_memory
                = dnnl::memory(leftmost_bwd_prim_desc.dst_iter_c_desc(), eng);
        reorder(leftmost_dst_iter_c_memory, leftmost_dst_iter_c_bwd_memory)
                .execute(s, leftmost_dst_iter_c_memory,
                        leftmost_dst_iter_c_bwd_memory);
    }

    // Construct the RNN primitive objects for backward
    lstm_backward rightmost_layer_bwd(rightmost_bwd_prim_desc);
    rightmost_layer_bwd.execute(s,
            {{DNNL_ARG_SRC_LAYER, rightmost_src_layer_bwd_memory},
                    {DNNL_ARG_SRC_ITER, rightmost_src_iter_memory},
                    {DNNL_ARG_SRC_ITER_C, rightmost_src_iter_c_memory},
                    {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory},
                    {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory},
                    {DNNL_ARG_BIAS, common_bias_bwd_memory},
                    {DNNL_ARG_DST_LAYER, rightmost_dst_layer_bwd_memory},
                    {DNNL_ARG_DIFF_SRC_LAYER, rightmost_diff_src_layer_memory},
                    {DNNL_ARG_DIFF_SRC_ITER, rightmost_diff_src_iter_memory},
                    {DNNL_ARG_DIFF_SRC_ITER_C,
                            rightmost_diff_src_iter_c_memory},
                    {DNNL_ARG_DIFF_WEIGHTS_LAYER,
                            common_diff_weights_layer_memory},
                    {DNNL_ARG_DIFF_WEIGHTS_ITER,
                            common_diff_weights_iter_memory},
                    {DNNL_ARG_DIFF_BIAS, common_diff_bias_memory},
                    {DNNL_ARG_DIFF_DST_LAYER, rightmost_diff_dst_layer_memory},
                    {DNNL_ARG_WORKSPACE, rightmost_workspace_memory}});

    lstm_backward leftmost_layer_bwd(leftmost_bwd_prim_desc);
    leftmost_layer_bwd.execute(s,
            {{DNNL_ARG_SRC_LAYER, leftmost_src_layer_bwd_memory},
                    {DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory},
                    {DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory},
                    {DNNL_ARG_BIAS, common_bias_bwd_memory},
                    {DNNL_ARG_DST_LAYER, leftmost_dst_layer_bwd_memory},
                    {DNNL_ARG_DST_ITER, leftmost_dst_iter_bwd_memory},
                    {DNNL_ARG_DST_ITER_C, leftmost_dst_iter_c_bwd_memory},
                    {DNNL_ARG_DIFF_SRC_LAYER, leftmost_diff_src_layer_memory},
                    {DNNL_ARG_DIFF_WEIGHTS_LAYER,
                            common_diff_weights_layer_memory},
                    {DNNL_ARG_DIFF_WEIGHTS_ITER,
                            common_diff_weights_iter_memory},
                    {DNNL_ARG_DIFF_BIAS, common_diff_bias_memory},
                    {DNNL_ARG_DIFF_DST_LAYER, leftmost_diff_dst_layer_memory},
                    {DNNL_ARG_DIFF_DST_ITER, leftmost_diff_dst_iter_memory},
                    {DNNL_ARG_DIFF_DST_ITER_C, leftmost_diff_dst_iter_c_memory},
                    {DNNL_ARG_WORKSPACE, leftmost_workspace_memory}});
    if (reorder_common_diff_weights_layer) {
        reorder(common_diff_weights_layer_memory,
                user_common_diff_weights_layer_memory)
                .execute(s, common_diff_weights_layer_memory,
                        user_common_diff_weights_layer_memory);
    }

    if (reorder_common_diff_bias) {
        reorder(common_diff_bias_memory, user_common_diff_bias_memory)
                .execute(s, common_diff_bias_memory,
                        user_common_diff_bias_memory);
    }

    //
    // User updates weights and bias using diffs
    //

    s.wait();
}

int main(int argc, char **argv) {
    return handle_example_errors(simple_net, parse_engine_kind(argc, argv));
}