The batch normalization primitive performs a forward or backward batch normalization operation on 0D, 2D, or 3D spatial data.
The batch normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Naming Conventions.
\[ \dst(n, c, h, w) = \gamma(c) \cdot \frac{\src(n, c, h, w) - \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}} + \beta(c), \]
where
Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used:
The \(\gamma(c)\) and \(\beta(c)\) tensors are considered learnable.
In training mode, the primitive also optionally supports fusion with ReLU activation with zero negative slope applied to the result (see dnnl_fuse_norm_relu flag).
\[ \hat\mu := \alpha \cdot \hat\mu + (1 - \alpha) \cdot \mu, \\ \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1 - \alpha) \cdot \sigma^2. \]
workspace
memory as one extra output. This memory is required to compute the backward propagation. When the primitive is executed with propagation kind dnnl_forward_inference, the workspace is not produced. Behavior would be the same as creating a batch normalization primitive with ReLU as a post-op (see section below).The backward propagation computes \(\diffsrc(n, c, h, w)\), \(\diffgamma(c)^*\), and \(\diffbeta(c)^*\) based on \(\diffdst(n, c, h, w)\), \(\src(n, c, h, w)\), \(\mu(c)\), \(\sigma^2(c)\), \(\gamma(c) ^*\), and \(\beta(c) ^*\).
The tensors marked with an asterisk are used only when the primitive is configured to use \(\gamma(c)\) and \(\beta(c)\) (i.e., dnnl_use_scaleshift is set).
Depending on the flags and propagation kind, the batch normalization primitive requires different inputs and outputs. For clarity, a summary is shown below.
dnnl_forward_inference | dnnl_forward_training | dnnl_backward | dnnl_backward_data | |
---|---|---|---|---|
(none) | Inputs: \(\src\) Outputs: \(\dst\) | Inputs: \(\src\) Outputs: \(\dst\), \(\mu\), \(\sigma^2\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\diffsrc\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\diffsrc\) |
dnnl_use_global_stats | Inputs: \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\dst\) | Inputs: \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\dst\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\diffsrc\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\) Outputs: \(\diffsrc\) |
dnnl_use_scaleshift | Inputs: \(\src\), \(\gamma\), \(\beta\) Outputs: \(\dst\) | Inputs: \(\src\), \(\gamma\), \(\beta\) Outputs: \(\dst\), \(\mu\), \(\sigma^2\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\diffsrc\) |
dnnl_use_global_stats | dnnl_use_scaleshift | Inputs: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\dst\) | Inputs: \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\dst\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\diffsrc\), \(\diffgamma\), \(\diffbeta\) | Inputs: \(\diffdst\), \(\src\), \(\mu\), \(\sigma^2\), \(\gamma\), \(\beta\) Outputs: \(\diffsrc\) |
flags | dnnl_fuse_norm_relu | Inputs: same as with flags Outputs: same as with flags | Inputs: same as with flags Outputs: same as with flags , Workspace | Inputs: same as with flags , Workspace Outputs: same as with flags | Inputs: same as with flags , Workspace Outputs: same as with flags |
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output | Execution argument index |
---|---|
\(\src\) | DNNL_ARG_SRC |
\(\gamma, \beta\) | DNNL_ARG_SCALE_SHIFT |
mean ( \(\mu\)) | DNNL_ARG_MEAN |
variance ( \(\sigma\)) | DNNL_ARG_VARIANCE |
\(\dst\) | DNNL_ARG_DST |
workspace | DNNL_ARG_WORKSPACE |
\(\diffdst\) | DNNL_ARG_DIFF_DST |
\(\diffsrc\) | DNNL_ARG_DIFF_SRC |
\(\diffgamma, \diffbeta\) | DNNL_ARG_DIFF_SCALE_SHIFT |
flags
parameter that is passed to the operation descriptor initialization function (e.g., dnnl::batch_normalization_forward::desc::desc()). Multiple flags can be set using the bitwise OR operator (|
).src
and dst
are assumed to be the same, and in the API they are typically referred to as data
(e.g., see data_desc
in dnnl::batch_normalization_forward::desc::desc()). The same is true for diff_src
and diff_dst
. The corresponding memory descriptors are referred to as diff_data_desc
.src
can be used as input and output for forward propagation, and diff_dst
can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten.workspace
, that should be passed during the backward propagation.The operation supports the following combinations of data types:
Propagation | Source / Destination | Mea |
---|---|---|
forward / backward | f32, bf16 | f32 |
forward | f16 | f32 |
forward | s8 | f32 |
The mean ( \(\mu\)) and variance ( \(\sigma^2\)) are separate 1D tensors of size \(C\).
The format of the corresponding memory object must be dnnl_x (dnnl_a).
If used, the scale ( \(\gamma\)) and shift ( \(\beta\)) are combined in a single 2D tensor of shape \(2 \times C\).
The format of the corresponding memory object must be dnnl_nc (dnnl_ab).
Like other CNN primitives, the batch normalization primitive expects data to be \(N \times C \times SP_n \times \cdots \times SP_0\) tensor.
The batch normalization primitive is optimized for the following memory formats:
Spatial | Logical tensor | Imp |
---|---|---|
0D | NC | dnnl_nc (dnnl_ab) |
2D | NCHW | dnnl_nchw (dnnl_abcd), dnnl_nhwc (dnnl_acdb), optimized^ |
3D | NCDHW | dnnl_ncdhw (dnnl_abcde), dnnl_ndhwc (dnnl_acdeb), optimized^ |
Here optimized^ means the format that comes out of any preceding compute-intensive primitive.
Post-ops and attributes enable you to modify the behavior of the batch normalization primitive by chaining certain operations after the batch normalization operation. The following post-ops are supported by batch normalization primitives:
Propagation | Type | Operation | Des |
---|---|---|---|
forward | post-op | eltwise | Applies an Eltwise operation to the result (currently only dnnl_eltwise_relu algorithm is supported) |
workspace
that is required to compute backward propagation correctly. Hence, in case of training one should use the dnnl_fuse_norm_relu directly.src
, diff_dst
, and diff_src
(the format of the diff_dst
and diff_src
are always the same because of the API). Different formats are functionally supported but lead to highly suboptimal performance.