RNN f32 inference exampleΒΆ

This C++ API example demonstrates how to build GNMT model inference.

This C++ API example demonstrates how to build GNMT model inference.

Example code: cpu_rnn_inference_f32.cpp

For the encoder we use:

  • one primitive for the bidirectional layer of the encoder

  • one primitive for all remaining unidirectional layers in the encoder For the decoder we use:

  • one primitive for the first iteration

  • one primitive for all subsequent iterations in the decoder. Note that in this example, this primitive computes the states in place.

  • the attention mechanism is implemented separately as there is no support for the context vectors in oneDNN yet

Initialize a CPU engine and stream. The last parameter in the call represents the index of the engine.

auto cpu_engine = engine(engine::kind::cpu, 0);
stream s(cpu_engine);

Declare encoder net and decoder net

std::vector<primitive> encoder_net, decoder_net;
std::vector<std::unordered_map<int, memory>> encoder_net_args,
        decoder_net_args;

std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f);
std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f);

Encoder

Initialize Encoder Memory

memory::dims enc_bidir_src_layer_tz
        = {src_seq_length_max, batch, feature_size};
memory::dims enc_bidir_weights_layer_tz
        = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size};
memory::dims enc_bidir_weights_iter_tz
        = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size};
memory::dims enc_bidir_bias_tz
        = {enc_bidir_n_layers, 2, lstm_n_gates, feature_size};
memory::dims enc_bidir_dst_layer_tz
        = {src_seq_length_max, batch, 2 * feature_size};

Encoder: 1 bidirectional layer and 7 unidirectional layers

Create the memory for user data

auto user_enc_bidir_src_layer_md = dnnl::memory::desc(
        {enc_bidir_src_layer_tz}, dnnl::memory::data_type::f32,
        dnnl::memory::format_tag::tnc);

auto user_enc_bidir_wei_layer_md = dnnl::memory::desc(
        {enc_bidir_weights_layer_tz}, dnnl::memory::data_type::f32,
        dnnl::memory::format_tag::ldigo);

auto user_enc_bidir_wei_iter_md = dnnl::memory::desc(
        {enc_bidir_weights_iter_tz}, dnnl::memory::data_type::f32,
        dnnl::memory::format_tag::ldigo);

auto user_enc_bidir_bias_md = dnnl::memory::desc({enc_bidir_bias_tz},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);

auto user_enc_bidir_src_layer_memory = dnnl::memory(
        user_enc_bidir_src_layer_md, cpu_engine, net_src.data());
auto user_enc_bidir_wei_layer_memory
        = dnnl::memory(user_enc_bidir_wei_layer_md, cpu_engine,
                user_enc_bidir_wei_layer.data());
auto user_enc_bidir_wei_iter_memory
        = dnnl::memory(user_enc_bidir_wei_iter_md, cpu_engine,
                user_enc_bidir_wei_iter.data());
auto user_enc_bidir_bias_memory = dnnl::memory(
        user_enc_bidir_bias_md, cpu_engine, user_enc_bidir_bias.data());

Create memory descriptors for RNN data w/o specified layout

auto enc_bidir_wei_layer_md = memory::desc({enc_bidir_weights_layer_tz},
        memory::data_type::f32, memory::format_tag::any);

auto enc_bidir_wei_iter_md = memory::desc({enc_bidir_weights_iter_tz},
        memory::data_type::f32, memory::format_tag::any);

auto enc_bidir_dst_layer_md = memory::desc({enc_bidir_dst_layer_tz},
        memory::data_type::f32, memory::format_tag::any);

Create bidirectional RNN

auto enc_bidir_prim_desc = lstm_forward::primitive_desc(cpu_engine,
        prop_kind::forward_inference, rnn_direction::bidirectional_concat,
        user_enc_bidir_src_layer_md, memory::desc(), memory::desc(),
        enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
        user_enc_bidir_bias_md, enc_bidir_dst_layer_md, memory::desc(),
        memory::desc());

Create memory for input data and use reorders to reorder user data to internal representation

auto enc_bidir_wei_layer_memory
        = memory(enc_bidir_prim_desc.weights_layer_desc(), cpu_engine);
auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
        user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory);
reorder(enc_bidir_wei_layer_reorder_pd)
        .execute(s, user_enc_bidir_wei_layer_memory,
                enc_bidir_wei_layer_memory);

Encoder : add the bidirectional rnn primitive with related arguments into encoder_net

encoder_net.push_back(lstm_forward(enc_bidir_prim_desc));
encoder_net_args.push_back(
        {{DNNL_ARG_SRC_LAYER, user_enc_bidir_src_layer_memory},
                {DNNL_ARG_WEIGHTS_LAYER, enc_bidir_wei_layer_memory},
                {DNNL_ARG_WEIGHTS_ITER, enc_bidir_wei_iter_memory},
                {DNNL_ARG_BIAS, user_enc_bidir_bias_memory},
                {DNNL_ARG_DST_LAYER, enc_bidir_dst_layer_memory}});

Encoder: unidirectional layers

First unidirectinal layer scales 2 * feature_size output of bidirectional layer to feature_size output

std::vector<float> user_enc_uni_first_wei_layer(
        1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
std::vector<float> user_enc_uni_first_wei_iter(
        1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f);
std::vector<float> user_enc_uni_first_bias(
        1 * 1 * lstm_n_gates * feature_size, 1.0f);

Encoder : Create unidirection RNN for first cell

auto enc_uni_first_prim_desc = lstm_forward::primitive_desc(cpu_engine,
        prop_kind::forward_inference,
        rnn_direction::unidirectional_left2right, enc_bidir_dst_layer_md,
        memory::desc(), memory::desc(), enc_uni_first_wei_layer_md,
        enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
        enc_uni_first_dst_layer_md, memory::desc(), memory::desc());

Encoder : add the first unidirectional rnn primitive with related arguments into encoder_net

// TODO: add a reorder when they will be available
encoder_net.push_back(lstm_forward(enc_uni_first_prim_desc));
encoder_net_args.push_back(
        {{DNNL_ARG_SRC_LAYER, enc_bidir_dst_layer_memory},
                {DNNL_ARG_WEIGHTS_LAYER, enc_uni_first_wei_layer_memory},
                {DNNL_ARG_WEIGHTS_ITER, enc_uni_first_wei_iter_memory},
                {DNNL_ARG_BIAS, user_enc_uni_first_bias_memory},
                {DNNL_ARG_DST_LAYER, enc_uni_first_dst_layer_memory}});

Encoder : Remaining unidirectional layers

std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
                * feature_size * lstm_n_gates * feature_size,
        1.0f);
std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
                * feature_size * lstm_n_gates * feature_size,
        1.0f);
std::vector<float> user_enc_uni_bias(
        (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);

Encoder : Create unidirection RNN cell

auto enc_uni_prim_desc = lstm_forward::primitive_desc(cpu_engine,
        prop_kind::forward_inference,
        rnn_direction::unidirectional_left2right,
        enc_uni_first_dst_layer_md, memory::desc(), memory::desc(),
        enc_uni_wei_layer_md, enc_uni_wei_iter_md, user_enc_uni_bias_md,
        enc_dst_layer_md, memory::desc(), memory::desc());

Encoder : add the unidirectional rnn primitive with related arguments into encoder_net

encoder_net.push_back(lstm_forward(enc_uni_prim_desc));
encoder_net_args.push_back(
        {{DNNL_ARG_SRC_LAYER, enc_uni_first_dst_layer_memory},
                {DNNL_ARG_WEIGHTS_LAYER, enc_uni_wei_layer_memory},
                {DNNL_ARG_WEIGHTS_ITER, enc_uni_wei_iter_memory},
                {DNNL_ARG_BIAS, user_enc_uni_bias_memory},
                {DNNL_ARG_DST_LAYER, enc_dst_layer_memory}});

Decoder with attention mechanism

Decoder : declare memory dimensions

std::vector<float> user_dec_wei_layer(
        dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
        1.0f);
std::vector<float> user_dec_wei_iter(dec_n_layers * 1
                * (feature_size + feature_size) * lstm_n_gates
                * feature_size,
        1.0f);
std::vector<float> user_dec_bias(
        dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
std::vector<float> user_dec_dst(
        tgt_seq_length_max * batch * feature_size, 1.0f);
std::vector<float> user_weights_attention_src_layer(
        feature_size * feature_size, 1.0f);
std::vector<float> user_weights_annotation(
        feature_size * feature_size, 1.0f);
std::vector<float> user_weights_alignments(feature_size, 1.0f);

memory::dims user_dec_wei_layer_dims
        = {dec_n_layers, 1, feature_size, lstm_n_gates, feature_size};
memory::dims user_dec_wei_iter_dims = {dec_n_layers, 1,
        feature_size + feature_size, lstm_n_gates, feature_size};
memory::dims user_dec_bias_dims
        = {dec_n_layers, 1, lstm_n_gates, feature_size};

memory::dims dec_src_layer_dims = {1, batch, feature_size};
memory::dims dec_dst_layer_dims = {1, batch, feature_size};
memory::dims dec_dst_iter_c_dims = {dec_n_layers, 1, batch, feature_size};

We will use the same memory for dec_src_iter and dec_dst_iter However, dec_src_iter has a context vector but not dec_dst_iter. To resolve this we will create one memory that holds the context vector as well as the both the hidden and cell states. The dst_iter will be a sub-memory of this memory. Note that the cell state will be padded by feature_size values. However, we do not compute or access those.

memory::dims dec_dst_iter_dims
        = {dec_n_layers, 1, batch, feature_size + feature_size};
memory::dims dec_dst_iter_noctx_dims
        = {dec_n_layers, 1, batch, feature_size};

Decoder : create memory description

auto user_dec_wei_layer_md = dnnl::memory::desc({user_dec_wei_layer_dims},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo);
auto user_dec_wei_iter_md = dnnl::memory::desc({user_dec_wei_iter_dims},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo);
auto user_dec_bias_md = dnnl::memory::desc({user_dec_bias_dims},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);
auto dec_dst_layer_md = dnnl::memory::desc({dec_dst_layer_dims},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc);
auto dec_src_layer_md = dnnl::memory::desc({dec_src_layer_dims},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc);
auto dec_dst_iter_md = dnnl::memory::desc({dec_dst_iter_dims},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc);
auto dec_dst_iter_c_md = dnnl::memory::desc({dec_dst_iter_c_dims},
        dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc);

Decoder : Create memory

auto user_dec_wei_layer_memory = dnnl::memory(
        user_dec_wei_layer_md, cpu_engine, user_dec_wei_layer.data());
auto user_dec_wei_iter_memory = dnnl::memory(
        user_dec_wei_iter_md, cpu_engine, user_dec_wei_iter.data());
auto user_dec_bias_memory
        = dnnl::memory(user_dec_bias_md, cpu_engine, user_dec_bias.data());
auto user_dec_dst_layer_memory
        = dnnl::memory(dec_dst_layer_md, cpu_engine, user_dec_dst.data());
auto dec_src_layer_memory = dnnl::memory(dec_src_layer_md, cpu_engine);
auto dec_dst_iter_c_memory = dnnl::memory(dec_dst_iter_c_md, cpu_engine);

Decoder : As mentioned above, we create a view without context out of the memory with context.

auto dec_dst_iter_memory = dnnl::memory(dec_dst_iter_md, cpu_engine);
auto dec_dst_iter_noctx_md = dec_dst_iter_md.submemory_desc(
        dec_dst_iter_noctx_dims, {0, 0, 0, 0, 0});

Decoder : Create RNN decoder cell

auto dec_ctx_prim_desc = lstm_forward::primitive_desc(cpu_engine,
        prop_kind::forward_inference,
        rnn_direction::unidirectional_left2right, dec_src_layer_md,
        dec_dst_iter_md, dec_dst_iter_c_md, dec_wei_layer_md,
        dec_wei_iter_md, user_dec_bias_md, dec_dst_layer_md,
        dec_dst_iter_noctx_md, dec_dst_iter_c_md);

Decoder : reorder weight memory

auto dec_wei_layer_memory
        = memory(dec_ctx_prim_desc.weights_layer_desc(), cpu_engine);
auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
        user_dec_wei_layer_memory, dec_wei_layer_memory);
reorder(dec_wei_layer_reorder_pd)
        .execute(s, user_dec_wei_layer_memory, dec_wei_layer_memory);

auto dec_wei_iter_memory
        = memory(dec_ctx_prim_desc.weights_iter_desc(), cpu_engine);
auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
        user_dec_wei_iter_memory, dec_wei_iter_memory);
reorder(dec_wei_iter_reorder_pd)
        .execute(s, user_dec_wei_iter_memory, dec_wei_iter_memory);

Decoder : add the rnn primitive with related arguments into decoder_net

// TODO: add a reorder when they will be available
decoder_net.push_back(lstm_forward(dec_ctx_prim_desc));
decoder_net_args.push_back({{DNNL_ARG_SRC_LAYER, dec_src_layer_memory},
        {DNNL_ARG_SRC_ITER, dec_dst_iter_memory},
        {DNNL_ARG_SRC_ITER_C, dec_dst_iter_c_memory},
        {DNNL_ARG_WEIGHTS_LAYER, dec_wei_layer_memory},
        {DNNL_ARG_WEIGHTS_ITER, dec_wei_iter_memory},
        {DNNL_ARG_BIAS, user_dec_bias_memory},
        {DNNL_ARG_DST_LAYER, user_dec_dst_layer_memory},
        {DNNL_ARG_DST_ITER, dec_dst_iter_memory},
        {DNNL_ARG_DST_ITER_C, dec_dst_iter_c_memory}});

Execution

run encoder (1 stream)

for (size_t p = 0; p < encoder_net.size(); ++p)
    encoder_net.at(p).execute(s, encoder_net_args.at(p));

we compute the weighted annotations once before the decoder

compute_weighted_annotations(weighted_annotations.data(),
        src_seq_length_max, batch, feature_size,
        user_weights_annotation.data(),
        (float *)enc_dst_layer_memory.get_data_handle());

We initialize src_layer to the embedding of the end of sequence character, which are assumed to be 0 here

memset(dec_src_layer_memory.get_data_handle(), 0,
        dec_src_layer_memory.get_desc().get_size());

From now on, src points to the output of the last iteration

Compute attention context vector into the first layer src_iter

compute_attention(src_att_iter_handle, src_seq_length_max, batch,
        feature_size, user_weights_attention_src_layer.data(),
        src_att_layer_handle,
        (float *)enc_bidir_dst_layer_memory.get_data_handle(),
        weighted_annotations.data(),
        user_weights_alignments.data());

copy the context vectors to all layers of src_iter

copy_context(
        src_att_iter_handle, dec_n_layers, batch, feature_size);

run the decoder iteration

for (size_t p = 0; p < decoder_net.size(); ++p)
    decoder_net.at(p).execute(s, decoder_net_args.at(p));

Move the handle on the src/dst layer to the next iteration

auto dst_layer_handle
        = (float *)user_dec_dst_layer_memory.get_data_handle();
dec_src_layer_memory.set_data_handle(dst_layer_handle);
user_dec_dst_layer_memory.set_data_handle(
        dst_layer_handle + batch * feature_size);