This example demonstrates memory format propagation, which is critical for deep learning applications performance.
#include <iostream>
#include <sstream>
#include <string>
void cpu_memory_format_propagation_tutorial() {
stream cpu_stream(cpu_engine);
const int N = 1, H = 14, W = 14, IC = 256, OC = IC, KH = 3, KW = 3;
);
);
auto conv_dst_md = conv_src_md;
auto pool_dst_md = conv_dst_md;
conv_src_md, conv_weights_md, conv_dst_md,
{1, 1},
{1, 1}, {1, 1}},
cpu_engine);
conv_pd.dst_desc(), pool_dst_md,
{1, 1}, {KH, KW},
{1, 1}, {1, 1}},
cpu_engine);
auto src_mem =
memory({{N, IC, H, W},
cpu_engine);
auto weights_mem =
memory({{IC, OC, KH, KW},
cpu_engine);
auto dst_mem =
memory({{N, IC, H, W},
cpu_engine);
bool need_reorder_src = conv_pd.src_desc() != src_mem.get_desc();
bool need_reorder_weights = conv_pd.weights_desc() != weights_mem.
get_desc();
bool need_reorder_dst = conv_pd.dst_desc() != dst_mem.
get_desc();
auto conv_src_mem = need_reorder_src
?
memory(conv_pd.src_desc(), cpu_engine)
: src_mem;
auto conv_weights_mem = need_reorder_weights
?
memory(conv_pd.weights_desc(), cpu_engine)
: weights_mem;
auto conv_dst_mem =
memory(conv_pd.dst_desc(), cpu_engine);
auto pool_dst_mem = need_reorder_dst
: dst_mem;
if (need_reorder_src) {
auto reorder_src =
reorder(src_mem, conv_src_mem);
reorder_src.execute(cpu_stream, {
{MKLDNN_ARG_FROM, src_mem},
{MKLDNN_ARG_TO, conv_src_mem}
});
}
if (need_reorder_weights) {
auto reorder_weights =
reorder(weights_mem, conv_weights_mem);
reorder_weights.execute(cpu_stream, {
{MKLDNN_ARG_FROM, weights_mem},
{MKLDNN_ARG_TO, conv_weights_mem}
});
}
auto conv_scratchpad_mem =
memory(conv_pd.scratchpad_desc(), cpu_engine);
conv.execute(cpu_stream, {
{MKLDNN_ARG_SRC, conv_src_mem},
{MKLDNN_ARG_WEIGHTS, conv_weights_mem},
{MKLDNN_ARG_DST, conv_dst_mem}
});
pool.execute(cpu_stream, {
{MKLDNN_ARG_SRC, conv_dst_mem},
{MKLDNN_ARG_DST, pool_dst_mem}
});
if (need_reorder_dst) {
auto reorder_dst =
reorder(pool_dst_mem, dst_mem);
reorder_dst.execute(cpu_stream, {
{MKLDNN_ARG_FROM, pool_dst_mem},
{MKLDNN_ARG_TO, dst_mem}
});
}
}
int main(int argc, char **argv) {
try {
cpu_memory_format_propagation_tutorial();
std::cerr <<
"Intel MKL-DNN error: " << e.
what() << std::endl
<<
"Error status: " << mkldnn_status2str(e.
status) << std::endl;
return 1;
} catch (std::string &e) {
std::cerr << "Error in the example: " << e << std::endl;
return 2;
}
std::cout << "Example passes" << std::endl;
return 0;
}