This example demonstrates the best practices for application performance optimizations with oneDNN.
#include <iostream>
#include <stdexcept>
#include <vector>
#include "example_utils.hpp"
void init_data(
memory &m,
float v) {
std::vector<float> data(size, v);
write_to_dnnl_memory(data.data(), m);
}
}
return attr;
}
void conv_relu_naive(
const memory &user_src,
const memory &user_wei,
conv_dst_md, strides, padding, padding);
conv.execute(s,
create_and_execute_relu(user_dst, eng, s);
}
conv_dst_md, strides, padding, padding);
if (conv_pd.src_desc() != user_src.
get_desc()) {
conv_src =
memory(conv_pd.src_desc(), eng);
}
if (conv_pd.weights_desc() != user_wei.
get_desc()) {
conv_wei =
memory(conv_pd.weights_desc(), eng);
}
if (conv_pd.dst_desc() != user_dst.
get_desc())
conv_dst =
memory(conv_pd.dst_desc(), eng);
conv.execute(s,
create_and_execute_relu(conv_dst, eng, s);
if (conv_pd.dst_desc() != user_dst.
get_desc()) {
}
}
conv_dst_md, strides, padding, padding);
auto attr = create_attr_with_relu_post_op();
if (conv_pd.src_desc() != user_src.
get_desc()) {
conv_src =
memory(conv_pd.src_desc(), eng);
}
if (conv_pd.weights_desc() != user_wei.
get_desc()) {
conv_wei =
memory(conv_pd.weights_desc(), eng);
}
if (conv_pd.dst_desc() != user_dst.
get_desc())
conv_dst =
memory(conv_pd.dst_desc(), eng);
conv.execute(s,
if (conv_pd.dst_desc() != user_dst.
get_desc()) {
}
}
void performance_profiling(
engine::kind engine_kind,
int argc,
char **argv) {
eng);
eng);
eng);
init_data(user_src, 1);
init_data(user_dst, -1);
init_data(user_wei, .5);
std::string implementation;
if (argc <= 2)
implementation = "validation";
else if (argc == 3)
implementation = argv[2];
if (!(implementation == "validation" || implementation == "naive"
|| implementation == "blocked" || implementation == "fused")) {
std::cout << "The implementation can be one of:\n";
std::cout << " - naive: NCHW format without fusion\n";
std::cout << " - blocked: format propagation without fusion\n";
std::cout << " - fused: format propagation with fusion\n";
std::cout << " - validation: runs all implementations\n\n";
std::cout << "Validation will run if no parameters are specified.\n\n";
throw std::invalid_argument("Incorrect input arguments.");
}
if (implementation == "naive" || implementation == "validation") {
std::cout << "Implementation: naive.\n";
conv_relu_naive(user_src, user_wei, user_dst, eng, s);
std::cout << "Conv + ReLU w/ nchw format completed.\n";
}
if (implementation == "blocked" || implementation == "validation") {
std::cout << "Implementation: blocked.\n";
conv_relu_blocked(user_src, user_wei, user_dst, eng, s);
std::cout << "Conv + ReLU w/ blocked format completed.\n";
}
if (implementation == "fused" || implementation == "validation") {
std::cout << "Implementation: fused.\n";
conv_relu_fused(user_src, user_wei, user_dst, eng, s);
std::cout << "Conv + ReLU w/ fusing completed.\n";
}
}
int main(int argc, char **argv) {
engine::kind engine_kind = parse_engine_kind(argc, argv, 1);
return handle_example_errors(
performance_profiling, engine_kind, argc, argv);
}