.. index:: pair: page; Batch Normalization .. _doxid-dev_guide_batch_normalization: Batch Normalization =================== :ref:API Reference  General ~~~~~~~ The batch normalization primitive performs a forward or backward batch normalization operation on tensors with number of dimensions equal to 2 or more. Forward ------- The batch normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard :ref:Naming Conventions . .. math:: \dst(n, c, h, w) = \gamma(c) \cdot \frac{\src(n, c, h, w) - \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}} + \beta(c), where * :math:\gamma(c), \beta(c) are optional scale and shift for a channel (see :ref:dnnl_use_scaleshift , :ref:dnnl_use_scale  and :ref:dnnl_use_shift  flags), * :math:\mu(c), \sigma^2(c) are mean and variance for a channel (see :ref:dnnl_use_global_stats  flag), and * :math:\varepsilon is a constant to improve numerical stability. Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used: * :math:\mu(c) = \frac{1}{NHW} \sum\limits_{nhw} \src(n, c, h, w)_{}, * :math:\sigma^2(c) = \frac{1}{NHW} \sum\limits_{nhw} {}_{} (\src(n, c, h, w) - \mu(c))^2. The :math:\gamma(c) and :math:\beta(c) tensors are considered learnable. In training mode, the primitive also optionally supports fusion with ReLU activation with zero negative slope applied to the result (see :ref:dnnl_fuse_norm_relu  flag). .. note:: * The batch normalization primitive computes population mean and variance and not the sample or unbiased versions that are typically used to compute running mean and variance. * Using the mean and variance computed by the batch normalization primitive, running mean and variance :math:\hat\mu and :math:\hat\sigma^2 can be computed as .. math:: \hat\mu := \alpha \cdot \hat\mu + (1 - \alpha) \cdot \mu, \\ \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1 - \alpha) \cdot \sigma^2. Difference Between Forward Training and Forward Inference +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ * If mean and variance are computed at runtime (i.e., :ref:dnnl_use_global_stats  is not set), they become outputs for the propagation kind :ref:dnnl_forward_training  (because they would be required during the backward propagation) and are not exposed for the propagation kind :ref:dnnl_forward_inference . * If batch normalization is created with ReLU fusion (i.e., :ref:dnnl_fuse_norm_relu  is set), for the propagation kind :ref:dnnl_forward_training  the primitive would produce a workspace memory as one extra output. This memory is required to compute the backward propagation. When the primitive is executed with propagation kind :ref:dnnl_forward_inference , the workspace is not produced. Behavior would be the same as creating a batch normalization primitive with ReLU as a post-op (see section below). Backward -------- The backward propagation computes :math:\diffsrc(n, c, h, w), :math:\diffgamma(c)^*, and :math:\diffbeta(c)^* based on :math:\diffdst(n, c, h, w), :math:\src(n, c, h, w), :math:\mu(c), :math:\sigma^2(c), :math:\gamma(c) ^*, and :math:\beta(c) ^*. The tensors marked with an asterisk are used only when the primitive is configured to use :math:\gamma(c) and :math:\beta(c) (i.e., :ref:dnnl_use_scaleshift , :ref:dnnl_use_scale  or :ref:dnnl_use_shift  are set). Execution Arguments ~~~~~~~~~~~~~~~~~~~ Depending on the :ref:flags  and :ref:propagation kind , the batch normalization primitive requires different inputs and outputs. For clarity, a summary is shown below. ================================================================================================================================================================================================================================================================================================ ================================================================================================================================================= ============================================================================================================================================================================================= ============================================================================================================================================================================================= ======================================================================================================================================================================================================================================================================================================================================================== :ref:dnnl_forward_inference  :ref:dnnl_forward_training  :ref:dnnl_backward  :ref:dnnl_backward_data  ================================================================================================================================================================================================================================================================================================ ================================================================================================================================================= ============================================================================================================================================================================================= ============================================================================================================================================================================================= ======================================================================================================================================================================================================================================================================================================================================================== :ref:dnnl_normalization_flags_none  *Inputs* : :math:\src *Outputs* : :math:\dst *Inputs* : :math:\src *Outputs* : :math:\dst , :math:\mu , :math:\sigma^2 *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\diffsrc Same as for :ref:dnnl_backward  :ref:dnnl_use_global_stats  *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\dst *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\diffsrc Same as for :ref:dnnl_backward  :ref:dnnl_use_scaleshift  *Inputs* : :math:\src , :math:\gamma , :math:\beta *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\gamma , :math:\beta *Outputs* : :math:\dst , :math:\mu , :math:\sigma^2 *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma , :math:\beta *Outputs* : :math:\diffsrc , :math:\diffgamma , :math:\diffbeta Not supported :ref:dnnl_use_scale  *Inputs* : :math:\src , :math:\gamma *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\gamma *Outputs* : :math:\dst , :math:\mu , :math:\sigma^2 *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma *Outputs* : :math:\diffsrc , :math:\diffgamma Not supported :ref:dnnl_use_shift  *Inputs* : :math:\src , :math:\beta *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\beta *Outputs* : :math:\dst , :math:\mu , :math:\sigma^2 *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 , :math:\beta *Outputs* : :math:\diffsrc , :math:\diffbeta Not supported :ref:dnnl_use_global_stats  | :ref:dnnl_use_scaleshift  *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma , :math:\beta *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma , :math:\beta *Outputs* : :math:\dst *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma , :math:\beta *Outputs* : :math:\diffsrc , :math:\diffgamma , :math:\diffbeta Not supported flags | :ref:dnnl_fuse_norm_relu  *Inputs* : same as with flags *Outputs* : same as with flags *Inputs* : same as with flags *Outputs* : same as with flags , :ref:Workspace  *Inputs* : same as with flags , :ref:Workspace  *Outputs* : same as with flags Same as for :ref:dnnl_backward  if flags do not contain :ref:dnnl_use_scaleshift  ; not supported otherwise ================================================================================================================================================================================================================================================================================================ ================================================================================================================================================= ============================================================================================================================================================================================= ============================================================================================================================================================================================= ======================================================================================================================================================================================================================================================================================================================================================== When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table. ============================== ========================== Primitive Input/Output Execution Argument Index ============================== ========================== :math:\src DNNL_ARG_SRC :math:\gamma, \beta DNNL_ARG_SCALE_SHIFT :math:\gamma DNNL_ARG_SCALE :math:\beta DNNL_ARG_SHIFT mean ( :math:\mu ) DNNL_ARG_MEAN variance ( :math:\sigma^2 ) DNNL_ARG_VARIANCE :math:\dst DNNL_ARG_DST workspace DNNL_ARG_WORKSPACE :math:\diffdst DNNL_ARG_DIFF_DST :math:\diffsrc DNNL_ARG_DIFF_SRC :math:\diffgamma, \diffbeta DNNL_ARG_DIFF_SCALE_SHIFT :math:\diffgamma DNNL_ARG_DIFF_SCALE :math:\diffbeta DNNL_ARG_DIFF_SHIFT ============================== ========================== Implementation Details ~~~~~~~~~~~~~~~~~~~~~~ General Notes ------------- #. The different flavors of the primitive are partially controlled by the flags parameter that is passed to the operation descriptor initialization function (e.g., :ref:dnnl::batch_normalization_forward::desc::desc() ). Multiple flags can be set using the bitwise OR operator (|). Flag :ref:dnnl_use_scaleshift  can not be mixed with :ref:dnnl_use_scale  or :ref:dnnl_use_shift . #. For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the :ref:dnnl_use_global_stats  flag. For the backward propagation, the mean and variance are always input parameters. #. The memory format and data type for src and dst are assumed to be the same, and in the API they are typically referred to as data (e.g., see data_desc in :ref:dnnl::batch_normalization_forward::desc::desc() ). The same is true for diff_src and diff_dst. The corresponding memory descriptors are referred to as diff_data_desc. #. Both forward and backward propagation support in-place operations, meaning that :math:\src can be used as input and output for forward propagation, and :math:\diffdst can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. Note, however, that backward propagation requires original :math:\src, hence the corresponding forward propagation should not be performed in-place. #. As mentioned above, the batch normalization primitive can be fused with ReLU activation even in the training mode. In this case, on the forward propagation the primitive has one additional output, workspace, that should be passed during the backward propagation. Data Type Support ----------------- The operation supports the following combinations of data types: =================== ===================== ============================= Propagation Source / Destination Mean / Variance / ScaleShift =================== ===================== ============================= forward / backward f32, bf16 f32 forward f16 f32 forward s8 f32 =================== ===================== ============================= .. warning:: There might be hardware- or implementation-specific restrictions. Check the :ref:Implementation Limitations  section below. Data Representation ------------------- Mean and Variance +++++++++++++++++ The mean (:math:\mu) and variance (:math:\sigma^2) are separate 1D tensors of size :math:C. The format of the corresponding memory object must be :ref:dnnl_x  (:ref:dnnl_a ). Scale and Shift +++++++++++++++ If :ref:dnnl_use_scaleshift  is used, the scale (:math:\gamma) and shift (:math:\beta) are combined in a single 2D tensor of shape :math:2 \times C. If :ref:dnnl_use_scale  or :ref:dnnl_use_shift  are used, the scale (:math:\gamma) and shift (:math:\beta) are separate 1D tensors of shape :math:C. The format of the corresponding memory object must be :ref:dnnl_nc  (:ref:dnnl_ab ). Source, Destination, and Their Gradients ++++++++++++++++++++++++++++++++++++++++ Like other CNN primitives, the batch normalization primitive expects data to be :math:N \times C \times SP_n \times \cdots \times SP_0 tensor. The batch normalization primitive is optimized for the following memory formats: ======== =============== =========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== Spatial Logical tensor Implementations optimized for memory formats ======== =============== =========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== 0D NC :ref:dnnl_nc  ( :ref:dnnl_ab  ) 1D NCW :ref:dnnl_ncw  ( :ref:dnnl_abc  ), :ref:dnnl_nwc  ( :ref:dnnl_acb  ), *optimized^* 2D NCHW :ref:dnnl_nchw  ( :ref:dnnl_abcd  ), :ref:dnnl_nhwc  ( :ref:dnnl_acdb  ), *optimized^* 3D NCDHW :ref:dnnl_ncdhw  ( :ref:dnnl_abcde  ), :ref:dnnl_ndhwc  ( :ref:dnnl_acdeb  ), *optimized^* ======== =============== =========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== Here optimized^ means the format that :ref:comes out  of any preceding compute-intensive primitive. Post-Ops and Attributes ----------------------- Post-ops and attributes enable you to modify the behavior of the batch normalization primitive by chaining certain operations after the batch normalization operation. The following post-ops are supported by batch normalization primitives: ============ ======== ========== ======================================================================================================================================================================================================================================================================== Propagation Type Operation Description ============ ======== ========== ======================================================================================================================================================================================================================================================================== forward post-op eltwise Applies an :ref:Eltwise  operation to the result (currently only :ref:dnnl_eltwise_relu  algorithm is supported) ============ ======== ========== ======================================================================================================================================================================================================================================================================== .. note:: As mentioned in :ref:Primitive Attributes , the post-ops should be used for inference only. For instance, using ReLU as a post-op would not produce the additional output workspace that is required to compute backward propagation correctly. Hence, in case of training one should use the :ref:dnnl_fuse_norm_relu  directly. :target:doxid-dev_guide_batch_normalization_1dg_bnorm_impl_limits Implementation Limitations ~~~~~~~~~~~~~~~~~~~~~~~~~~ #. Refer to :ref:Data Types  for limitations related to data types support. #. For the data types that have forward propagation support only, mean and variance must be provided by a user (i.e., :ref:dnnl_use_global_stats  is set). #. GPU * ReLU eltwise post-op doesn't support non-zero :math:\alpha parameter. Performance Tips ~~~~~~~~~~~~~~~~ #. For backward propagation, use the same memory format for src, diff_dst, and diff_src (the format of the diff_dst and diff_src are always the same because of the API). Different formats are functionally supported but lead to highly suboptimal performance. #. Use in-place operations whenever possible (see caveats in General Notes). #. GPU implementations support an experimental algorithm with single pass statistics calculations. Please review :ref:experimental features  for more details. Examples ~~~~~~~~ :ref:Batch Normalization Primitive Example  This C++ API example demonstrates how to create and execute a :ref:Batch Normalization  primitive in forward training propagation mode. Key optimizations included in this example: * In-place primitive execution; * Source memory format for an optimized primitive implementation; * Fused post-ops via operation descriptor flags;