This C++ API example demonstrates how to build GNMT model training.
#include <cstring>
#include <iostream>
#include <math.h>
#include <numeric>
#include <string>
const int N0 = 1 + rand() % 31;
const int N1 = 1 + rand() % 31;
const int T0 = 31 + 1 + rand() % 31;
const int T1 = 1 + rand() % 31;
const int leftmost_batch = N0 + N1;
const int rightmost_batch = N0;
const int leftmost_seq_length = T1;
const int rightmost_seq_length = T0 - T1;
const int common_feature_size = 1024;
const int common_n_layers = 1;
const int lstm_n_gates = 4;
void simple_net() {
bool is_training = true;
std::vector<primitive> fwd_net;
std::vector<primitive> bwd_net;
memory::dims net_src_dims = {
T0,
N0 + N1,
common_feature_size
};
memory::dims leftmost_src_layer_dims = {
leftmost_seq_length,
leftmost_batch,
common_feature_size
};
memory::dims rightmost_src_layer_dims = {
rightmost_seq_length,
rightmost_batch,
common_feature_size
};
memory::dims common_weights_layer_dims = {
common_n_layers,
1,
common_feature_size,
lstm_n_gates,
common_feature_size
};
memory::dims common_weights_iter_dims = {
common_n_layers,
1,
common_feature_size,
lstm_n_gates,
common_feature_size
};
memory::dims common_bias_dims = {
common_n_layers,
1,
lstm_n_gates,
common_feature_size
};
memory::dims leftmost_dst_layer_dims = {
leftmost_seq_length,
leftmost_batch,
common_feature_size
};
memory::dims rightmost_dst_layer_dims = {
rightmost_seq_length,
rightmost_batch,
common_feature_size
};
memory::dims leftmost_dst_iter_dims = {
common_n_layers,
1,
leftmost_batch,
common_feature_size
};
memory::dims leftmost_dst_iter_c_dims = {
common_n_layers,
1,
leftmost_batch,
common_feature_size
};
memory::dims rightmost_src_iter_dims = {
common_n_layers,
1,
rightmost_batch,
common_feature_size
};
memory::dims rightmost_src_iter_c_dims = {
common_n_layers,
1,
rightmost_batch,
common_feature_size
};
auto tz_volume = [=](memory::dims tz_dims) {
return std::accumulate(tz_dims.begin(), tz_dims.end(), (memory::dim)1,
std::multiplies<memory::dim>());
};
auto formatted_md
};
auto generic_md = [=](memory::dims dimensions) {
return formatted_md(dimensions, tag::any);
};
std::vector<float> net_src(tz_volume(net_src_dims), 1.0f);
auto net_src_memory
cpu_engine, net_src.data());
leftmost_src_layer_dims, { 0, 0, 0 });
auto user_rightmost_src_layer_md
{ leftmost_seq_length, 0, 0 });
auto leftmost_src_layer_memory = net_src_memory;
auto rightmost_src_layer_memory = net_src_memory;
std::vector<float> user_common_weights_layer(
tz_volume(common_weights_layer_dims), 1.0f);
auto user_common_weights_layer_memory
cpu_engine, user_common_weights_layer.data());
std::vector<float> user_common_weights_iter(
tz_volume(common_weights_iter_dims), 1.0f);
{ { common_weights_iter_dims }, dt::f32, tag::ldigo }, cpu_engine,
user_common_weights_layer.data());
std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_bias_memory
cpu_engine, user_common_bias.data());
std::vector<float> user_leftmost_dst_layer(
tz_volume(leftmost_dst_layer_dims), 1.0f);
auto user_leftmost_dst_layer_memory
cpu_engine, user_leftmost_dst_layer.data());
std::vector<float> user_rightmost_dst_layer(
tz_volume(rightmost_dst_layer_dims), 1.0f);
{ { rightmost_dst_layer_dims }, dt::f32, tag::tnc }, cpu_engine,
user_rightmost_dst_layer.data());
fwd_inf_train,
rnn_direction::unidirectional_left2right,
user_leftmost_src_layer_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(leftmost_dst_layer_dims, tag::tnc),
generic_md(leftmost_dst_iter_dims),
generic_md(leftmost_dst_iter_c_dims)
);
leftmost_layer_desc, cpu_engine);
auto leftmost_dst_iter_memory
auto leftmost_dst_iter_c_memory
auto rightmost_src_iter_md
rightmost_src_iter_dims,
{ 0, 0, 0, 0 });
auto rightmost_src_iter_memory = leftmost_dst_iter_memory;
auto rightmost_src_iter_c_md
rightmost_src_iter_c_dims,
{ 0, 0, 0, 0 });
auto rightmost_src_iter_c_memory = leftmost_dst_iter_c_memory;
fwd_inf_train,
rnn_direction::unidirectional_left2right,
user_rightmost_src_layer_md,
rightmost_src_iter_md,
rightmost_src_iter_c_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(rightmost_dst_layer_dims, tag::tnc),
);
rightmost_layer_desc, cpu_engine);
auto common_weights_layer_memory = user_common_weights_layer_memory;
if (leftmost_prim_desc.weights_layer_desc()
!= common_weights_layer_memory.get_desc()) {
leftmost_prim_desc.weights_layer_desc(), cpu_engine);
reorder(user_common_weights_layer_memory, common_weights_layer_memory)
.
execute(s, user_common_weights_layer_memory,
common_weights_layer_memory);
}
auto common_weights_iter_memory = user_common_weights_iter_memory;
if (leftmost_prim_desc.weights_iter_desc()
!= common_weights_iter_memory.get_desc()) {
leftmost_prim_desc.weights_iter_desc(), cpu_engine);
reorder(user_common_weights_iter_memory, common_weights_iter_memory)
.execute(s, user_common_weights_iter_memory,
common_weights_iter_memory);
}
auto common_bias_memory = user_common_bias_memory;
if (leftmost_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
common_bias_memory
reorder(user_common_bias_memory, common_bias_memory)
.execute(s, user_common_bias_memory, common_bias_memory);
}
auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory;
if (leftmost_prim_desc.dst_layer_desc()
!= leftmost_dst_layer_memory.get_desc()) {
leftmost_prim_desc.dst_layer_desc(), cpu_engine);
reorder(user_leftmost_dst_layer_memory, leftmost_dst_layer_memory)
.execute(s, user_leftmost_dst_layer_memory,
leftmost_dst_layer_memory);
}
auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory;
if (rightmost_prim_desc.dst_layer_desc()
!= rightmost_dst_layer_memory.get_desc()) {
rightmost_prim_desc.dst_layer_desc(), cpu_engine);
reorder(user_rightmost_dst_layer_memory, rightmost_dst_layer_memory)
.execute(s, user_rightmost_dst_layer_memory,
rightmost_dst_layer_memory);
}
};
auto leftmost_workspace_memory = create_ws(leftmost_prim_desc);
auto rightmost_workspace_memory = create_ws(rightmost_prim_desc);
leftmost_layer.
execute(s,
{ { MKLDNN_ARG_SRC_LAYER, leftmost_src_layer_memory },
{ MKLDNN_ARG_WEIGHTS_LAYER, common_weights_layer_memory },
{ MKLDNN_ARG_WEIGHTS_ITER, common_weights_iter_memory },
{ MKLDNN_ARG_BIAS, common_bias_memory },
{ MKLDNN_ARG_DST_LAYER, leftmost_dst_layer_memory },
{ MKLDNN_ARG_DST_ITER, leftmost_dst_iter_memory },
{ MKLDNN_ARG_DST_ITER_C, leftmost_dst_iter_c_memory },
{ MKLDNN_ARG_WORKSPACE, leftmost_workspace_memory } });
rightmost_layer.execute(s,
{ { MKLDNN_ARG_SRC_LAYER, rightmost_src_layer_memory },
{ MKLDNN_ARG_SRC_ITER, rightmost_src_iter_memory },
{ MKLDNN_ARG_SRC_ITER_C, rightmost_src_iter_c_memory },
{ MKLDNN_ARG_WEIGHTS_LAYER, common_weights_layer_memory },
{ MKLDNN_ARG_WEIGHTS_ITER, common_weights_iter_memory },
{ MKLDNN_ARG_BIAS, common_bias_memory },
{ MKLDNN_ARG_DST_LAYER, rightmost_dst_layer_memory },
{ MKLDNN_ARG_WORKSPACE, rightmost_workspace_memory } });
if (!is_training)
return;
std::vector<float> net_diff_src(tz_volume(net_src_dims), 1.0f);
auto net_diff_src_memory
net_diff_src.data());
auto user_leftmost_diff_src_layer_md
leftmost_src_layer_dims, { 0, 0, 0 });
auto user_rightmost_diff_src_layer_md
rightmost_src_layer_dims,
{ leftmost_seq_length, 0, 0 });
auto leftmost_diff_src_layer_memory = net_diff_src_memory;
auto rightmost_diff_src_layer_memory = net_diff_src_memory;
std::vector<float> user_common_diff_weights_layer(
tz_volume(common_weights_layer_dims), 1.0f);
formatted_md(common_weights_layer_dims, tag::ldigo), cpu_engine,
user_common_diff_weights_layer.data());
std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_diff_bias_memory
cpu_engine, user_common_diff_bias.data());
memory::dims net_diff_dst_dims = {
T0,
N0 + N1,
common_feature_size
};
std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f);
auto net_diff_dst_memory
cpu_engine, net_diff_dst.data());
auto user_leftmost_diff_dst_layer_md
leftmost_dst_layer_dims, { 0, 0, 0 } );
auto user_rightmost_diff_dst_layer_md
rightmost_dst_layer_dims,
{ leftmost_seq_length, 0, 0 });
auto leftmost_diff_dst_layer_memory = net_diff_dst_memory;
auto rightmost_diff_dst_layer_memory = net_diff_dst_memory;
rnn_direction::unidirectional_left2right,
user_leftmost_src_layer_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(leftmost_dst_layer_dims, tag::tnc),
generic_md(leftmost_dst_iter_dims),
generic_md(leftmost_dst_iter_c_dims),
user_leftmost_diff_src_layer_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
user_leftmost_diff_dst_layer_md,
generic_md(leftmost_dst_iter_dims),
generic_md(leftmost_dst_iter_c_dims)
);
leftmost_layer_bwd_desc, cpu_engine, leftmost_prim_desc);
leftmost_bwd_prim_desc.diff_dst_iter_desc(), cpu_engine);
leftmost_bwd_prim_desc.diff_dst_iter_c_desc(), cpu_engine);
auto rightmost_diff_src_iter_md
rightmost_src_iter_dims,
{ 0, 0, 0, 0 });
auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory;
auto rightmost_diff_src_iter_c_md
rightmost_src_iter_c_dims,
{ 0, 0, 0, 0 });
auto rightmost_diff_src_iter_c_memory = leftmost_diff_dst_iter_c_memory;
rnn_direction::unidirectional_left2right,
user_rightmost_src_layer_md,
generic_md(rightmost_src_iter_dims),
generic_md(rightmost_src_iter_c_dims),
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(rightmost_dst_layer_dims, tag::tnc),
user_rightmost_diff_src_layer_md,
rightmost_diff_src_iter_md,
rightmost_diff_src_iter_c_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
user_rightmost_diff_dst_layer_md,
);
rightmost_layer_bwd_desc, cpu_engine, rightmost_prim_desc);
auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory;
auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory;
auto common_weights_layer_bwd_memory = common_weights_layer_memory;
if (leftmost_bwd_prim_desc.weights_layer_desc()
!= leftmost_prim_desc.weights_layer_desc()) {
common_weights_layer_bwd_memory =
memory(
leftmost_bwd_prim_desc.weights_layer_desc(), cpu_engine);
reorder(common_weights_layer_memory, common_weights_layer_bwd_memory)
.execute(s, common_weights_layer_memory,
common_weights_layer_bwd_memory);
}
auto common_weights_iter_bwd_memory = common_weights_iter_memory;
if (leftmost_bwd_prim_desc.weights_iter_desc()
!= leftmost_prim_desc.weights_iter_desc()) {
common_weights_iter_bwd_memory =
memory(
leftmost_bwd_prim_desc.weights_iter_desc(), cpu_engine);
reorder(common_weights_iter_memory, common_weights_iter_bwd_memory)
.execute(s, common_weights_iter_memory,
common_weights_iter_bwd_memory);
}
auto common_bias_bwd_memory = common_bias_memory;
if (leftmost_bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
leftmost_bwd_prim_desc.bias_desc(), cpu_engine);
reorder(common_bias_memory, common_bias_bwd_memory)
.execute(s, common_bias_memory, common_bias_bwd_memory);
}
auto common_diff_weights_layer_memory
= user_common_diff_weights_layer_memory;
auto reorder_common_diff_weights_layer = false;
if (leftmost_bwd_prim_desc.diff_weights_layer_desc()
!= common_diff_weights_layer_memory.get_desc()) {
leftmost_bwd_prim_desc.diff_weights_layer_desc(), cpu_engine);
reorder_common_diff_weights_layer = true;
}
auto common_diff_bias_memory = user_common_diff_bias_memory;
auto reorder_common_diff_bias = false;
if (leftmost_bwd_prim_desc.diff_bias_desc()
!= common_diff_bias_memory.get_desc()) {
leftmost_bwd_prim_desc.diff_bias_desc(), cpu_engine);
reorder_common_diff_bias = true;
}
auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory;
if (leftmost_bwd_prim_desc.dst_layer_desc()
!= leftmost_dst_layer_bwd_memory.get_desc()) {
leftmost_bwd_prim_desc.dst_layer_desc(), cpu_engine);
reorder(leftmost_dst_layer_memory, leftmost_dst_layer_bwd_memory)
.execute(s, leftmost_dst_layer_memory,
leftmost_dst_layer_bwd_memory);
}
auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory;
if (rightmost_bwd_prim_desc.dst_layer_desc()
!= rightmost_dst_layer_bwd_memory.get_desc()) {
rightmost_bwd_prim_desc.dst_layer_desc(), cpu_engine);
reorder(rightmost_dst_layer_memory, rightmost_dst_layer_bwd_memory)
.execute(s, rightmost_dst_layer_memory,
rightmost_dst_layer_bwd_memory);
}
leftmost_bwd_prim_desc.diff_weights_iter_desc(), cpu_engine);
auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory;
if (leftmost_bwd_prim_desc.dst_iter_desc()
!= leftmost_dst_iter_bwd_memory.
get_desc()) {
leftmost_bwd_prim_desc.dst_iter_desc(), cpu_engine);
reorder(leftmost_dst_iter_memory, leftmost_dst_iter_bwd_memory)
.execute(s, leftmost_dst_iter_memory,
leftmost_dst_iter_bwd_memory);
}
auto leftmost_dst_iter_c_bwd_memory = leftmost_dst_iter_c_memory;
if (leftmost_bwd_prim_desc.dst_iter_c_desc()
!= leftmost_dst_iter_c_bwd_memory.get_desc()) {
leftmost_bwd_prim_desc.dst_iter_c_desc(), cpu_engine);
reorder(leftmost_dst_iter_c_memory, leftmost_dst_iter_c_bwd_memory)
.execute(s, leftmost_dst_iter_c_memory,
leftmost_dst_iter_c_bwd_memory);
}
rightmost_layer_bwd.execute(s,
{ { MKLDNN_ARG_SRC_LAYER, rightmost_src_layer_bwd_memory },
{ MKLDNN_ARG_SRC_ITER, rightmost_src_iter_memory },
{ MKLDNN_ARG_SRC_ITER_C, rightmost_src_iter_c_memory },
{ MKLDNN_ARG_WEIGHTS_LAYER,
common_weights_layer_bwd_memory },
{ MKLDNN_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory },
{ MKLDNN_ARG_BIAS, common_bias_bwd_memory },
{ MKLDNN_ARG_DST_LAYER, rightmost_dst_layer_bwd_memory },
{ MKLDNN_ARG_DIFF_SRC_LAYER,
rightmost_diff_src_layer_memory },
{ MKLDNN_ARG_DIFF_SRC_ITER,
rightmost_diff_src_iter_memory },
{ MKLDNN_ARG_DIFF_SRC_ITER_C,
rightmost_diff_src_iter_c_memory },
{ MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
common_diff_weights_layer_memory },
{ MKLDNN_ARG_DIFF_WEIGHTS_ITER,
common_diff_weights_iter_memory },
{ MKLDNN_ARG_DIFF_BIAS, common_diff_bias_memory },
{ MKLDNN_ARG_DIFF_DST_LAYER,
rightmost_diff_dst_layer_memory },
{ MKLDNN_ARG_WORKSPACE, rightmost_workspace_memory } });
leftmost_layer_bwd.execute(s,
{ { MKLDNN_ARG_SRC_LAYER, leftmost_src_layer_bwd_memory },
{ MKLDNN_ARG_WEIGHTS_LAYER,
common_weights_layer_bwd_memory },
{ MKLDNN_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory },
{ MKLDNN_ARG_BIAS, common_bias_bwd_memory },
{ MKLDNN_ARG_DST_LAYER, leftmost_dst_layer_bwd_memory },
{ MKLDNN_ARG_DST_ITER, leftmost_dst_iter_bwd_memory },
{ MKLDNN_ARG_DST_ITER_C, leftmost_dst_iter_c_bwd_memory },
{ MKLDNN_ARG_DIFF_SRC_LAYER,
leftmost_diff_src_layer_memory },
{ MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
common_diff_weights_layer_memory },
{ MKLDNN_ARG_DIFF_WEIGHTS_ITER,
common_diff_weights_iter_memory },
{ MKLDNN_ARG_DIFF_BIAS, common_diff_bias_memory },
{ MKLDNN_ARG_DIFF_DST_LAYER,
leftmost_diff_dst_layer_memory },
{ MKLDNN_ARG_DIFF_DST_ITER, leftmost_diff_dst_iter_memory },
{ MKLDNN_ARG_DIFF_DST_ITER_C, leftmost_diff_dst_iter_c_memory },
{ MKLDNN_ARG_WORKSPACE, leftmost_workspace_memory } });
if (reorder_common_diff_weights_layer) {
reorder(common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory)
.execute(s, common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory);
}
if (reorder_common_diff_bias) {
reorder(common_diff_bias_memory, user_common_diff_bias_memory)
.execute(s, common_diff_bias_memory,
user_common_diff_bias_memory);
}
}
int main(int argc, char **argv) {
try {
simple_net();
std::cout << "ok\n";
std::cerr <<
"status: " << e.
status << std::endl;
std::cerr <<
"message: " << e.
message << std::endl;
return 1;
}
return 0;
}