The batch normalization primitive performs a forward or backward batch normalization operation on tensors with number of dimensions equal to 2 or more.
The batch normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Naming Conventions.
\[ \dst(n, c, h, w) = \gamma(c) \cdot \frac{\src(n, c, h, w)  \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}} + \beta(c), \]
where
Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used:
The \(\gamma(c)\) and \(\beta(c)\) tensors are considered learnable.
In training mode, the primitive also optionally supports fusion with ReLU activation with zero negative slope applied to the result (see dnnl_fuse_norm_relu flag).
\[ \hat\mu := \alpha \cdot \hat\mu + (1  \alpha) \cdot \mu, \\ \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1  \alpha) \cdot \sigma^2. \]
workspace
memory as one extra output. This memory is required to compute the backward propagation. When the primitive is executed with propagation kind dnnl_forward_inference, the workspace is not produced. Behavior would be the same as creating a batch normalization primitive with ReLU as a postop (see section below).The backward propagation computes \(\diffsrc(n, c, h, w)\), \(\diffgamma(c)^*\), and \(\diffbeta(c)^*\) based on \(\diffdst(n, c, h, w)\), \(\src(n, c, h, w)\), \(\mu(c)\), \(\sigma^2(c)\), \(\gamma(c) ^*\), and \(\beta(c) ^*\).
The tensors marked with an asterisk are used only when the primitive is configured to use \(\gamma(c)\) and \(\beta(c)\) (i.e., dnnl_use_scaleshift is set).
Depending on the flags and propagation kind, the batch normalization primitive requires different inputs and outputs. For clarity, a summary is shown below.
dnnl_forward_inference  dnnl_forward_training  dnnl_backward  dnnl_backward_data  

dnnl_normalization_flags_none  Inputs: \(\src\) Outputs: \(\dst\)  Inputs: \(\src\) Outputs: \(\dst\), \(\mu\), \(\sigma^2\)  Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\diffsrc\)  Same as for dnnl_backward 
dnnl_use_global_stats  Inputs: \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\dst\)  Inputs: \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\dst\)  Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\diffsrc\)  Same as for dnnl_backward 
dnnl_use_scaleshift  Inputs: \(\src\), \(\gamma\), \(\beta\) Outputs: \(\dst\)  Inputs: \(\src\), \(\gamma\), \(\beta\) Outputs: \(\dst\), \(\mu\), \(\sigma^2\)  Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\)  Not supported 
dnnl_use_global_stats  dnnl_use_scaleshift  Inputs: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\dst\)  Inputs: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\dst\)  Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\)  Not supported 
flags  dnnl_fuse_norm_relu  Inputs: same as with flags Outputs: same as with flags  Inputs: same as with flags Outputs: same as with flags , Workspace  Inputs: same as with flags , Workspace Outputs: same as with flags  Same as for dnnl_backward if flags do not contain dnnl_use_scaleshift; not supported otherwise 
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output  Execution argument index 

\(\src\)  DNNL_ARG_SRC 
\(\gamma, \beta\)  DNNL_ARG_SCALE_SHIFT 
mean ( \(\mu\))  DNNL_ARG_MEAN 
variance ( \(\sigma\))  DNNL_ARG_VARIANCE 
\(\dst\)  DNNL_ARG_DST 
workspace  DNNL_ARG_WORKSPACE 
\(\diffdst\)  DNNL_ARG_DIFF_DST 
\(\diffsrc\)  DNNL_ARG_DIFF_SRC 
\(\diffgamma, \diffbeta\)  DNNL_ARG_DIFF_SCALE_SHIFT 
flags
parameter that is passed to the operation descriptor initialization function (e.g., dnnl::batch_normalization_forward::desc::desc()). Multiple flags can be set using the bitwise OR operator (
).src
and dst
are assumed to be the same, and in the API they are typically referred to as data
(e.g., see data_desc
in dnnl::batch_normalization_forward::desc::desc()). The same is true for diff_src
and diff_dst
. The corresponding memory descriptors are referred to as diff_data_desc
.workspace
, that should be passed during the backward propagation.The operation supports the following combinations of data types:
Propagation  Source / Destination  Mea 

forward / backward  f32, bf16  f32 
forward  f16  f32 
forward  s8  f32 
The mean ( \(\mu\)) and variance ( \(\sigma^2\)) are separate 1D tensors of size \(C\).
The format of the corresponding memory object must be dnnl_x (dnnl_a).
If used, the scale ( \(\gamma\)) and shift ( \(\beta\)) are combined in a single 2D tensor of shape \(2 \times C\).
The format of the corresponding memory object must be dnnl_nc (dnnl_ab).
Like other CNN primitives, the batch normalization primitive expects data to be \(N \times C \times SP_n \times \cdots \times SP_0\) tensor.
The batch normalization primitive is optimized for the following memory formats:
Spatial  Logical tensor  Imp 

0D  NC  dnnl_nc (dnnl_ab) 
1D  NCW  dnnl_ncw (dnnl_abc), dnnl_nwc (dnnl_acb), optimized^ 
2D  NCHW  dnnl_nchw (dnnl_abcd), dnnl_nhwc (dnnl_acdb), optimized^ 
3D  NCDHW  dnnl_ncdhw (dnnl_abcde), dnnl_ndhwc (dnnl_acdeb), optimized^ 
Here optimized^ means the format that comes out of any preceding computeintensive primitive.
Postops and attributes enable you to modify the behavior of the batch normalization primitive by chaining certain operations after the batch normalization operation. The following postops are supported by batch normalization primitives:
Propagation  Type  Operation  Des 

forward  postop  eltwise  Applies an Eltwise operation to the result (currently only dnnl_eltwise_relu algorithm is supported) 
workspace
that is required to compute backward propagation correctly. Hence, in case of training one should use the dnnl_fuse_norm_relu directly.src
, diff_dst
, and diff_src
(the format of the diff_dst
and diff_src
are always the same because of the API). Different formats are functionally supported but lead to highly suboptimal performance.Engine  Name  Com 

CPU/GPU  Batch Normalization Primitive Example  This C++ API example demonstrates how to create and execute a Batch Normalization primitive in forward training propagation mode. Key optimizations included in this example:
