This C++ API example demonstrates how to build GNMT model training.
#include <cstring>
#include <math.h>
#include <numeric>
#include "example_utils.hpp"
const int N0 = 1 + rand() % 31;
const int N1 = 1 + rand() % 31;
const int T0 = 31 + 1 + rand() % 31;
const int T1 = 1 + rand() % 31;
const int leftmost_batch = N0 + N1;
const int rightmost_batch = N0;
const int leftmost_seq_length = T1;
const int rightmost_seq_length = T0 - T1;
const int common_feature_size = 1024;
const int common_n_layers = 1;
const int lstm_n_gates = 4;
auto eng =
engine(engine_kind, 0);
bool is_training = true;
std::vector<primitive> fwd_net;
std::vector<primitive> bwd_net;
T0,
N0 + N1,
common_feature_size
};
leftmost_seq_length,
leftmost_batch,
common_feature_size
};
rightmost_seq_length,
rightmost_batch,
common_feature_size
};
common_n_layers,
1,
common_feature_size,
lstm_n_gates,
common_feature_size
};
common_n_layers,
1,
common_feature_size,
lstm_n_gates,
common_feature_size
};
common_n_layers,
1,
lstm_n_gates,
common_feature_size
};
leftmost_seq_length,
leftmost_batch,
common_feature_size
};
rightmost_seq_length,
rightmost_batch,
common_feature_size
};
common_n_layers,
1,
leftmost_batch,
common_feature_size
};
common_n_layers,
1,
leftmost_batch,
common_feature_size
};
common_n_layers,
1,
rightmost_batch,
common_feature_size
};
common_n_layers,
1,
rightmost_batch,
common_feature_size
};
return std::accumulate(tz_dims.begin(), tz_dims.end(), (
memory::dim)1,
std::multiplies<memory::dim>());
};
auto formatted_md
};
return formatted_md(dimensions, tag::any);
};
std::vector<float> net_src(tz_volume(net_src_dims), 1.0f);
auto net_src_memory
write_to_dnnl_memory(net_src.data(), net_src_memory);
auto user_leftmost_src_layer_md = net_src_memory.get_desc().submemory_desc(
leftmost_src_layer_dims, {0, 0, 0});
auto user_rightmost_src_layer_md
= net_src_memory.get_desc().submemory_desc(rightmost_src_layer_dims,
{leftmost_seq_length, 0, 0});
auto leftmost_src_layer_memory = net_src_memory;
auto rightmost_src_layer_memory = net_src_memory;
std::vector<float> user_common_weights_layer(
tz_volume(common_weights_layer_dims), 1.0f);
{common_weights_layer_dims, dt::f32, tag::ldigo}, eng);
write_to_dnnl_memory(
user_common_weights_layer.data(), user_common_weights_layer_memory);
std::vector<float> user_common_weights_iter(
tz_volume(common_weights_iter_dims), 1.0f);
{{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng);
write_to_dnnl_memory(
user_common_weights_layer.data(), user_common_weights_iter_memory);
std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_bias_memory
=
dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng);
write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory);
std::vector<float> user_leftmost_dst_layer(
tz_volume(leftmost_dst_layer_dims), 1.0f);
auto user_leftmost_dst_layer_memory
=
dnnl::memory({{leftmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
write_to_dnnl_memory(
user_leftmost_dst_layer.data(), user_leftmost_dst_layer_memory);
std::vector<float> user_rightmost_dst_layer(
tz_volume(rightmost_dst_layer_dims), 1.0f);
{{rightmost_dst_layer_dims}, dt::f32, tag::tnc}, eng);
write_to_dnnl_memory(
user_rightmost_dst_layer.data(), user_rightmost_dst_layer_memory);
user_leftmost_src_layer_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(leftmost_dst_layer_dims, tag::tnc),
generic_md(leftmost_dst_iter_dims),
generic_md(leftmost_dst_iter_c_dims)
);
auto leftmost_prim_desc
auto leftmost_dst_iter_memory
auto leftmost_dst_iter_c_memory
auto rightmost_src_iter_md
rightmost_src_iter_dims,
{0, 0, 0, 0});
auto rightmost_src_iter_memory = leftmost_dst_iter_memory;
auto rightmost_src_iter_c_md
rightmost_src_iter_c_dims,
{0, 0, 0, 0});
auto rightmost_src_iter_c_memory = leftmost_dst_iter_c_memory;
user_rightmost_src_layer_md,
rightmost_src_iter_md,
rightmost_src_iter_c_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(rightmost_dst_layer_dims, tag::tnc),
);
auto rightmost_prim_desc
auto common_weights_layer_memory = user_common_weights_layer_memory;
if (leftmost_prim_desc.weights_layer_desc()
!= common_weights_layer_memory.get_desc()) {
common_weights_layer_memory
=
dnnl::memory(leftmost_prim_desc.weights_layer_desc(), eng);
reorder(user_common_weights_layer_memory, common_weights_layer_memory)
.
execute(s, user_common_weights_layer_memory,
common_weights_layer_memory);
}
auto common_weights_iter_memory = user_common_weights_iter_memory;
if (leftmost_prim_desc.weights_iter_desc()
!= common_weights_iter_memory.get_desc()) {
common_weights_iter_memory
=
dnnl::memory(leftmost_prim_desc.weights_iter_desc(), eng);
reorder(user_common_weights_iter_memory, common_weights_iter_memory)
.
execute(s, user_common_weights_iter_memory,
common_weights_iter_memory);
}
auto common_bias_memory = user_common_bias_memory;
if (leftmost_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
common_bias_memory =
dnnl::memory(leftmost_prim_desc.bias_desc(), eng);
reorder(user_common_bias_memory, common_bias_memory)
.
execute(s, user_common_bias_memory, common_bias_memory);
}
auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory;
if (leftmost_prim_desc.dst_layer_desc()
!= leftmost_dst_layer_memory.get_desc()) {
leftmost_dst_layer_memory
reorder(user_leftmost_dst_layer_memory, leftmost_dst_layer_memory)
.
execute(s, user_leftmost_dst_layer_memory,
leftmost_dst_layer_memory);
}
auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory;
if (rightmost_prim_desc.dst_layer_desc()
!= rightmost_dst_layer_memory.get_desc()) {
rightmost_dst_layer_memory
reorder(user_rightmost_dst_layer_memory, rightmost_dst_layer_memory)
.
execute(s, user_rightmost_dst_layer_memory,
rightmost_dst_layer_memory);
}
};
auto leftmost_workspace_memory = create_ws(leftmost_prim_desc);
auto rightmost_workspace_memory = create_ws(rightmost_prim_desc);
if (!is_training) return;
std::vector<float> net_diff_src(tz_volume(net_src_dims), 1.0f);
auto net_diff_src_memory
write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory);
auto user_leftmost_diff_src_layer_md
= net_diff_src_memory.get_desc().submemory_desc(
leftmost_src_layer_dims, {0, 0, 0});
auto user_rightmost_diff_src_layer_md
= net_diff_src_memory.get_desc().submemory_desc(
rightmost_src_layer_dims,
{leftmost_seq_length, 0, 0});
auto leftmost_diff_src_layer_memory = net_diff_src_memory;
auto rightmost_diff_src_layer_memory = net_diff_src_memory;
std::vector<float> user_common_diff_weights_layer(
tz_volume(common_weights_layer_dims), 1.0f);
formatted_md(common_weights_layer_dims, tag::ldigo), eng);
write_to_dnnl_memory(user_common_diff_weights_layer.data(),
user_common_diff_weights_layer_memory);
std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_diff_bias_memory
=
dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng);
write_to_dnnl_memory(
user_common_diff_bias.data(), user_common_diff_bias_memory);
T0,
N0 + N1,
common_feature_size
};
std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f);
auto net_diff_dst_memory
=
dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng);
write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory);
auto user_leftmost_diff_dst_layer_md
= net_diff_dst_memory.get_desc().submemory_desc(
leftmost_dst_layer_dims, {0, 0, 0});
auto user_rightmost_diff_dst_layer_md
= net_diff_dst_memory.get_desc().submemory_desc(
rightmost_dst_layer_dims,
{leftmost_seq_length, 0, 0});
auto leftmost_diff_dst_layer_memory = net_diff_dst_memory;
auto rightmost_diff_dst_layer_memory = net_diff_dst_memory;
user_leftmost_src_layer_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(leftmost_dst_layer_dims, tag::tnc),
generic_md(leftmost_dst_iter_dims),
generic_md(leftmost_dst_iter_c_dims),
user_leftmost_diff_src_layer_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
user_leftmost_diff_dst_layer_md,
generic_md(leftmost_dst_iter_dims),
generic_md(leftmost_dst_iter_c_dims)
);
leftmost_layer_bwd_desc, eng, leftmost_prim_desc);
auto leftmost_diff_dst_iter_memory
=
dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_desc(), eng);
auto leftmost_diff_dst_iter_c_memory
=
dnnl::memory(leftmost_bwd_prim_desc.diff_dst_iter_c_desc(), eng);
auto rightmost_diff_src_iter_md
rightmost_src_iter_dims,
{0, 0, 0, 0});
auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory;
auto rightmost_diff_src_iter_c_md
rightmost_src_iter_c_dims,
{0, 0, 0, 0});
auto rightmost_diff_src_iter_c_memory = leftmost_diff_dst_iter_c_memory;
user_rightmost_src_layer_md,
generic_md(rightmost_src_iter_dims),
generic_md(rightmost_src_iter_c_dims),
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
formatted_md(rightmost_dst_layer_dims, tag::tnc),
user_rightmost_diff_src_layer_md,
rightmost_diff_src_iter_md,
rightmost_diff_src_iter_c_md,
generic_md(common_weights_layer_dims),
generic_md(common_weights_iter_dims),
generic_md(common_bias_dims),
user_rightmost_diff_dst_layer_md,
);
rightmost_layer_bwd_desc, eng, rightmost_prim_desc);
auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory;
auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory;
auto common_weights_layer_bwd_memory = common_weights_layer_memory;
if (leftmost_bwd_prim_desc.weights_layer_desc()
!= leftmost_prim_desc.weights_layer_desc()) {
common_weights_layer_bwd_memory
=
memory(leftmost_bwd_prim_desc.weights_layer_desc(), eng);
reorder(common_weights_layer_memory, common_weights_layer_bwd_memory)
.
execute(s, common_weights_layer_memory,
common_weights_layer_bwd_memory);
}
auto common_weights_iter_bwd_memory = common_weights_iter_memory;
if (leftmost_bwd_prim_desc.weights_iter_desc()
!= leftmost_prim_desc.weights_iter_desc()) {
common_weights_iter_bwd_memory
=
memory(leftmost_bwd_prim_desc.weights_iter_desc(), eng);
reorder(common_weights_iter_memory, common_weights_iter_bwd_memory)
.
execute(s, common_weights_iter_memory,
common_weights_iter_bwd_memory);
}
auto common_bias_bwd_memory = common_bias_memory;
if (leftmost_bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
common_bias_bwd_memory
reorder(common_bias_memory, common_bias_bwd_memory)
.
execute(s, common_bias_memory, common_bias_bwd_memory);
}
auto common_diff_weights_layer_memory
= user_common_diff_weights_layer_memory;
auto reorder_common_diff_weights_layer = false;
if (leftmost_bwd_prim_desc.diff_weights_layer_desc()
!= common_diff_weights_layer_memory.get_desc()) {
leftmost_bwd_prim_desc.diff_weights_layer_desc(), eng);
reorder_common_diff_weights_layer = true;
}
auto common_diff_bias_memory = user_common_diff_bias_memory;
auto reorder_common_diff_bias = false;
if (leftmost_bwd_prim_desc.diff_bias_desc()
!= common_diff_bias_memory.get_desc()) {
common_diff_bias_memory
=
dnnl::memory(leftmost_bwd_prim_desc.diff_bias_desc(), eng);
reorder_common_diff_bias = true;
}
auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory;
if (leftmost_bwd_prim_desc.dst_layer_desc()
!= leftmost_dst_layer_bwd_memory.get_desc()) {
leftmost_dst_layer_bwd_memory
=
dnnl::memory(leftmost_bwd_prim_desc.dst_layer_desc(), eng);
reorder(leftmost_dst_layer_memory, leftmost_dst_layer_bwd_memory)
.
execute(s, leftmost_dst_layer_memory,
leftmost_dst_layer_bwd_memory);
}
auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory;
if (rightmost_bwd_prim_desc.dst_layer_desc()
!= rightmost_dst_layer_bwd_memory.get_desc()) {
rightmost_dst_layer_bwd_memory
=
dnnl::memory(rightmost_bwd_prim_desc.dst_layer_desc(), eng);
reorder(rightmost_dst_layer_memory, rightmost_dst_layer_bwd_memory)
.
execute(s, rightmost_dst_layer_memory,
rightmost_dst_layer_bwd_memory);
}
leftmost_bwd_prim_desc.diff_weights_iter_desc(), eng);
auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory;
if (leftmost_bwd_prim_desc.dst_iter_desc()
!= leftmost_dst_iter_bwd_memory.
get_desc()) {
leftmost_dst_iter_bwd_memory
=
dnnl::memory(leftmost_bwd_prim_desc.dst_iter_desc(), eng);
reorder(leftmost_dst_iter_memory, leftmost_dst_iter_bwd_memory)
.
execute(s, leftmost_dst_iter_memory,
leftmost_dst_iter_bwd_memory);
}
auto leftmost_dst_iter_c_bwd_memory = leftmost_dst_iter_c_memory;
if (leftmost_bwd_prim_desc.dst_iter_c_desc()
!= leftmost_dst_iter_c_bwd_memory.get_desc()) {
leftmost_dst_iter_c_bwd_memory
=
dnnl::memory(leftmost_bwd_prim_desc.dst_iter_c_desc(), eng);
reorder(leftmost_dst_iter_c_memory, leftmost_dst_iter_c_bwd_memory)
.
execute(s, leftmost_dst_iter_c_memory,
leftmost_dst_iter_c_bwd_memory);
}
rightmost_diff_src_iter_c_memory},
common_diff_weights_layer_memory},
common_diff_weights_iter_memory},
common_diff_weights_layer_memory},
common_diff_weights_iter_memory},
if (reorder_common_diff_weights_layer) {
reorder(common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory)
.
execute(s, common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory);
}
if (reorder_common_diff_bias) {
reorder(common_diff_bias_memory, user_common_diff_bias_memory)
.
execute(s, common_diff_bias_memory,
user_common_diff_bias_memory);
}
}
int main(int argc, char **argv) {
return handle_example_errors(simple_net, parse_engine_kind(argc, argv));
}