memory_format_propagation.cpp

This example demonstrates memory format propagation, which is critical for deep learning applications performance. Annotated version: Memory Format Propagation

This example demonstrates memory format propagation, which is critical for deep learning applications performance. Annotated version: Memory Format Propagation

/*******************************************************************************
* Copyright 2019-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/


#include <iostream>
#include <sstream>
#include <string>


#include "oneapi/dnnl/dnnl.hpp"

#include "example_utils.hpp"

using namespace dnnl;

void memory_format_propagation_tutorial(engine::kind engine_kind) {
    // [Initialize engine and stream]
    engine eng(engine_kind, 0);
    stream s(eng);
    // [Initialize engine and stream]

    // [Create placeholder memory descriptors]
    // Tensor and kernel dimensions. We use the same 3x3 kernel with padding=1
    // for both convolution and pooling primitives, which means that the
    // activation tensor shapes do not change.
    const int N = 1, H = 14, W = 14, IC = 128, OC = 256, KH = 3, KW = 3;
    auto conv_src_md = memory::desc({N, IC, H, W}, memory::data_type::f32,
            memory::format_tag::any // let convolution choose memory format
    );
    auto conv_weights_md = memory::desc(
            {OC, IC, KH, KW}, memory::data_type::f32,
            memory::format_tag::any // let convolution choose memory format
    );
    auto conv_dst_md = memory::desc({N, OC, H, W}, memory::data_type::f32,
            memory::format_tag::any // let convolution choose memory format
    );
    auto pool_dst_md = conv_dst_md; // shape does not change
    // [Create placeholder memory descriptors]

    // [Create convolution and pooling primitive descriptors]
    auto conv_pd = convolution_forward::primitive_desc(
            {prop_kind::forward_inference, algorithm::convolution_auto,
                    conv_src_md, conv_weights_md,
                    conv_dst_md, // shape information
                    {1, 1}, // strides
                    {1, 1}, {1, 1}}, // left and right padding
            eng);
    auto pool_pd = pooling_forward::primitive_desc(
            {prop_kind::forward_inference, algorithm::pooling_max,
                    conv_pd.dst_desc(), pool_dst_md, // shape information
                    {1, 1}, {KH, KW}, // strides and kernel
                    {1, 1}, {1, 1}}, // left and right padding
            eng);
    // [Create convolution and pooling primitive descriptors]

    // [Create source and destination memory objects]
    auto src_mem = memory(
            {{N, IC, H, W}, memory::data_type::f32, memory::format_tag::nchw},
            eng);
    auto weights_mem = memory({{OC, IC, KH, KW}, memory::data_type::f32,
                                      memory::format_tag::oihw},
            eng);
    auto dst_mem = memory(
            {{N, OC, H, W}, memory::data_type::f32, memory::format_tag::nchw},
            eng);
    // [Create source and destination memory objects]

    // [Determine if source needs to be reordered]
    bool need_reorder_src = conv_pd.src_desc() != src_mem.get_desc();
    // [Determine if source needs to be reordered]

    // [Determine if weights and destination need to be reordered]
    bool need_reorder_weights
            = conv_pd.weights_desc() != weights_mem.get_desc();
    bool need_reorder_dst = conv_pd.dst_desc() != dst_mem.get_desc();
    // [Determine if weights and destination need to be reordered]

    // [Allocate intermediate buffers if necessary]
    auto conv_src_mem
            = need_reorder_src ? memory(conv_pd.src_desc(), eng) : src_mem;
    auto conv_weights_mem = need_reorder_weights
            ? memory(conv_pd.weights_desc(), eng)
            : weights_mem;
    auto conv_dst_mem = memory(conv_pd.dst_desc(), eng);
    auto pool_dst_mem
            = need_reorder_dst ? memory(pool_pd.dst_desc(), eng) : dst_mem;
    // [Allocate intermediate buffers if necessary]

    // [Perform reorders for source data if necessary]
    if (need_reorder_src) {
        auto reorder_src = reorder(src_mem, conv_src_mem);
        reorder_src.execute(
                s, {{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, conv_src_mem}});
        s.wait(); // wait for the reorder to complete
    }

    if (need_reorder_weights) {
        auto reorder_weights = reorder(weights_mem, conv_weights_mem);
        reorder_weights.execute(s,
                {{DNNL_ARG_FROM, weights_mem},
                        {DNNL_ARG_TO, conv_weights_mem}});
        s.wait(); // wait for the reorder to complete
    }
    // [Perform reorders for source data if necessary]

    // [Create and execute convolution and pooling primitives]
    auto conv_scratchpad_mem = memory(conv_pd.scratchpad_desc(), eng);
    auto conv = convolution_forward(conv_pd);
    conv.execute(s,
            {{DNNL_ARG_SRC, conv_src_mem}, {DNNL_ARG_WEIGHTS, conv_weights_mem},
                    {DNNL_ARG_DST, conv_dst_mem}});
    auto pool_scratchpad_mem = memory(pool_pd.scratchpad_desc(), eng);
    auto pool = pooling_forward(pool_pd);
    pool.execute(
            s, {{DNNL_ARG_SRC, conv_dst_mem}, {DNNL_ARG_DST, pool_dst_mem}});
    s.wait();
    // [Create and execute convolution and pooling primitives]

    // [Reorder destination data if necessary]
    if (need_reorder_dst) {
        auto reorder_dst = reorder(pool_dst_mem, dst_mem);
        reorder_dst.execute(
                s, {{DNNL_ARG_FROM, pool_dst_mem}, {DNNL_ARG_TO, dst_mem}});
        s.wait();
    }
    // [Reorder destination data if necessary]
}

int main(int argc, char **argv) {
    return handle_example_errors(
            memory_format_propagation_tutorial, parse_engine_kind(argc, argv));
}