This C API example demonstrates how to build an AlexNet model training.
#define _POSIX_C_SOURCE 200112L
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define BATCH 32
#define IC 3
#define OC 96
#define CONV_IH 227
#define CONV_IW 227
#define CONV_OH 55
#define CONV_OW 55
#define CONV_STRIDE 4
#define CONV_PAD 0
#define POOL_OH 27
#define POOL_OW 27
#define POOL_STRIDE 2
#define POOL_PAD 0
#define CHECK(f) \
do { \
dnnl_status_t s = f; \
if (s != dnnl_success) { \
printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f, \
s); \
exit(2); \
} \
} while (0)
#define CHECK_TRUE(expr) \
do { \
int e_ = expr; \
if (!e_) { \
printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr); \
exit(2); \
} \
} while (0)
static size_t product(
dnnl_dim_t *arr,
size_t size) {
size_t prod = 1;
for (size_t i = 0; i < size; ++i)
prod *= arr[i];
return prod;
}
static void init_net_data(
float *data, uint32_t dim,
const dnnl_dim_t *dims) {
if (dim == 1) {
data[i] = (float)(i % 1637);
}
} else if (dim == 4) {
dnnl_dim_t indx = in * dims[1] * dims[2] * dims[3]
+ ic * dims[2] * dims[3] + ih * dims[3] + iw;
data[indx] = (float)(indx % 1637);
}
}
}
typedef struct {
int nargs;
} args_t;
static void prepare_arg_node(args_t *node, int nargs) {
node->nargs = nargs;
}
static void free_arg_node(args_t *node) {
free(node->args);
}
}
static void init_data_memory(uint32_t dim,
const dnnl_dim_t *dims,
&user_md, dim, dims,
dnnl_f32, user_fmt));
}
int dir_is_user_to_prim,
uint32_t *net_index,
DNNL_MEMORY_ALLOCATE));
if (dir_is_user_to_prim) {
user_memory_md, user_mem_engine, prim_memory_md,
prim_engine, NULL));
} else {
prim_memory_md, prim_engine, user_memory_md,
user_mem_engine, NULL));
}
net[*net_index] = *reorder;
prepare_arg_node(&net_args[*net_index], 2);
set_arg(&net_args[*net_index].args[0], DNNL_ARG_FROM,
dir_is_user_to_prim ? *user_memory : *prim_memory);
set_arg(&net_args[*net_index].args[1], DNNL_ARG_TO,
dir_is_user_to_prim ? *prim_memory : *user_memory);
(*net_index)++;
} else {
*prim_memory = NULL;
*reorder = NULL;
}
}
uint32_t n_fwd = 0, n_bwd = 0;
args_t net_fwd_args[10], net_bwd_args[10];
dnnl_dim_t net_src_sizes[4] = {BATCH, IC, CONV_IH, CONV_IW};
dnnl_dim_t net_dst_sizes[4] = {BATCH, OC, POOL_OH, POOL_OW};
float *net_src = (float *)malloc(product(net_src_sizes, 4) * sizeof(float));
float *net_dst = (float *)malloc(product(net_dst_sizes, 4) * sizeof(float));
init_net_data(net_src, 4, net_src_sizes);
memset(net_dst, 0, product(net_dst_sizes, 4) * sizeof(float));
dnnl_dim_t conv_user_weights_sizes[4] = {OC, IC, 11, 11};
dnnl_dim_t conv_user_dst_sizes[4] = {BATCH, OC, CONV_OH, CONV_OW};
dnnl_dim_t conv_strides[2] = {CONV_STRIDE, CONV_STRIDE};
dnnl_dim_t conv_padding[2] = {CONV_PAD, CONV_PAD};
float *conv_src = net_src;
float *conv_weights = (float *)malloc(
product(conv_user_weights_sizes, 4) * sizeof(float));
float *conv_bias
= (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float));
init_net_data(conv_weights, 4, conv_user_weights_sizes);
init_net_data(conv_bias, 1, conv_bias_sizes);
conv_user_bias_memory;
conv_src, &conv_user_src_memory);
conv_weights, &conv_user_weights_memory);
init_data_memory(1, conv_bias_sizes,
dnnl_x,
dnnl_f32, engine, conv_bias,
&conv_user_bias_memory);
{
conv_dst_md;
&conv_bias_md, &conv_dst_md, conv_strides, conv_padding,
conv_padding));
&conv_pd, &conv_any_desc, NULL, engine, NULL));
}
dnnl_memory_t conv_internal_src_memory, conv_internal_weights_memory,
conv_internal_dst_memory;
DNNL_MEMORY_ALLOCATE));
CHECK(prepare_reorder(&conv_user_src_memory, conv_src_md, engine, 1,
&conv_internal_src_memory, &conv_reorder_src, &n_fwd, net_fwd,
net_fwd_args));
CHECK(prepare_reorder(&conv_user_weights_memory, conv_weights_md, engine, 1,
&conv_internal_weights_memory, &conv_reorder_weights, &n_fwd,
net_fwd, net_fwd_args));
? conv_internal_src_memory
: conv_user_src_memory;
? conv_internal_weights_memory
: conv_user_weights_memory;
net_fwd[n_fwd] = conv;
prepare_arg_node(&net_fwd_args[n_fwd], 4);
set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, conv_src_memory);
set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_WEIGHTS,
conv_weights_memory);
set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_BIAS, conv_user_bias_memory);
set_arg(&net_fwd_args[n_fwd].args[3], DNNL_ARG_DST,
conv_internal_dst_memory);
n_fwd++;
float negative_slope = 1.0f;
&relu_dst_memory, relu_dst_md, engine, DNNL_MEMORY_ALLOCATE));
net_fwd[n_fwd] = relu;
prepare_arg_node(&net_fwd_args[n_fwd], 2);
set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC,
conv_internal_dst_memory);
set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, relu_dst_memory);
n_fwd++;
uint32_t local_size = 5;
float alpha = 0.0001f;
float beta = 0.75f;
float k = 1.0f;
&lrn_dst_memory, lrn_dst_md, engine, DNNL_MEMORY_ALLOCATE));
&lrn_ws_memory, lrn_ws_md, engine, DNNL_MEMORY_ALLOCATE));
net_fwd[n_fwd] = lrn;
prepare_arg_node(&net_fwd_args[n_fwd], 3);
set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, relu_dst_memory);
set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, lrn_dst_memory);
set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory);
n_fwd++;
dnnl_dim_t pool_strides[2] = {POOL_STRIDE, POOL_STRIDE};
dnnl_dim_t pool_padding[2] = {POOL_PAD, POOL_PAD};
&pool_user_dst_memory);
{
pool_kernel, pool_padding, pool_padding));
&pool_pd, &pool_desc, NULL, engine, NULL));
}
&pool_ws_memory, pool_ws_md, engine, DNNL_MEMORY_ALLOCATE));
n_fwd += 1;
CHECK(prepare_reorder(&pool_user_dst_memory, pool_dst_md, engine, 0,
&pool_internal_dst_memory, &pool_reorder_dst, &n_fwd, net_fwd,
net_fwd_args));
n_fwd -= pool_reorder_dst ? 2 : 1;
? pool_internal_dst_memory
: pool_user_dst_memory;
net_fwd[n_fwd] = pool;
prepare_arg_node(&net_fwd_args[n_fwd], 3);
set_arg(&net_fwd_args[n_fwd].args[0], DNNL_ARG_SRC, lrn_dst_memory);
set_arg(&net_fwd_args[n_fwd].args[1], DNNL_ARG_DST, pool_dst_memory);
set_arg(&net_fwd_args[n_fwd].args[2], DNNL_ARG_WORKSPACE, pool_ws_memory);
n_fwd++;
if (pool_reorder_dst) n_fwd += 1;
float *net_diff_dst
= (float *)malloc(product(pool_dst_sizes, 4) * sizeof(float));
init_net_data(net_diff_dst, 4, pool_dst_sizes);
net_diff_dst, &pool_user_diff_dst_memory);
pool_diff_src_md, pool_diff_dst_md, pool_strides, pool_kernel,
pool_padding, pool_padding));
&pool_bwd_pd, &pool_bwd_desc, NULL, engine, pool_pd));
dnnl_memory_t pool_diff_dst_memory, pool_internal_diff_dst_memory;
CHECK(prepare_reorder(&pool_user_diff_dst_memory, pool_diff_dst_md, engine,
1, &pool_internal_diff_dst_memory, &pool_reorder_diff_dst, &n_bwd,
net_bwd, net_bwd_args));
pool_diff_dst_memory = pool_internal_diff_dst_memory
? pool_internal_diff_dst_memory
: pool_user_diff_dst_memory;
DNNL_MEMORY_ALLOCATE));
net_bwd[n_bwd] = pool_bwd;
prepare_arg_node(&net_bwd_args[n_bwd], 3);
set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_DIFF_DST,
pool_diff_dst_memory);
set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_WORKSPACE, pool_ws_memory);
set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_SRC,
pool_diff_src_memory);
n_bwd++;
lrn_diff_dst_md, lrn_src_md, local_size, alpha, beta, k));
&lrn_bwd_pd, &lrn_bwd_desc, NULL, engine, lrn_pd));
DNNL_MEMORY_ALLOCATE));
net_bwd[n_bwd] = lrn_bwd;
prepare_arg_node(&net_bwd_args[n_bwd], 4);
set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC, relu_dst_memory);
set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST,
pool_diff_src_memory);
set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory);
set_arg(&net_bwd_args[n_bwd].args[3], DNNL_ARG_DIFF_SRC,
lrn_diff_src_memory);
n_bwd++;
relu_diff_dst_md, relu_src_md, negative_slope, 0));
&relu_bwd_pd, &relu_bwd_desc, NULL, engine, relu_pd));
DNNL_MEMORY_ALLOCATE));
net_bwd[n_bwd] = relu_bwd;
prepare_arg_node(&net_bwd_args[n_bwd], 3);
set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC,
conv_internal_dst_memory);
set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST,
lrn_diff_src_memory);
set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_SRC,
relu_diff_src_memory);
n_bwd++;
float *conv_diff_bias_buffer
= (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float));
float *conv_user_diff_weights_buffer = (float *)malloc(
product(conv_user_weights_sizes, 4) * sizeof(float));
conv_user_diff_weights_buffer, &conv_user_diff_weights_memory);
{
conv_diff_bias_md, conv_diff_dst_md;
&conv_diff_src_md, &conv_diff_weights_md, &conv_diff_bias_md,
&conv_diff_dst_md, conv_strides, conv_padding, conv_padding));
&conv_bwd_weights_desc, NULL, engine, conv_pd));
}
CHECK(prepare_reorder(&conv_src_memory, conv_diff_src_md, engine, 1,
&conv_bwd_internal_src_memory, &conv_bwd_reorder_src, &n_bwd,
net_bwd, net_bwd_args));
dnnl_memory_t conv_bwd_weights_src_memory = conv_bwd_internal_src_memory
? conv_bwd_internal_src_memory
: conv_src_memory;
CHECK(prepare_reorder(&relu_diff_src_memory, conv_diff_dst_md, engine, 1,
&conv_internal_diff_dst_memory, &conv_reorder_diff_dst, &n_bwd,
net_bwd, net_bwd_args));
dnnl_memory_t conv_diff_dst_memory = conv_internal_diff_dst_memory
? conv_internal_diff_dst_memory
: relu_diff_src_memory;
n_bwd += 1;
CHECK(prepare_reorder(&conv_user_diff_weights_memory, conv_diff_weights_md,
engine, 0, &conv_internal_diff_weights_memory,
&conv_reorder_diff_weights, &n_bwd, net_bwd, net_bwd_args));
n_bwd -= conv_reorder_diff_weights ? 2 : 1;
dnnl_memory_t conv_diff_weights_memory = conv_internal_diff_weights_memory
? conv_internal_diff_weights_memory
: conv_user_diff_weights_memory;
&conv_diff_bias_memory, conv_diff_bias_md, engine, NULL));
conv_diff_bias_memory, conv_diff_bias_buffer));
net_bwd[n_bwd] = conv_bwd_weights;
prepare_arg_node(&net_bwd_args[n_bwd], 4);
set_arg(&net_bwd_args[n_bwd].args[0], DNNL_ARG_SRC,
conv_bwd_weights_src_memory);
set_arg(&net_bwd_args[n_bwd].args[1], DNNL_ARG_DIFF_DST,
conv_diff_dst_memory);
set_arg(&net_bwd_args[n_bwd].args[2], DNNL_ARG_DIFF_WEIGHTS,
conv_diff_weights_memory);
set_arg(&net_bwd_args[n_bwd].args[3], DNNL_ARG_DIFF_BIAS,
conv_diff_bias_memory);
n_bwd++;
if (conv_reorder_diff_weights) n_bwd += 1;
void *net_diff_weights = NULL;
void *net_diff_bias = NULL;
int n_iter = 10;
for (int i = 0; i < n_iter; i++) {
for (uint32_t i = 0; i < n_fwd; ++i)
net_fwd_args[i].nargs, net_fwd_args[i].args));
void *net_output = NULL;
for (uint32_t i = 0; i < n_bwd; ++i)
net_bwd_args[i].nargs, net_bwd_args[i].args));
conv_user_diff_weights_memory, &net_diff_weights));
conv_diff_bias_memory, &net_diff_bias));
}
for (uint32_t i = 0; i < n_fwd; ++i)
free_arg_node(&net_fwd_args[i]);
for (uint32_t i = 0; i < n_bwd; ++i)
free_arg_node(&net_bwd_args[i]);
free(net_src);
free(net_dst);
free(conv_weights);
free(conv_bias);
free(net_diff_dst);
free(conv_diff_bias_buffer);
free(conv_user_diff_weights_buffer);
}
int main(int argc, char **argv) {
printf(
"%s\n", (result ==
dnnl_success) ?
"passed" :
"failed");
return result;
}