Deep Neural Network Library (DNNL)  1.1.3
Performance library for Deep Learning
dnnl.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
19 
20 #ifndef DNNL_HPP
21 #define DNNL_HPP
22 
23 #include "dnnl_config.h"
24 
26 #include <algorithm>
27 #include <cstdlib>
28 #include <iterator>
29 #include <memory>
30 #include <vector>
31 #include <unordered_map>
32 
33 #include "dnnl.h"
34 
35 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
36 #include <CL/cl.h>
37 #endif
38 
40 namespace dnnl {
41 
44 
47 
52 struct error : public std::exception {
53  dnnl_status_t status;
54  const char *message;
55 
60  error(dnnl_status_t astatus, const char *amessage)
61  : status(astatus), message(amessage) {}
62 
64  const char *what() const noexcept override { return message; }
65 
71  static void wrap_c_api(dnnl_status_t status, const char *message) {
72  if (status != dnnl_success) throw error(status, message);
73  }
74 };
75 
77 template <typename T>
78 class handle_traits {};
79 
93 template <typename T, typename traits = handle_traits<T>>
94 class handle {
95 private:
96  static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
97 
98  std::shared_ptr<typename std::remove_pointer<T>::type> _data {0};
99 
100 protected:
101  bool operator==(const T other) const { return other == _data.get(); }
102  bool operator!=(const T other) const { return !(*this == other); }
103 
104 public:
114  handle() = default;
115  handle(const handle<T, traits> &) = default;
116  handle(handle<T, traits> &&) = default;
117  handle<T, traits> &operator=(handle<T, traits> &&) = default;
118  handle<T, traits> &operator=(const handle<T, traits> &) = default;
119 
123  explicit handle(T t, bool weak = false) { reset(t, weak); }
124 
128  void reset(T t, bool weak = false) {
129  _data.reset(t, weak ? &dummy_destructor : traits::destructor);
130  }
131 
133  T get(bool allow_emtpy = false) const {
134  T result = _data.get();
135 
136  if (allow_emtpy == false && result == nullptr)
138  "attempt to use uninitialized object");
139 
140  return result;
141  }
142 
143  explicit operator T() const { return get(true); }
144 
145  explicit operator bool() const { return get(true) != nullptr; }
146 
147  bool operator==(const handle &other) const {
148  return other._data.get() == _data.get();
149  }
150  bool operator!=(const handle &other) const { return !(*this == other); }
151 };
152 
154 template <>
155 struct handle_traits<dnnl_memory_t> {
156  static constexpr auto destructor = &dnnl_memory_destroy;
157 };
158 
159 template <>
160 struct handle_traits<dnnl_primitive_desc_t> {
161  static constexpr auto destructor = &dnnl_primitive_desc_destroy;
162 };
163 
164 template <>
165 struct handle_traits<dnnl_primitive_t> {
166  static constexpr auto destructor = &dnnl_primitive_destroy;
167 };
168 
169 template <>
170 struct handle_traits<dnnl_primitive_desc_iterator_t> {
171  static constexpr auto destructor = &dnnl_primitive_desc_iterator_destroy;
172 };
174 
175 struct stream;
176 struct error;
177 struct memory;
178 struct primitive_desc;
179 
181 class primitive : public handle<dnnl_primitive_t> {
182  friend struct error;
183  friend struct stream;
184  using handle::handle;
185 
186 public:
189  enum class kind {
199  sum = dnnl_sum,
211  lrn = dnnl_lrn,
219  rnn = dnnl_rnn,
222  };
223 
225  primitive(const primitive_desc &pd);
226 
229  // TODO: use the C++ API wrapper structure.
230 
231  void execute(
232  stream &astream, const std::unordered_map<int, memory> &args) const;
233 };
234 
235 inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) {
236  return static_cast<dnnl_primitive_kind_t>(akind);
237 }
238 
242  "could not get primitive descriptor by primitive");
243  return pd;
244 }
246 
251 
253 enum class scratchpad_mode {
258 };
259 
260 inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
261  return static_cast<dnnl_scratchpad_mode_t>(mode);
262 }
263 
265 enum class prop_kind {
289 };
290 
291 inline dnnl_prop_kind_t convert_to_c(prop_kind kind) {
292  return static_cast<dnnl_prop_kind_t>(kind);
293 }
294 
296 enum class algorithm {
297  undef = dnnl_alg_kind_undef,
342  pooling_avg = dnnl_pooling_avg,
365 };
366 
367 inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) {
368  return static_cast<dnnl_alg_kind_t>(aalgorithm);
369 }
370 
372 enum class normalization_flags : unsigned {
385 
400 
409 };
410 
411 inline dnnl_normalization_flags_t convert_to_c(normalization_flags aflag) {
412  return static_cast<dnnl_normalization_flags_t>(aflag);
413 }
414 
415 enum class rnn_flags : unsigned { undef = dnnl_rnn_flags_undef };
416 
417 inline dnnl_rnn_flags_t convert_to_c(rnn_flags aflag) {
418  return static_cast<dnnl_rnn_flags_t>(aflag);
419 }
420 
421 #define DNNL_DEFINE_BITMASK_OPS(enum_name) \
422  inline enum_name operator|(enum_name lhs, enum_name rhs) { \
423  return static_cast<enum_name>( \
424  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
425  } \
426 \
427  inline enum_name operator&(enum_name lhs, enum_name rhs) { \
428  return static_cast<enum_name>( \
429  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
430  } \
431 \
432  inline enum_name operator^(enum_name lhs, enum_name rhs) { \
433  return static_cast<enum_name>( \
434  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
435  } \
436 \
437  inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
438  lhs = static_cast<enum_name>( \
439  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
440  return lhs; \
441  } \
442 \
443  inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
444  lhs = static_cast<enum_name>( \
445  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
446  return lhs; \
447  } \
448 \
449  inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
450  lhs = static_cast<enum_name>( \
451  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
452  return lhs; \
453  } \
454 \
455  inline enum_name operator~(enum_name rhs) { \
456  return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
457  }
458 
459 DNNL_DEFINE_BITMASK_OPS(normalization_flags)
460 DNNL_DEFINE_BITMASK_OPS(rnn_flags)
461 
462 #undef DNNL_DEFINE_BITMASK_OPS
463 
464 enum class rnn_direction {
465  unidirectional_left2right = dnnl_unidirectional_left2right,
466  unidirectional_right2left = dnnl_unidirectional_right2left,
467  unidirectional = dnnl_unidirectional,
468  bidirectional_concat = dnnl_bidirectional_concat,
469  bidirectional_sum = dnnl_bidirectional_sum,
470 };
471 
472 inline dnnl_rnn_direction_t convert_to_c(rnn_direction adir) {
473  return static_cast<dnnl_rnn_direction_t>(adir);
474 }
475 
483 enum class query {
486 
491 
496 
505 
510 
515 
518 
545 
562 };
563 
564 inline dnnl_query_t convert_to_c(query aquery) {
565  return static_cast<dnnl_query_t>(aquery);
566 }
567 
569 
575 
577 template <>
578 struct handle_traits<dnnl_post_ops_t> {
579  static constexpr auto destructor = &dnnl_post_ops_destroy;
580 };
582 
586 struct post_ops : public handle<dnnl_post_ops_t> {
588 
591  dnnl_post_ops_t result;
593  "could not create post operation sequence");
594  reset(result);
595  }
596 
598  int len() const { return dnnl_post_ops_len(get()); }
599 
601  primitive::kind kind(int index) const {
603  "post_ops index is out of range");
604  return static_cast<primitive::kind>(
605  dnnl_post_ops_get_kind(get(), index));
606  }
607 
628  void append_sum(float scale = 1.) {
630  dnnl_post_ops_append_sum(get(), scale), "could not append sum");
631  }
632 
635  void get_params_sum(int index, float &scale) const {
636  error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale),
637  "could not get sum params");
638  }
639 
648  void append_eltwise(float scale, algorithm alg, float alpha, float beta) {
650  get(), scale, convert_to_c(alg), alpha, beta),
651  "could not append eltwise");
652  }
653 
655  void get_params_eltwise(int index, float &scale, algorithm &alg,
656  float &alpha, float &beta) const {
657  dnnl_alg_kind_t c_alg;
659  get(), index, &scale, &c_alg, &alpha, &beta),
660  "could not get eltwise params");
661  alg = static_cast<algorithm>(c_alg);
662  }
663 };
664 
666 template <>
667 struct handle_traits<dnnl_primitive_attr_t> {
668  static constexpr auto destructor = &dnnl_primitive_attr_destroy;
669 };
671 
675 struct primitive_attr : public handle<dnnl_primitive_attr_t> {
677 
680  dnnl_primitive_attr_t result;
682  "could not create a primitive attr");
683  reset(result);
684  }
685 
690  : handle<dnnl_primitive_attr_t>(attr) {}
691 
694  dnnl_scratchpad_mode_t result;
697  "could not get scratchpad mode");
698  return scratchpad_mode(result);
699  }
700 
704  get(), dnnl::convert_to_c(mode)),
705  "could not set scratchpad mode");
706  }
707 
710  void get_output_scales(int &mask, std::vector<float> &scales) const {
711  dnnl_dim_t count;
712  int c_mask;
713  const float *c_scales;
715  get(), &count, &c_mask, &c_scales),
716  "could not get int output scales");
717  scales.resize(count);
718 
719  mask = c_mask;
720  for (dnnl_dim_t c = 0; c < count; ++c)
721  scales[c] = c_scales[c];
722  }
723 
739  void set_output_scales(int mask, const std::vector<float> &scales) {
741  (dnnl_dim_t)scales.size(), mask, &scales[0]),
742  "could not set int output scales");
743  }
744 
746  const post_ops get_post_ops() const {
747  post_ops result;
748  const_dnnl_post_ops_t c_result;
750  "could not get post operation sequence");
751  result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
752  return result;
753  }
754 
756  void set_post_ops(post_ops ops) {
758  "could not set post operation sequence");
759  }
760 
769  void set_rnn_data_qparams(float scale, float shift) {
771  dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
772  "could not set rnn data int scale/shift");
773  }
774 
798  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
800  get(), (int)scales.size(), mask, &scales[0]),
801  "could not set rnn weights int scales");
802  }
803 };
804 
806 
812 
814 template <>
815 struct handle_traits<dnnl_engine_t> {
816  static constexpr auto destructor = &dnnl_engine_destroy;
817 };
819 
821 struct engine : public handle<dnnl_engine_t> {
822  friend class primitive;
823  friend struct reorder;
824 
826  enum class kind {
830  cpu = dnnl_cpu,
832  gpu = dnnl_gpu,
833  };
834 
835  engine() = default;
836 
840  static size_t get_count(kind akind) {
841  return dnnl_engine_get_count(convert_to_c(akind));
842  }
843 
850  engine(kind akind, size_t index) {
851  dnnl_engine_t aengine;
853  dnnl_engine_create(&aengine, convert_to_c(akind), index),
854  "could not create an engine");
855  reset(aengine);
856  }
857 
858 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
859  engine(kind akind, cl_device_id device, cl_context context) {
862  dnnl_engine_t aengine;
863  error::wrap_c_api(dnnl_engine_create_ocl(&aengine, convert_to_c(akind),
864  device, context),
865  "could not create an engine");
866  reset(aengine);
867  }
868 #endif
869 
871  explicit engine(const dnnl_engine_t &aengine) : handle(aengine, true) {}
872 
876  dnnl_engine_t engine_q;
879  dnnl::convert_to_c(dnnl::query::engine), 0, &engine_q),
880  "could not get engine from primitive_desc");
881  reset(engine_q, true);
882  }
883 
885  kind get_kind() const {
886  dnnl_engine_kind_t akind;
888  "could not get the engine kind");
889  return static_cast<engine::kind>(akind);
890  }
891 
892 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
893  cl_context get_ocl_context() const {
895  cl_context context = nullptr;
897  "could not get a context handle");
898  return context;
899  }
900 
902  cl_device_id get_ocl_device() const {
903  cl_device_id device = nullptr;
905  "could not get a device handle");
906  return device;
907  }
908 #endif
909 
910  template <class primitive_desc>
911  static engine query(const primitive_desc &pd) {
912  return query(pd, dnnl::query::engine);
913  }
914 
915 private:
916  static dnnl_engine_kind_t convert_to_c(kind akind) {
917  return static_cast<dnnl_engine_kind_t>(akind);
918  }
919 
920  template <class primitive_desc>
921  static engine query(const primitive_desc &pd, dnnl::query what) {
922  dnnl_engine_t engine_q;
924  dnnl::convert_to_c(what), 0, &engine_q),
925  "could not get engine from primitive_desc");
926 
927  return engine(engine_q);
928  }
929 };
930 
932 
938 
940 template <>
941 struct handle_traits<dnnl_stream_t> {
942  static constexpr auto destructor = &dnnl_stream_destroy;
943 };
945 
947 struct stream : public handle<dnnl_stream_t> {
948  using handle::handle;
949 
951  enum class flags : unsigned {
961  };
962 
963  stream() = default;
964 
966  stream(const engine &aengine, flags aflags = flags::default_flags) {
967  dnnl_stream_t astream;
968  error::wrap_c_api(dnnl_stream_create(&astream, aengine.get(),
969  static_cast<dnnl_stream_flags_t>(aflags)),
970  "could not create a stream");
971  reset(astream);
972  }
973 
974 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
975  stream(const engine &eng, cl_command_queue queue) {
978  dnnl_stream_t astream;
979  error::wrap_c_api(dnnl_stream_create_ocl(&astream, eng.get(), queue),
980  "could not create a stream");
981  reset(astream);
982  }
983 
985  cl_command_queue get_ocl_command_queue() const {
986  cl_command_queue queue = nullptr;
988  "could not get OpenCL command queue");
989  return queue;
990  }
991 #endif
992 
995  error::wrap_c_api(dnnl_stream_wait(get()), "could not wait a stream");
996  return *this;
997  }
998 };
999 
1000 inline stream::flags operator|(stream::flags lhs, stream::flags rhs) {
1001  return static_cast<stream::flags>(
1002  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
1003 }
1004 
1005 inline stream::flags operator&(stream::flags lhs, stream::flags rhs) {
1006  return static_cast<stream::flags>(
1007  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs));
1008 }
1009 
1010 inline stream::flags operator^(stream::flags lhs, stream::flags rhs) {
1011  return static_cast<stream::flags>(
1012  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs));
1013 }
1014 
1015 inline stream::flags operator~(stream::flags rhs) {
1016  return static_cast<stream::flags>(~static_cast<unsigned>(rhs));
1017 }
1018 
1020 
1023 
1029 
1031 struct memory : public handle<dnnl_memory_t> {
1032  typedef dnnl_dim_t dim;
1033  typedef std::vector<dim> dims;
1034 
1035  template <typename T>
1036  static void validate_dims(const std::vector<T> &v) {
1037  if (v.size() > DNNL_MAX_NDIMS)
1038  throw error(dnnl_invalid_arguments, "invalid dimensions");
1039  }
1040 
1042  enum class data_type {
1046  f16 = dnnl_f16,
1048  bf16 = dnnl_bf16,
1050  f32 = dnnl_f32,
1052  s32 = dnnl_s32,
1054  s8 = dnnl_s8,
1056  u8 = dnnl_u8,
1057  };
1058 
1060  enum class format_kind {
1074  };
1075 
1078  enum class format_tag {
1084 
1085  // Semantic agnostic section
1086  // The physical order of dimensions is defined by the permutation of the
1087  // characters, assuming that ab..z defines the natural order.
1088 
1089  // Plain formats
1090 
1091  a = dnnl_a,
1092  ab = dnnl_ab,
1093  abc = dnnl_abc,
1094  abcd = dnnl_abcd,
1095  abcde = dnnl_abcde,
1096  abcdef = dnnl_abcdef,
1097 
1098  // Permuted plain formats
1099 
1100  abdec = dnnl_abdec,
1101  acb = dnnl_acb,
1102  acbde = dnnl_acbde,
1103  acdb = dnnl_acdb,
1104  acdeb = dnnl_acdeb,
1105  ba = dnnl_ba,
1106  bac = dnnl_bac,
1107  bacd = dnnl_bacd,
1108  bcda = dnnl_bcda,
1109  cba = dnnl_cba,
1110  cdba = dnnl_cdba,
1111  cdeba = dnnl_cdeba,
1112  decab = dnnl_decab,
1113 
1114  // Opaque blocked formats
1115 
1116  Abc16a = dnnl_Abc16a,
1117  ABc16a16b = dnnl_ABc16a16b,
1118  aBc16b = dnnl_aBc16b,
1119  ABc16b16a = dnnl_ABc16b16a,
1120  Abc4a = dnnl_Abc4a,
1121  aBc4b = dnnl_aBc4b,
1122  ABc4b16a4b = dnnl_ABc4b16a4b,
1123  ABc4b4a = dnnl_ABc4b4a,
1124  ABc8a16b2a = dnnl_ABc8a16b2a,
1125  ABc8a8b = dnnl_ABc8a8b,
1126  aBc8b = dnnl_aBc8b,
1127  ABc8b16a2b = dnnl_ABc8b16a2b,
1128  ABc8b8a = dnnl_ABc8b8a,
1129  Abcd16a = dnnl_Abcd16a,
1130  ABcd16a16b = dnnl_ABcd16a16b,
1131  aBcd16b = dnnl_aBcd16b,
1132  ABcd16b16a = dnnl_ABcd16b16a,
1133  aBCd16b16c = dnnl_aBCd16b16c,
1134  aBCd16c16b = dnnl_aBCd16c16b,
1135  Abcd4a = dnnl_Abcd4a,
1136  aBcd4b = dnnl_aBcd4b,
1137  ABcd4b16a4b = dnnl_ABcd4b16a4b,
1138  ABcd4b4a = dnnl_ABcd4b4a,
1139  aBCd4c16b4c = dnnl_aBCd4c16b4c,
1140  aBCd4c4b = dnnl_aBCd4c4b,
1141  ABcd8a16b2a = dnnl_ABcd8a16b2a,
1142  ABcd8a8b = dnnl_ABcd8a8b,
1144  aBcd8b = dnnl_aBcd8b,
1145  ABcd8b16a2b = dnnl_ABcd8b16a2b,
1146  aBCd8b16c2b = dnnl_aBCd8b16c2b,
1149  aBCd8b8c = dnnl_aBCd8b8c,
1150  aBCd8c16b2c = dnnl_aBCd8c16b2c,
1151  aBCd8c8b = dnnl_aBCd8c8b,
1152  Abcde16a = dnnl_Abcde16a,
1153  ABcde16a16b = dnnl_ABcde16a16b,
1154  aBcde16b = dnnl_aBcde16b,
1155  ABcde16b16a = dnnl_ABcde16b16a,
1156  aBCde16b16c = dnnl_aBCde16b16c,
1157  aBCde16c16b = dnnl_aBCde16c16b,
1158  aBCde2c8b4c = dnnl_aBCde2c8b4c,
1159  Abcde4a = dnnl_Abcde4a,
1160  aBcde4b = dnnl_aBcde4b,
1161  ABcde4b4a = dnnl_ABcde4b4a,
1162  aBCde4b4c = dnnl_aBCde4b4c,
1163  aBCde4c16b4c = dnnl_aBCde4c16b4c,
1164  aBCde4c4b = dnnl_aBCde4c4b,
1165  Abcde8a = dnnl_Abcde8a,
1166  ABcde8a8b = dnnl_ABcde8a8b,
1167  aBcde8b = dnnl_aBcde8b,
1168  ABcde8b16a2b = dnnl_ABcde8b16a2b,
1169  aBCde8b16c2b = dnnl_aBCde8b16c2b,
1170  ABcde8b8a = dnnl_ABcde8b8a,
1171  aBCde8b8c = dnnl_aBCde8b8c,
1172  ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1173  ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1174  aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1175  aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1176  aBCde8c16b2c = dnnl_aBCde8c16b2c,
1177  aBCde8c8b = dnnl_aBCde8c8b,
1178  aBcdef16b = dnnl_aBcdef16b,
1179  aBCdef16b16c = dnnl_aBCdef16b16c,
1180  aBCdef16c16b = dnnl_aBCdef16c16b,
1181  aBcdef4b = dnnl_aBcdef4b,
1182  aBCdef4c4b = dnnl_aBCdef4c4b,
1183  aBCdef8b8c = dnnl_aBCdef8b8c,
1184  aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1185  aBCdef8c8b = dnnl_aBCdef8c8b,
1186  aBdc16b = dnnl_aBdc16b,
1187  aBdc4b = dnnl_aBdc4b,
1188  aBdc8b = dnnl_aBdc8b,
1189  aBdec16b = dnnl_aBdec16b,
1190  aBdec4b = dnnl_aBdec4b,
1191  aBdec8b = dnnl_aBdec8b,
1192  aBdefc16b = dnnl_aBdefc16b,
1193  aCBdef16c16b = dnnl_aCBdef16c16b,
1194  aBdefc4b = dnnl_aBdefc4b,
1195  aBdefc8b = dnnl_aBdefc8b,
1196  Acb16a = dnnl_Acb16a,
1197  Acb4a = dnnl_Acb4a,
1198  Acb8a = dnnl_Acb8a,
1199  aCBd16b16c = dnnl_aCBd16b16c,
1200  aCBd16c16b = dnnl_aCBd16c16b,
1201  aCBde16b16c = dnnl_aCBde16b16c,
1202  aCBde16c16b = dnnl_aCBde16c16b,
1203  Acdb16a = dnnl_Acdb16a,
1204  Acdb4a = dnnl_Acdb4a,
1205  Acdb8a = dnnl_Acdb8a,
1206  Acdeb16a = dnnl_Acdeb16a,
1207  Acdeb4a = dnnl_Acdeb4a,
1208  Acdeb8a = dnnl_Acdeb8a,
1209  BAc16a16b = dnnl_BAc16a16b,
1210  BAc16b16a = dnnl_BAc16b16a,
1211  BAcd16a16b = dnnl_BAcd16a16b,
1212  BAcd16b16a = dnnl_BAcd16b16a,
1213  ABcd32a32b = dnnl_ABcd32a32b,
1214  BAcde16b16 = dnnl_BAcde16b16a,
1215  aBdec32b = dnnl_aBdec32b,
1216  Abcdef16a = dnnl_Abcdef16a,
1217  Acdb32a = dnnl_Acdb32a,
1218  format_tag_last = dnnl_format_tag_last,
1219 
1220  x = dnnl_x,
1223  nc = dnnl_nc,
1224  cn = dnnl_cn,
1225  tn = dnnl_tn,
1226  nt = dnnl_nt,
1227  ncw = dnnl_ncw,
1228  nwc = dnnl_nwc,
1231  nchw = dnnl_nchw,
1234  nhwc = dnnl_nhwc,
1237  chwn = dnnl_chwn,
1238  ncdhw = dnnl_ncdhw,
1239  ndhwc = dnnl_ndhwc,
1240  oi = dnnl_oi,
1241  io = dnnl_io,
1242  oiw = dnnl_oiw,
1243  wio = dnnl_wio,
1244  oihw = dnnl_oihw,
1245  hwio = dnnl_hwio,
1246  ihwo = dnnl_ihwo,
1247  iohw = dnnl_iohw,
1248  oidhw = dnnl_oidhw,
1249  dhwio = dnnl_dhwio,
1250  goiw = dnnl_goiw,
1251  goihw = dnnl_goihw,
1252  hwigo = dnnl_hwigo,
1253  giohw = dnnl_giohw,
1254  goidhw = dnnl_goidhw,
1255  tnc = dnnl_tnc,
1256  ntc = dnnl_ntc,
1257  ldnc = dnnl_ldnc,
1258  ldigo = dnnl_ldigo,
1259  ldgoi = dnnl_ldgoi,
1260  ldgo = dnnl_ldgo,
1261  nCdhw16c = dnnl_nCdhw16c,
1262  nCdhw4c = dnnl_nCdhw4c,
1263  nCdhw8c = dnnl_nCdhw8c,
1264  nChw16c = dnnl_nChw16c,
1265  nChw4c = dnnl_nChw4c,
1266  nChw8c = dnnl_nChw8c,
1267  nCw16c = dnnl_nCw16c,
1268  nCw4c = dnnl_nCw4c,
1269  nCw8c = dnnl_nCw8c,
1270  NCw16n16c = dnnl_NCw16n16c,
1271  NChw16n16c = dnnl_NChw16n16c,
1272  NCdhw16n16c = dnnl_NCdhw16n16c,
1273  NChw32n32c = dnnl_NChw32n32c,
1274  IOhw16i16o = dnnl_IOhw16i16o,
1275  Ohwi32o = dnnl_Ohwi32o,
1276  IOdhw16i16o = dnnl_IOdhw16i16o,
1277  gIOhw16i16o = dnnl_gIOhw16i16o,
1278  gOhwi32o = dnnl_gOhwi32o,
1279  Goidhw16g = dnnl_Goidhw16g,
1280  IOw16o16i = dnnl_IOw16o16i,
1281  OIw16i16o = dnnl_OIw16i16o,
1282  IOw16i16o = dnnl_IOw16i16o,
1283  gIOw16i16o = dnnl_gIOw16i16o,
1284  OIw16o16i = dnnl_OIw16o16i,
1285  Oiw16o = dnnl_Oiw16o,
1286  OIw4i16o4i = dnnl_OIw4i16o4i,
1287  OIw4i4o = dnnl_OIw4i4o,
1288  Oiw4o = dnnl_Oiw4o,
1289  OIw8i16o2i = dnnl_OIw8i16o2i,
1290  OIw8i8o = dnnl_OIw8i8o,
1291  OIw8o16i2o = dnnl_OIw8o16i2o,
1292  OIw8o8i = dnnl_OIw8o8i,
1293  Owi16o = dnnl_Owi16o,
1294  Owi4o = dnnl_Owi4o,
1295  Owi8o = dnnl_Owi8o,
1296  IOhw16o16i = dnnl_IOhw16o16i,
1297  Ohwi16o = dnnl_Ohwi16o,
1298  Ohwi4o = dnnl_Ohwi4o,
1299  Ohwi8o = dnnl_Ohwi8o,
1300  OIhw16i16o = dnnl_OIhw16i16o,
1301  OIhw16o16i = dnnl_OIhw16o16i,
1302  Oihw16o = dnnl_Oihw16o,
1303  OIhw4i16o4i = dnnl_OIhw4i16o4i,
1304  OIhw4i4o = dnnl_OIhw4i4o,
1305  Oihw4o = dnnl_Oihw4o,
1306  OIhw8i16o2i = dnnl_OIhw8i16o2i,
1307  OIhw8i8o = dnnl_OIhw8i8o,
1308  OIhw8o16i2o = dnnl_OIhw8o16i2o,
1309  OIhw8o8i = dnnl_OIhw8o8i,
1310  Odhwi16o = dnnl_Odhwi16o,
1311  Odhwi4o = dnnl_Odhwi4o,
1312  Odhwi8o = dnnl_Odhwi8o,
1313  OIdhw16i16o = dnnl_OIdhw16i16o,
1314  OIdhw16o16i = dnnl_OIdhw16o16i,
1315  Oidhw16o = dnnl_Oidhw16o,
1316  OIdhw4i4o = dnnl_OIdhw4i4o,
1317  Oidhw4o = dnnl_Oidhw4o,
1318  OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1319  OIdhw8i8o = dnnl_OIdhw8i8o,
1320  OIdhw8o8i = dnnl_OIdhw8o8i,
1321  gIOw16o16i = dnnl_gIOw16o16i,
1322  gOIw16i16o = dnnl_gOIw16i16o,
1323  gOIw16o16i = dnnl_gOIw16o16i,
1324  gOiw16o = dnnl_gOiw16o,
1325  gOIw4i16o4i = dnnl_gOIw4i16o4i,
1326  gOIw4i4o = dnnl_gOIw4i4o,
1327  gOiw4o = dnnl_gOiw4o,
1328  gOIw8i16o2i = dnnl_gOIw8i16o2i,
1329  gOIw8i8o = dnnl_gOIw8i8o,
1330  gOIw8o16i2o = dnnl_gOIw8o16i2o,
1331  gOIw8o8i = dnnl_gOIw8o8i,
1332  gOwi16o = dnnl_gOwi16o,
1333  gOwi4o = dnnl_gOwi4o,
1334  gOwi8o = dnnl_gOwi8o,
1335  gIOhw16o16i = dnnl_gIOhw16o16i,
1336  gOhwi16o = dnnl_gOhwi16o,
1337  gOhwi4o = dnnl_gOhwi4o,
1338  gOhwi8o = dnnl_gOhwi8o,
1339  Goihw16g = dnnl_Goihw16g,
1340  gOIhw16i16o = dnnl_gOIhw16i16o,
1341  gOIhw16o16i = dnnl_gOIhw16o16i,
1342  gOihw16o = dnnl_gOihw16o,
1343  gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1344  gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1345  gOIhw4i4o = dnnl_gOIhw4i4o,
1346  gOIhw4o4i = dnnl_gOIhw4o4i,
1347  gOihw4o = dnnl_gOihw4o,
1348  Goihw8g = dnnl_Goihw8g,
1349  gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1350  gOIhw8i8o = dnnl_gOIhw8i8o,
1351  gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1352  OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1353  OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1354  gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1355  gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1356  gOIhw8o8i = dnnl_gOIhw8o8i,
1357  gIOdhw16i16o = dnnl_gIOdhw16i16o,
1358  gOdhwi16o = dnnl_gOdhwi16o,
1359  gOdhwi4o = dnnl_gOdhwi4o,
1360  gOdhwi8o = dnnl_gOdhwi8o,
1361  gOIdhw16i16o = dnnl_gOIdhw16i16o,
1362  gOIdhw16o16i = dnnl_gOIdhw16o16i,
1363  gOidhw16o = dnnl_gOidhw16o,
1364  gOIdhw4i4o = dnnl_gOIdhw4i4o,
1365  gOidhw4o = dnnl_gOidhw4o,
1366  gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1367  gOIdhw8i8o = dnnl_gOIdhw8i8o,
1368  gOIdhw8o8i = dnnl_gOIdhw8o8i,
1369  };
1370 
1372  struct desc {
1373  friend struct memory;
1376 
1378  desc() : data() {}
1379 
1385  desc(const dims &adims, data_type adata_type, format_tag aformat_tag) {
1386  validate_dims(adims);
1388  dnnl_memory_desc_init_by_tag(&data, (int)adims.size(),
1389  adims.size() == 0 ? nullptr : &adims[0],
1390  convert_to_c(adata_type),
1391  convert_to_c(aformat_tag)),
1392  "could not initialize a memory descriptor by tag");
1393  }
1394 
1400  desc(const dims &adims, data_type adata_type, const dims &astrides) {
1401  validate_dims(adims);
1403  dnnl_memory_desc_init_by_strides(&data, (int)adims.size(),
1404  adims.size() == 0 ? nullptr : &adims[0],
1405  convert_to_c(adata_type),
1406  astrides.size() == 0 ? nullptr : &astrides[0]),
1407  "could not initialize a memory descriptor by strides");
1408  }
1409 
1413  desc(const dnnl_memory_desc_t &adata) : data(adata) {}
1414 
1416  //
1419  desc submemory_desc(const dims &adims, const dims &offsets) {
1420  dnnl_memory_desc_t sub_md;
1422  &sub_md, &data, &adims[0], &offsets[0]),
1423  "could not initialize a sub-memory");
1424  return desc(sub_md);
1425  }
1426 
1428  desc reshape(const dims &adims) {
1429  dnnl_memory_desc_t out_md;
1431  (int)adims.size(), &adims[0]),
1432  "could not reshape a memory descriptor");
1433  return desc(out_md);
1434  }
1435 
1438  size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
1439 
1441  bool is_zero() const { return data.ndims == 0; }
1442 
1443  bool operator==(const desc &other) const {
1444  return dnnl_memory_desc_equal(&data, &other.data) != 0;
1445  }
1446 
1447  bool operator!=(const desc &other) const { return !operator==(other); }
1448  };
1449 
1450  memory() = default;
1451 
1457  memory(const desc &md, const engine &aengine, void *ahandle) {
1458  dnnl_memory_t result;
1460  dnnl_memory_create(&result, &md.data, aengine.get(), ahandle),
1461  "could not create a memory");
1462  reset(result);
1463  }
1464 
1469  memory(const desc &md, const engine &aengine)
1470  : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
1471 
1473  desc get_desc() const {
1474  const dnnl_memory_desc_t *cdesc;
1476  "could not get memory descriptor from a memory");
1477  return desc(*cdesc);
1478  }
1479 
1481  engine get_engine() const {
1482  dnnl_engine_t engine_q;
1483  error::wrap_c_api(dnnl_memory_get_engine(get(), &engine_q),
1484  "could not get engine from a memory");
1485  return engine(engine_q);
1486  }
1487 
1491  void *get_data_handle() const {
1492  void *handle;
1494  "could not get native handle");
1495  return handle;
1496  }
1497 
1498  void set_data_handle(void *handle) const {
1500  "could not set native handle");
1501  }
1502 
1518  template <typename T = void>
1519  T *map_data() const {
1520  void *mapped_ptr;
1521  error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
1522  "could not map the data");
1523  return static_cast<T *>(mapped_ptr);
1524  }
1525 
1534  void unmap_data(void *mapped_ptr) const {
1535  error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
1536  "could not unmap the data");
1537  }
1538 
1539 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
1540  cl_mem get_ocl_mem_object() const {
1542  cl_mem mem_object;
1544  "could not get OpenCL memory object");
1545  return mem_object;
1546  }
1547 
1549  void set_ocl_mem_object(cl_mem mem_object) {
1551  "could not set OpenCL memory object");
1552  }
1553 #endif
1554 
1555  // Must go away or be private:
1556  static dnnl_data_type_t convert_to_c(data_type adata_type) {
1557  return static_cast<dnnl_data_type_t>(adata_type);
1558  }
1559  static dnnl_format_tag_t convert_to_c(format_tag aformat) {
1560  return static_cast<dnnl_format_tag_t>(aformat);
1561  }
1562 };
1563 
1564 inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
1565  return a == memory::convert_to_c(b);
1566 }
1567 inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
1568  return !(a == b);
1569 }
1570 inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
1571  return b == a;
1572 }
1573 inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
1574  return !(a == b);
1575 }
1576 
1577 inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
1578  return a == memory::convert_to_c(b);
1579 }
1580 inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
1581  return !(a == b);
1582 }
1583 inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
1584  return b == a;
1585 }
1586 inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
1587  return !(a == b);
1588 }
1589 
1591 
1594 
1597 
1599 struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
1601 
1602  primitive_desc_base() = default;
1603 
1605  engine get_engine() const { return engine::query(*this); }
1606 
1608  const char *impl_info_str() const {
1609  const char *res;
1611  get(), dnnl_query_impl_info_str, 0, &res),
1612  "could not query implementation info string");
1613  return res;
1614  }
1615 
1617  memory::dim query_s64(query q) const {
1618  memory::dim res;
1620  get(), dnnl::convert_to_c(q), 0, &res);
1621  return status == dnnl_success ? res : 0;
1622  }
1623 
1625  memory::desc query_md(query what, int idx = 0) const {
1626  std::vector<query> valid_q {query::src_md, query::diff_src_md,
1629  if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
1630  [=](query q) { return what == q; }))
1631  throw error(dnnl_invalid_arguments, "invalid memory query");
1632 
1634  get(), dnnl::convert_to_c(what), idx);
1635  return memory::desc(*cdesc);
1636  }
1637 
1643  return query_md(query::scratchpad_md, 0);
1644  }
1645 
1648  dnnl_engine_t engine_q;
1650  dnnl::convert_to_c(query::scratchpad_engine),
1651  0, &engine_q),
1652  "could not get scratchpad engine from a primitive_desc");
1653 
1654  return engine(engine_q);
1655  }
1656 
1659  const_dnnl_primitive_attr_t const_cattr;
1660  error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_cattr),
1661  "could not get attributes");
1662  dnnl_primitive_attr_t cattr;
1663  error::wrap_c_api(dnnl_primitive_attr_clone(&cattr, const_cattr),
1664  "could not clone attributes");
1665 
1666  return primitive_attr(cattr);
1667  }
1668 
1669 protected:
1670  void reset_with_clone(const_dnnl_primitive_desc_t pd) {
1671  dnnl_primitive_desc_t new_pd;
1673  "could not clone primitive descriptor");
1674  reset(new_pd);
1675  }
1676 
1677  primitive_desc_base(
1679  : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
1680 
1681  primitive_desc_base(dnnl_primitive_desc_t pd,
1683  : primitive_desc_base(pd, prim_kind, prop_kind, prop_kind) {}
1684 
1692  dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
1693  dnnl::prop_kind prop_kind2) {
1694  // It is OK to pass an empty primitive descriptor
1695  if (pd == nullptr) return;
1696 
1697  dnnl_status_t rc;
1698 
1699  dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
1700  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
1701  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
1702 
1703  // Check that primitive kind matches
1704  dnnl_primitive_kind_t pd_kind;
1706  pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
1707  error::wrap_c_api(rc,
1708  "could not get primitive kind from the primitive descriptor");
1709  if (pd_kind != c_prim_kind)
1711  "primitive descriptor operation kind mismatch");
1712 
1713  // Check that propagation kind matches
1714  dnnl_prop_kind_t pd_prop_kind;
1716  pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
1717 
1718  // Something went wrong
1719  if (rc != dnnl_success && rc != dnnl_unimplemented)
1721  "could not get propagation kind "
1722  "from the primitive descriptor");
1723 
1724  // Everything is fine
1725  if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
1726  || (rc == dnnl_success
1727  && (pd_prop_kind == c_prop_kind1
1728  || pd_prop_kind == c_prop_kind2))) {
1729  reset_with_clone(pd);
1730  return;
1731  }
1732 
1733  // We could get the propagation kind but there is a mismatch
1735  "primitive descriptor propagation kind mismatch");
1736  }
1737 };
1738 
1741 
1748 
1752 struct reorder : public primitive {
1753  struct primitive_desc : public primitive_desc_base {
1754  using primitive_desc_base::primitive_desc_base;
1755 
1756  primitive_desc() = default;
1757 
1758  primitive_desc(const engine &src_engine, const memory::desc &src_md,
1759  const engine &dst_engine, const memory::desc &dst_md,
1760  const primitive_attr &aattr = primitive_attr()) {
1761  dnnl_primitive_desc_t result;
1764  src_engine.get(), &dst_md.data, dst_engine.get(),
1765  aattr.get()),
1766  "could not create a reorder primitive descriptor");
1767  reset(result);
1768  }
1769 
1770  primitive_desc(const memory &src, const memory &dst,
1771  const primitive_attr &aattr = primitive_attr()) {
1772  dnnl_primitive_desc_t result;
1773  auto src_md = src.get_desc();
1774  auto dst_md = dst.get_desc();
1777  src.get_engine().get(), &dst_md.data,
1778  dst.get_engine().get(), aattr.get()),
1779  "could not create a reorder primitive descriptor");
1780  reset(result);
1781  }
1782 
1785  primitive_desc(dnnl_primitive_desc_t pd)
1787 
1788  engine get_src_engine() const {
1789  return engine::query(*this, dnnl::query::reorder_src_engine);
1790  }
1791 
1792  engine get_dst_engine() const {
1793  return engine::query(*this, dnnl::query::reorder_dst_engine);
1794  }
1795  };
1796 
1797  reorder() = default;
1798 
1799  reorder(const primitive_desc &pd) : primitive(pd.get()) {}
1800 
1801  reorder(const memory &src, const memory &dst)
1802  : primitive(primitive_desc(src, dst).get()) {}
1803 
1804  using primitive::execute;
1805 
1806  void execute(stream astream, memory &src, memory &dst) {
1807  primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
1808  }
1809 };
1810 
1812 
1819 
1821 inline std::vector<dnnl_memory_desc_t> convert_to_c(
1822  const std::vector<memory::desc> &mems) {
1823  std::vector<dnnl_memory_desc_t> c_api_mems;
1824  c_api_mems.reserve(mems.size());
1825  for (const auto &s : mems)
1826  c_api_mems.push_back(s.data);
1827  return c_api_mems;
1828 }
1830 
1838 struct concat : public primitive {
1839  struct primitive_desc : public primitive_desc_base {
1840  using primitive_desc_base::primitive_desc_base;
1841 
1842  primitive_desc(const memory::desc &dst, int concat_dimension,
1843  const std::vector<memory::desc> &srcs, const engine &aengine,
1844  const primitive_attr &aattr = primitive_attr()) {
1845  auto c_api_srcs = convert_to_c(srcs);
1846 
1847  dnnl_primitive_desc_t result;
1850  (int)c_api_srcs.size(), concat_dimension,
1851  &c_api_srcs[0], aattr.get(), aengine.get()),
1852  "could not create a concat primitive descriptor");
1853  reset(result);
1854  }
1855 
1856  primitive_desc(int concat_dimension,
1857  const std::vector<memory::desc> &srcs, const engine &aengine,
1858  const primitive_attr &aattr = primitive_attr()) {
1859  auto c_api_srcs = convert_to_c(srcs);
1860 
1861  dnnl_primitive_desc_t result;
1863  dnnl_concat_primitive_desc_create(&result, nullptr,
1864  (int)c_api_srcs.size(), concat_dimension,
1865  &c_api_srcs[0], aattr.get(), aengine.get()),
1866  "could not create a concat primitive descriptor");
1867  reset(result);
1868  }
1869 
1872  primitive_desc(dnnl_primitive_desc_t pd)
1874 
1876  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
1877  };
1878 
1879  concat() = default;
1880 
1881  concat(const primitive_desc &pd) : primitive(pd.get()) {}
1882 };
1883 
1885 
1892 
1898 struct sum : public primitive {
1899  struct primitive_desc : public primitive_desc_base {
1900  using primitive_desc_base::primitive_desc_base;
1901 
1902  primitive_desc() = default;
1903 
1904  primitive_desc(const memory::desc &dst,
1905  const std::vector<float> &scales,
1906  const std::vector<memory::desc> &srcs, const engine &aengine,
1907  const primitive_attr &aattr = primitive_attr()) {
1908  error::wrap_c_api(scales.size() == srcs.size()
1909  ? dnnl_success
1911  "number of scales not equal to number of srcs");
1912 
1913  auto c_api_srcs = convert_to_c(srcs);
1914 
1915  dnnl_primitive_desc_t result;
1917  dnnl_sum_primitive_desc_create(&result, &dst.data,
1918  (int)c_api_srcs.size(), &scales[0], &c_api_srcs[0],
1919  aattr.get(), aengine.get()),
1920  "could not create a sum primitive descriptor");
1921  reset(result);
1922  }
1923 
1924  primitive_desc(const std::vector<float> &scales,
1925  const std::vector<memory::desc> &srcs, const engine &aengine,
1926  const primitive_attr &aattr = primitive_attr()) {
1927  error::wrap_c_api(scales.size() == srcs.size()
1928  ? dnnl_success
1930  "number of scales not equal to number of srcs");
1931 
1932  auto c_api_srcs = convert_to_c(srcs);
1933  dnnl_primitive_desc_t result;
1935  dnnl_sum_primitive_desc_create(&result, nullptr,
1936  (int)c_api_srcs.size(), &scales[0], &c_api_srcs[0],
1937  aattr.get(), aengine.get()),
1938  "could not create a sum primitive descriptor");
1939  reset(result);
1940  }
1941 
1944  primitive_desc(dnnl_primitive_desc_t pd)
1946 
1948  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
1949  };
1950 
1951  sum() = default;
1952 
1953  sum(const primitive_desc &pd) : primitive(pd.get()) {}
1954 };
1955 
1957 
1959 
1962 
1965 
1969  using primitive_desc_base::primitive_desc_base;
1970 
1971  primitive_desc() = default;
1972 
1979  const engine &e, const_dnnl_primitive_desc_t hint_fwd_pd,
1980  bool allow_empty = false)
1981  : allow_empty(allow_empty) {
1982  dnnl_primitive_desc_iterator_t iterator = nullptr;
1984  desc, attr ? attr->get() : nullptr, e.get(), hint_fwd_pd);
1985  if (!allow_empty)
1987  status, "could not create a primitive descriptor iterator");
1988  pd_iterator.reset(iterator);
1989  fetch_impl();
1990  }
1991 
1998  bool next_impl() {
1999  dnnl_status_t status
2000  = dnnl_primitive_desc_iterator_next(pd_iterator.get());
2001  if (status == dnnl_iterator_ends) return false;
2002  error::wrap_c_api(status, "primitive descriptor iterator next failed");
2003 
2004  fetch_impl();
2005  return true;
2006  }
2007 
2008 private:
2009  bool allow_empty = false;
2011  void fetch_impl() {
2013  pd_iterator.get(allow_empty));
2014  error::wrap_c_api(pd != nullptr || allow_empty ? dnnl_success
2016  "could not fetch a primitive descriptor from the iterator");
2017  reset(pd);
2018  }
2019 };
2020 
2022 
2030 
2036 
2038  struct desc {
2040 
2049  desc(prop_kind aprop_kind, algorithm aalgorithm,
2050  const memory::desc &src_desc, const memory::desc &weights_desc,
2051  const memory::desc &bias_desc, const memory::desc &dst_desc,
2052  const memory::dims &strides, const memory::dims &padding_l,
2053  const memory::dims &padding_r) {
2054  memory::validate_dims(strides);
2055  memory::validate_dims(padding_l);
2056  memory::validate_dims(padding_r);
2059  dnnl::convert_to_c(aprop_kind),
2060  convert_to_c(aalgorithm), &src_desc.data,
2061  &weights_desc.data, &bias_desc.data, &dst_desc.data,
2062  &strides[0], &padding_l[0], &padding_r[0]),
2063  "could not create a convolution forward descriptor");
2064  }
2065 
2074  desc(prop_kind aprop_kind, algorithm aalgorithm,
2075  const memory::desc &src_desc, const memory::desc &weights_desc,
2076  const memory::desc &dst_desc, const memory::dims &strides,
2077  const memory::dims &padding_l, const memory::dims &padding_r) {
2078  memory::validate_dims(strides);
2079  memory::validate_dims(padding_l);
2080  memory::validate_dims(padding_r);
2083  dnnl::convert_to_c(aprop_kind),
2084  convert_to_c(aalgorithm), &src_desc.data,
2085  &weights_desc.data, nullptr, &dst_desc.data,
2086  &strides[0], &padding_l[0], &padding_r[0]),
2087  "could not create a convolution forward descriptor");
2088  }
2089 
2098  desc(prop_kind aprop_kind, algorithm aalgorithm,
2099  const memory::desc &src_desc, const memory::desc &weights_desc,
2100  const memory::desc &bias_desc, const memory::desc &dst_desc,
2101  const memory::dims &strides, const memory::dims &dilates,
2102  const memory::dims &padding_l, const memory::dims &padding_r) {
2103  memory::validate_dims(strides);
2104  memory::validate_dims(dilates);
2105  memory::validate_dims(padding_l);
2106  memory::validate_dims(padding_r);
2108  dnnl::convert_to_c(aprop_kind),
2109  convert_to_c(aalgorithm), &src_desc.data,
2110  &weights_desc.data, &bias_desc.data,
2111  &dst_desc.data, &strides[0], &dilates[0],
2112  &padding_l[0], &padding_r[0]),
2113  "could not create a dilated convolution forward "
2114  "descriptor");
2115  }
2116 
2125  desc(prop_kind aprop_kind, algorithm aalgorithm,
2126  const memory::desc &src_desc, const memory::desc &weights_desc,
2127  const memory::desc &dst_desc, const memory::dims &strides,
2128  const memory::dims &dilates, const memory::dims &padding_l,
2129  const memory::dims &padding_r) {
2130  memory::validate_dims(strides);
2131  memory::validate_dims(dilates);
2132  memory::validate_dims(padding_l);
2133  memory::validate_dims(padding_r);
2135  dnnl::convert_to_c(aprop_kind),
2136  convert_to_c(aalgorithm), &src_desc.data,
2137  &weights_desc.data, nullptr,
2138  &dst_desc.data, &strides[0], &dilates[0],
2139  &padding_l[0], &padding_r[0]),
2140  "could not create a dilated convolution forward "
2141  "descriptor");
2142  }
2143  };
2144 
2147  primitive_desc() = default;
2148 
2152  const desc &desc, const engine &e, bool allow_empty = false)
2153  : dnnl::primitive_desc(
2154  &desc.data, nullptr, e, nullptr, allow_empty) {}
2155 
2159  const engine &e, bool allow_empty = false)
2160  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
2161  }
2162 
2166  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
2169 
2172 
2175  return query_md(query::weights_md, 0);
2176  }
2177 
2183  return query_md(query::weights_md, 1);
2184  }
2185 
2188  };
2189 
2190  convolution_forward() = default;
2191 
2195 };
2196 
2202 
2204  struct desc {
2206 
2213  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2214  const memory::desc &weights_desc,
2215  const memory::desc &diff_dst_desc, const memory::dims &strides,
2216  const memory::dims &padding_l, const memory::dims &padding_r) {
2217  memory::validate_dims(strides);
2218  memory::validate_dims(padding_l);
2219  memory::validate_dims(padding_r);
2222  convert_to_c(aalgorithm), &diff_src_desc.data,
2223  &weights_desc.data, &diff_dst_desc.data,
2224  &strides[0], &padding_l[0], &padding_r[0]),
2225  "could not create a convolution backward data descriptor");
2226  }
2227 
2234  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2235  const memory::desc &weights_desc,
2236  const memory::desc &diff_dst_desc, const memory::dims &strides,
2237  const memory::dims &dilates, const memory::dims &padding_l,
2238  const memory::dims &padding_r) {
2239  memory::validate_dims(strides);
2240  memory::validate_dims(dilates);
2241  memory::validate_dims(padding_l);
2242  memory::validate_dims(padding_r);
2245  convert_to_c(aalgorithm), &diff_src_desc.data,
2246  &weights_desc.data, &diff_dst_desc.data,
2247  &strides[0], &dilates[0], &padding_l[0],
2248  &padding_r[0]),
2249  "could not create a convolution backward data descriptor");
2250  }
2251  };
2252 
2255  primitive_desc() = default;
2256 
2259  primitive_desc(const desc &desc, const engine &e,
2260  const convolution_forward::primitive_desc &hint_fwd_pd,
2261  bool allow_empty = false)
2262  : dnnl::primitive_desc(
2263  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2264 
2268  const engine &e,
2269  const convolution_forward::primitive_desc &hint_fwd_pd,
2270  bool allow_empty = false)
2271  : dnnl::primitive_desc(
2272  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
2273 
2277  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
2279 
2282  return query_md(query::diff_src_md, 0);
2283  }
2284 
2287  return query_md(query::weights_md, 0);
2288  }
2289 
2292  return query_md(query::diff_dst_md, 0);
2293  }
2294  };
2295 
2296  convolution_backward_data() = default;
2297 
2301 };
2302 
2308 
2310  struct desc {
2312 
2319  desc(algorithm aalgorithm, const memory::desc &src_desc,
2320  const memory::desc &diff_weights_desc,
2321  const memory::desc &diff_bias_desc,
2322  const memory::desc &diff_dst_desc, const memory::dims &strides,
2323  const memory::dims &padding_l, const memory::dims &padding_r) {
2324  memory::validate_dims(strides);
2325  memory::validate_dims(padding_l);
2326  memory::validate_dims(padding_r);
2329  convert_to_c(aalgorithm), &src_desc.data,
2330  &diff_weights_desc.data, &diff_bias_desc.data,
2331  &diff_dst_desc.data, &strides[0], &padding_l[0],
2332  &padding_r[0]),
2333  "could not create a convolution backward weights "
2334  "descriptor");
2335  }
2336 
2343  desc(algorithm aalgorithm, const memory::desc &src_desc,
2344  const memory::desc &diff_weights_desc,
2345  const memory::desc &diff_dst_desc, const memory::dims &strides,
2346  const memory::dims &padding_l, const memory::dims &padding_r) {
2347  memory::validate_dims(strides);
2348  memory::validate_dims(padding_l);
2349  memory::validate_dims(padding_r);
2351  convert_to_c(aalgorithm), &src_desc.data,
2352  &diff_weights_desc.data, nullptr,
2353  &diff_dst_desc.data, &strides[0],
2354  &padding_l[0], &padding_r[0]),
2355  "could not create a convolution backward weights "
2356  "descriptor");
2357  }
2358 
2365  desc(algorithm aalgorithm, const memory::desc &src_desc,
2366  const memory::desc &diff_weights_desc,
2367  const memory::desc &diff_bias_desc,
2368  const memory::desc &diff_dst_desc, const memory::dims &strides,
2369  const memory::dims &dilates, const memory::dims &padding_l,
2370  const memory::dims &padding_r) {
2371  memory::validate_dims(strides);
2372  memory::validate_dims(dilates);
2373  memory::validate_dims(padding_l);
2374  memory::validate_dims(padding_r);
2377  convert_to_c(aalgorithm), &src_desc.data,
2378  &diff_weights_desc.data, &diff_bias_desc.data,
2379  &diff_dst_desc.data, &strides[0], &dilates[0],
2380  &padding_l[0], &padding_r[0]),
2381  "could not create a convolution backward weights "
2382  "descriptor");
2383  }
2384 
2391  desc(algorithm aalgorithm, const memory::desc &src_desc,
2392  const memory::desc &diff_weights_desc,
2393  const memory::desc &diff_dst_desc, const memory::dims &strides,
2394  const memory::dims &dilates, const memory::dims &padding_l,
2395  const memory::dims &padding_r) {
2396  memory::validate_dims(strides);
2397  memory::validate_dims(dilates);
2398  memory::validate_dims(padding_l);
2399  memory::validate_dims(padding_r);
2402  convert_to_c(aalgorithm), &src_desc.data,
2403  &diff_weights_desc.data, nullptr,
2404  &diff_dst_desc.data, &strides[0], &dilates[0],
2405  &padding_l[0], &padding_r[0]),
2406  "could not create a convolution backward weights "
2407  "descriptor");
2408  }
2409  };
2410 
2413  primitive_desc() = default;
2414 
2416  primitive_desc(const desc &desc, const engine &e,
2417  const convolution_forward::primitive_desc &hint_fwd_pd,
2418  bool allow_empty = false)
2419  : dnnl::primitive_desc(
2420  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2421 
2425  const engine &e,
2426  const convolution_forward::primitive_desc &hint_fwd_pd,
2427  bool allow_empty = false)
2428  : dnnl::primitive_desc(
2429  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
2430 
2434  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
2436 
2439 
2442  return query_md(query::diff_weights_md, 0);
2443  }
2444 
2447  return query_md(query::diff_weights_md, 1);
2448  }
2449 
2452  return query_md(query::diff_dst_md, 0);
2453  }
2454  };
2455 
2456  convolution_backward_weights() = default;
2457 
2461 };
2462 
2464 //
2470 
2476 
2478  struct desc {
2480 
2489  desc(prop_kind aprop_kind, algorithm aalgorithm,
2490  const memory::desc &src_desc, const memory::desc &weights_desc,
2491  const memory::desc &bias_desc, const memory::desc &dst_desc,
2492  const memory::dims &strides, const memory::dims &padding_l,
2493  const memory::dims &padding_r) {
2494  memory::validate_dims(strides);
2495  memory::validate_dims(padding_l);
2496  memory::validate_dims(padding_r);
2499  dnnl::convert_to_c(aprop_kind),
2500  convert_to_c(aalgorithm), &src_desc.data,
2501  &weights_desc.data, &bias_desc.data, &dst_desc.data,
2502  &strides[0], &padding_l[0], &padding_r[0]),
2503  "could not create a deconvolution forward descriptor");
2504  }
2505 
2514  desc(prop_kind aprop_kind, algorithm aalgorithm,
2515  const memory::desc &src_desc, const memory::desc &weights_desc,
2516  const memory::desc &dst_desc, const memory::dims &strides,
2517  const memory::dims &padding_l, const memory::dims &padding_r) {
2518  memory::validate_dims(strides);
2519  memory::validate_dims(padding_l);
2520  memory::validate_dims(padding_r);
2523  dnnl::convert_to_c(aprop_kind),
2524  convert_to_c(aalgorithm), &src_desc.data,
2525  &weights_desc.data, nullptr, &dst_desc.data,
2526  &strides[0], &padding_l[0], &padding_r[0]),
2527  "could not create a deconvolution forward descriptor");
2528  }
2529 
2538  desc(prop_kind aprop_kind, algorithm aalgorithm,
2539  const memory::desc &src_desc, const memory::desc &weights_desc,
2540  const memory::desc &bias_desc, const memory::desc &dst_desc,
2541  const memory::dims &strides, const memory::dims &dilates,
2542  const memory::dims &padding_l, const memory::dims &padding_r) {
2543  memory::validate_dims(strides);
2544  memory::validate_dims(dilates);
2545  memory::validate_dims(padding_l);
2546  memory::validate_dims(padding_r);
2548  &data, dnnl::convert_to_c(aprop_kind),
2549  convert_to_c(aalgorithm), &src_desc.data,
2550  &weights_desc.data, &bias_desc.data,
2551  &dst_desc.data, &strides[0], &dilates[0],
2552  &padding_l[0], &padding_r[0]),
2553  "could not create a dilated deconvolution forward "
2554  "descriptor");
2555  }
2556 
2565  desc(prop_kind aprop_kind, algorithm aalgorithm,
2566  const memory::desc &src_desc, const memory::desc &weights_desc,
2567  const memory::desc &dst_desc, const memory::dims &strides,
2568  const memory::dims &dilates, const memory::dims &padding_l,
2569  const memory::dims &padding_r) {
2570  memory::validate_dims(strides);
2571  memory::validate_dims(dilates);
2572  memory::validate_dims(padding_l);
2573  memory::validate_dims(padding_r);
2575  &data, dnnl::convert_to_c(aprop_kind),
2576  convert_to_c(aalgorithm), &src_desc.data,
2577  &weights_desc.data, nullptr,
2578  &dst_desc.data, &strides[0], &dilates[0],
2579  &padding_l[0], &padding_r[0]),
2580  "could not create a dilated deconvolution forward "
2581  "descriptor");
2582  }
2583  };
2584 
2587  primitive_desc() = default;
2588 
2592  const desc &desc, const engine &e, bool allow_empty = false)
2593  : dnnl::primitive_desc(
2594  &desc.data, nullptr, e, nullptr, allow_empty) {}
2595 
2599  const engine &e, bool allow_empty = false)
2600  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
2601  }
2602 
2606  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
2609 
2612 
2615  return query_md(query::weights_md, 0);
2616  }
2617 
2623  return query_md(query::weights_md, 1);
2624  }
2625 
2628  };
2629 
2630  deconvolution_forward() = default;
2631 
2635 };
2636 
2642 
2644  struct desc {
2646 
2653  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2654  const memory::desc &weights_desc,
2655  const memory::desc &diff_dst_desc, const memory::dims &strides,
2656  const memory::dims &padding_l, const memory::dims &padding_r) {
2657  memory::validate_dims(strides);
2658  memory::validate_dims(padding_l);
2659  memory::validate_dims(padding_r);
2662  convert_to_c(aalgorithm), &diff_src_desc.data,
2663  &weights_desc.data, &diff_dst_desc.data,
2664  &strides[0], &padding_l[0], &padding_r[0]),
2665  "could not create a deconvolution backward data "
2666  "descriptor");
2667  }
2668 
2675  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2676  const memory::desc &weights_desc,
2677  const memory::desc &diff_dst_desc, const memory::dims &strides,
2678  const memory::dims &dilates, const memory::dims &padding_l,
2679  const memory::dims &padding_r) {
2680  memory::validate_dims(strides);
2681  memory::validate_dims(dilates);
2682  memory::validate_dims(padding_l);
2683  memory::validate_dims(padding_r);
2686  convert_to_c(aalgorithm), &diff_src_desc.data,
2687  &weights_desc.data, &diff_dst_desc.data,
2688  &strides[0], &dilates[0], &padding_l[0],
2689  &padding_r[0]),
2690  "could not create a dilated deconvolution backward data "
2691  "descriptor");
2692  }
2693  };
2694 
2697  primitive_desc() = default;
2698 
2701  primitive_desc(const desc &desc, const engine &e,
2702  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2703  bool allow_empty = false)
2704  : dnnl::primitive_desc(
2705  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2706 
2710  const engine &e,
2711  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2712  bool allow_empty = false)
2713  : dnnl::primitive_desc(
2714  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
2715 
2719  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
2721 
2724  return query_md(query::diff_src_md, 0);
2725  }
2726 
2729  return query_md(query::weights_md, 0);
2730  }
2731 
2734  return query_md(query::diff_dst_md, 0);
2735  }
2736  };
2737 
2738  deconvolution_backward_data() = default;
2739 
2743 };
2744 
2750 
2752  struct desc {
2754 
2761  desc(algorithm aalgorithm, const memory::desc &src_desc,
2762  const memory::desc &diff_weights_desc,
2763  const memory::desc &diff_bias_desc,
2764  const memory::desc &diff_dst_desc, const memory::dims &strides,
2765  const memory::dims &padding_l, const memory::dims &padding_r) {
2766  memory::validate_dims(strides);
2767  memory::validate_dims(padding_l);
2768  memory::validate_dims(padding_r);
2771  convert_to_c(aalgorithm), &src_desc.data,
2772  &diff_weights_desc.data, &diff_bias_desc.data,
2773  &diff_dst_desc.data, &strides[0], &padding_l[0],
2774  &padding_r[0]),
2775  "could not create a deconvolution backward weights "
2776  "descriptor");
2777  }
2778 
2785  desc(algorithm aalgorithm, const memory::desc &src_desc,
2786  const memory::desc &diff_weights_desc,
2787  const memory::desc &diff_dst_desc, const memory::dims &strides,
2788  const memory::dims &padding_l, const memory::dims &padding_r) {
2789  memory::validate_dims(strides);
2790  memory::validate_dims(padding_l);
2791  memory::validate_dims(padding_r);
2793  &data, convert_to_c(aalgorithm),
2794  &src_desc.data, &diff_weights_desc.data,
2795  nullptr, &diff_dst_desc.data, &strides[0],
2796  &padding_l[0], &padding_r[0]),
2797  "could not create a deconvolution backward weights "
2798  "descriptor");
2799  }
2800 
2807  desc(algorithm aalgorithm, const memory::desc &src_desc,
2808  const memory::desc &diff_weights_desc,
2809  const memory::desc &diff_bias_desc,
2810  const memory::desc &diff_dst_desc, const memory::dims &strides,
2811  const memory::dims &dilates, const memory::dims &padding_l,
2812  const memory::dims &padding_r) {
2813  memory::validate_dims(strides);
2814  memory::validate_dims(dilates);
2815  memory::validate_dims(padding_l);
2816  memory::validate_dims(padding_r);
2819  convert_to_c(aalgorithm), &src_desc.data,
2820  &diff_weights_desc.data, &diff_bias_desc.data,
2821  &diff_dst_desc.data, &strides[0], &dilates[0],
2822  &padding_l[0], &padding_r[0]),
2823  "could not create a dilated deconvolution backward "
2824  "weights descriptor");
2825  }
2826 
2833  desc(algorithm aalgorithm, const memory::desc &src_desc,
2834  const memory::desc &diff_weights_desc,
2835  const memory::desc &diff_dst_desc, const memory::dims &strides,
2836  const memory::dims &dilates, const memory::dims &padding_l,
2837  const memory::dims &padding_r) {
2838  memory::validate_dims(strides);
2839  memory::validate_dims(dilates);
2840  memory::validate_dims(padding_l);
2841  memory::validate_dims(padding_r);
2844  convert_to_c(aalgorithm), &src_desc.data,
2845  &diff_weights_desc.data, nullptr,
2846  &diff_dst_desc.data, &strides[0], &dilates[0],
2847  &padding_l[0], &padding_r[0]),
2848  "could not create a dilated deconvolution backward weights "
2849  "descriptor");
2850  }
2851  };
2852 
2855  primitive_desc() = default;
2856 
2858  primitive_desc(const desc &desc, const engine &e,
2859  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2860  bool allow_empty = false)
2861  : dnnl::primitive_desc(
2862  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2863 
2867  const engine &e,
2868  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2869  bool allow_empty = false)
2870  : dnnl::primitive_desc(
2871  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
2872 
2876  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
2878 
2881 
2884  return query_md(query::diff_weights_md, 0);
2885  }
2886 
2889  return query_md(query::diff_weights_md, 1);
2890  }
2891 
2894  return query_md(query::diff_dst_md, 0);
2895  }
2896  };
2897 
2898  deconvolution_backward_weights() = default;
2899 
2903 };
2904 
2906 
2914 
2917 struct lrn_forward : public primitive {
2918 
2920  struct desc {
2921  dnnl_lrn_desc_t data;
2922 
2928  desc(prop_kind aprop_kind, algorithm aalgorithm,
2929  const memory::desc &src_desc, memory::dim local_size,
2930  float alpha, float beta, float k = 1.f) {
2932  dnnl::convert_to_c(aprop_kind),
2933  convert_to_c(aalgorithm), &src_desc.data,
2934  local_size, alpha, beta, k),
2935  "could not create a lrn forward descriptor");
2936  }
2937  };
2938 
2942  primitive_desc() = default;
2943 
2945  const desc &desc, const engine &e, bool allow_empty = false)
2947  &desc.data, nullptr, e, nullptr, allow_empty) {}
2948 
2949  primitive_desc(const desc &desc, const primitive_attr &attr,
2950  const engine &e, bool allow_empty = false)
2951  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
2952  }
2953 
2958  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
2961 
2964 
2967 
2972  return query_md(query::workspace_md, 0);
2973  }
2974  };
2975 
2976  lrn_forward() = default;
2977 
2978  lrn_forward(const primitive_desc &pd) : primitive(pd) {}
2979 };
2980 
2983 struct lrn_backward : public primitive {
2984 
2986  struct desc {
2987  dnnl_lrn_desc_t data;
2988 
2993  desc(algorithm aalgorithm, const memory::desc &data_desc,
2994  const memory::desc &diff_data_desc, memory::dim local_size,
2995  float alpha, float beta, float k = 1.f) {
2997  dnnl_lrn_backward_desc_init(&data, convert_to_c(aalgorithm),
2998  &diff_data_desc.data, &data_desc.data, local_size,
2999  alpha, beta, k),
3000  "could not create a lrn backward descriptor");
3001  }
3002  };
3003 
3007  primitive_desc() = default;
3008 
3009  primitive_desc(const desc &desc, const engine &e,
3010  const lrn_forward::primitive_desc &hint_fwd_pd,
3011  bool allow_empty = false)
3013  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3014 
3015  primitive_desc(const desc &desc, const primitive_attr &attr,
3016  const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd,
3017  bool allow_empty = false)
3019  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3020 
3025  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
3027 
3030  return query_md(query::diff_src_md, 0);
3031  }
3032 
3035  return query_md(query::diff_dst_md, 0);
3036  }
3037 
3042  return query_md(query::workspace_md, 0);
3043  }
3044  };
3045 
3046  lrn_backward() = default;
3047 
3048  lrn_backward(const primitive_desc &pd) : primitive(pd) {}
3049 };
3050 
3052 
3059 
3062 struct pooling_forward : public primitive {
3063 
3065  struct desc {
3066  dnnl_pooling_desc_t data;
3067 
3073  desc(prop_kind aprop_kind, algorithm aalgorithm,
3074  const memory::desc &src_desc, const memory::desc &dst_desc,
3075  const memory::dims &strides, const memory::dims &kernel,
3076  const memory::dims &padding_l, const memory::dims &padding_r) {
3077  memory::validate_dims(strides);
3078  memory::validate_dims(kernel);
3079  memory::validate_dims(padding_l);
3080  memory::validate_dims(padding_r);
3082  dnnl::convert_to_c(aprop_kind),
3083  convert_to_c(aalgorithm), &src_desc.data,
3084  &dst_desc.data, &strides[0], &kernel[0],
3085  &padding_l[0], &padding_r[0]),
3086  "could not init a forward pooling descriptor");
3087  }
3088  };
3089 
3092  primitive_desc() = default;
3093 
3095  const desc &desc, const engine &e, bool allow_empty = false)
3097  &desc.data, nullptr, e, nullptr, allow_empty) {}
3098 
3099  primitive_desc(const desc &desc, const primitive_attr &attr,
3100  const engine &e, bool allow_empty = false)
3101  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3102  }
3103 
3107  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
3110 
3113 
3116 
3121  return query_md(query::workspace_md, 0);
3122  }
3123  };
3124 
3125  pooling_forward() = default;
3126 
3127  pooling_forward(const primitive_desc &pd) : primitive(pd) {}
3128 };
3129 
3130 struct pooling_backward : public primitive {
3131 
3133  struct desc {
3134  dnnl_pooling_desc_t data;
3135 
3139  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
3140  const memory::desc &diff_dst_desc, const memory::dims &strides,
3141  const memory::dims &kernel, const memory::dims &padding_l,
3142  const memory::dims &padding_r) {
3143  memory::validate_dims(strides);
3144  memory::validate_dims(kernel);
3145  memory::validate_dims(padding_l);
3146  memory::validate_dims(padding_r);
3149  convert_to_c(aalgorithm), &diff_src_desc.data,
3150  &diff_dst_desc.data, &strides[0], &kernel[0],
3151  &padding_l[0], &padding_r[0]),
3152  "could not init a backward pooling descriptor");
3153  }
3154  };
3155 
3158  primitive_desc() = default;
3159 
3160  primitive_desc(const desc &desc, const engine &e,
3161  const pooling_forward::primitive_desc &hint_fwd_pd,
3162  bool allow_empty = false)
3164  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3165 
3166  primitive_desc(const desc &desc, const primitive_attr &attr,
3167  const engine &e,
3168  const pooling_forward::primitive_desc &hint_fwd_pd,
3169  bool allow_empty = false)
3171  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3172 
3176  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
3178 
3181  return query_md(query::diff_src_md, 0);
3182  }
3183 
3186  return query_md(query::diff_dst_md, 0);
3187  }
3188 
3193  return query_md(query::workspace_md, 0);
3194  }
3195  };
3196 
3197  pooling_backward() = default;
3198 
3199  pooling_backward(const primitive_desc &pd) : primitive(pd) {}
3200 };
3201 
3203 
3221 
3224 struct eltwise_forward : public primitive {
3225 
3230  struct desc {
3231  dnnl_eltwise_desc_t data;
3232  desc(prop_kind aprop_kind, algorithm aalgorithm,
3233  const memory::desc &src_desc, float alpha = 0, float beta = 0) {
3235  dnnl::convert_to_c(aprop_kind),
3236  dnnl::convert_to_c(aalgorithm),
3237  &src_desc.data, alpha, beta),
3238  "could not create a eltwise forward descriptor");
3239  }
3240  };
3241 
3244  primitive_desc() = default;
3245 
3247  const desc &desc, const engine &e, bool allow_empty = false)
3249  &desc.data, nullptr, e, nullptr, allow_empty) {}
3250 
3251  primitive_desc(const desc &desc, const primitive_attr &attr,
3252  const engine &e, bool allow_empty = false)
3253  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3254  }
3255 
3259  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
3262 
3265 
3268  };
3269 
3270  eltwise_forward() = default;
3271 
3272  eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
3273 };
3274 
3277 struct eltwise_backward : public primitive {
3278 
3282  struct desc {
3283  dnnl_eltwise_desc_t data;
3284 
3285  desc(algorithm aalgorithm, const memory::desc &diff_data_desc,
3286  const memory::desc &data_desc, float alpha = 0,
3287  float beta = 0) {
3290  dnnl::convert_to_c(aalgorithm),
3291  &diff_data_desc.data, &data_desc.data, alpha, beta),
3292  "could not create a eltwise backward descriptor");
3293  }
3294  };
3295 
3298  primitive_desc() = default;
3299 
3300  primitive_desc(const desc &desc, const engine &e,
3301  const eltwise_forward::primitive_desc &hint_fwd_pd,
3302  bool allow_empty = false)
3304  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3305 
3306  primitive_desc(const desc &desc, const primitive_attr &attr,
3307  const engine &e,
3308  const eltwise_forward::primitive_desc &hint_fwd_pd,
3309  bool allow_empty = false)
3311  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3312 
3316  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
3318 
3321 
3324  return query_md(query::diff_src_md, 0);
3325  }
3326 
3329  return query_md(query::diff_dst_md, 0);
3330  }
3331  };
3332 
3333  eltwise_backward() = default;
3334 
3335  eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
3336 };
3337 
3339 
3346 
3349 struct softmax_forward : public primitive {
3350 
3352  struct desc {
3353  dnnl_softmax_desc_t data;
3354 
3358  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3359  int softmax_axis) {
3361  dnnl::convert_to_c(aprop_kind),
3362  &data_desc.data, softmax_axis),
3363  "could not create a softmax forward descriptor");
3364  }
3365  };
3366 
3369  primitive_desc() = default;
3370 
3372  const desc &desc, const engine &e, bool allow_empty = false)
3374  &desc.data, nullptr, e, nullptr, allow_empty) {}
3375 
3376  primitive_desc(const desc &desc, const primitive_attr &attr,
3377  const engine &e, bool allow_empty = false)
3378  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3379  }
3380 
3384  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
3387 
3390 
3393  };
3394 
3395  softmax_forward() = default;
3396 
3397  softmax_forward(const primitive_desc &pd) : primitive(pd) {}
3398 };
3399 
3402 struct softmax_backward : public primitive {
3403 
3405  struct desc {
3406  dnnl_softmax_desc_t data;
3407 
3410  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
3411  int softmax_axis) {
3413  dnnl_softmax_backward_desc_init(&data, &diff_desc.data,
3414  &data_desc.data, softmax_axis),
3415  "could not init a backward softmax descriptor");
3416  }
3417  };
3418 
3421  primitive_desc() = default;
3422 
3423  primitive_desc(const desc &desc, const engine &e,
3424  const softmax_forward::primitive_desc &hint_fwd_pd,
3425  bool allow_empty = false)
3427  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3428 
3429  primitive_desc(const desc &desc, const primitive_attr &attr,
3430  const engine &e,
3431  const softmax_forward::primitive_desc &hint_fwd_pd,
3432  bool allow_empty = false)
3434  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3435 
3439  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
3441 
3444 
3447  return query_md(query::diff_src_md, 0);
3448  }
3449 
3452  return query_md(query::diff_dst_md, 0);
3453  }
3454  };
3455 
3456  softmax_backward() = default;
3457 
3458  softmax_backward(const primitive_desc &pd) : primitive(pd) {}
3459 };
3460 
3462 
3480 
3484 
3486  struct desc {
3488 
3497  desc(prop_kind aprop_kind, const memory::desc &src_desc, float epsilon,
3498  normalization_flags flags) {
3501  dnnl::convert_to_c(aprop_kind), &src_desc.data,
3502  epsilon, convert_to_c(flags)),
3503  "could not create a batch normalization forward "
3504  "descriptor");
3505  }
3506  };
3507 
3510  primitive_desc() = default;
3511 
3513  const desc &desc, const engine &e, bool allow_empty = false)
3515  &desc.data, nullptr, e, nullptr, allow_empty) {}
3516 
3517  primitive_desc(const desc &desc, const primitive_attr &attr,
3518  const engine &e, bool allow_empty = false)
3519  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3520  }
3521 
3525  : dnnl::primitive_desc(pd,
3526  dnnl::primitive::kind::batch_normalization,
3529 
3532 
3535  return query_md(query::weights_md, 0);
3536  }
3537 
3540 
3545  return query_md(query::workspace_md, 0);
3546  }
3547 
3549  memory::desc mean_desc() const { return stat_desc(mean); }
3550 
3552  memory::desc variance_desc() const { return stat_desc(var); }
3553 
3554  private:
3555  enum {
3556  mean = 1,
3557  var = 2,
3558  };
3559  memory::desc stat_desc(int kind) const {
3563  dnnl::convert_to_c(query::batch_normalization_d), 0,
3564  &p),
3565  "could not get a batch-normalization descriptor");
3566  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
3567  : query::dst_md,
3568  kind);
3569  }
3570  };
3571 
3572  batch_normalization_forward() = default;
3573 
3574  batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
3575 };
3576 
3580 
3582  struct desc {
3584 
3593  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
3594  const memory::desc &data_desc, float epsilon,
3595  normalization_flags flags) {
3597  dnnl::convert_to_c(aprop_kind),
3598  &diff_data_desc.data, &data_desc.data,
3599  epsilon, convert_to_c(flags)),
3600  "could not create a batch normalization backward "
3601  "descriptor");
3602  }
3603  };
3604 
3607  primitive_desc() = default;
3608 
3609  primitive_desc(const desc &desc, const engine &e,
3611  bool allow_empty = false)
3613  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3614 
3615  primitive_desc(const desc &desc, const primitive_attr &attr,
3616  const engine &e,
3618  bool allow_empty = false)
3620  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3621 
3625  : dnnl::primitive_desc(pd,
3626  dnnl::primitive::kind::batch_normalization,
3628  }
3629 
3632 
3635 
3638  return query_md(query::src_md, 2);
3639  }
3640 
3643  return query_md(query::weights_md, 0);
3644  }
3645 
3648 
3651  return query_md(query::diff_dst_md, 0);
3652  }
3653 
3658  return query_md(query::workspace_md, 0);
3659  }
3660 
3663  return query_md(query::diff_src_md, 0);
3664  }
3665 
3668  return query_md(query::diff_weights_md, 0);
3669  }
3670  };
3671 
3672  batch_normalization_backward() = default;
3673 
3675 };
3676 
3678 
3697 
3701 
3703  struct desc {
3705 
3714  desc(prop_kind aprop_kind, const memory::desc &src_desc,
3715  const memory::desc &stat_desc, float epsilon,
3716  normalization_flags flags) {
3719  dnnl::convert_to_c(aprop_kind), &src_desc.data,
3720  &stat_desc.data, epsilon, convert_to_c(flags)),
3721  "could not create a layer normalization forward "
3722  "descriptor");
3723  }
3724 
3725  desc(prop_kind aprop_kind, const memory::desc &src_desc, float epsilon,
3726  normalization_flags flags) {
3729  dnnl::convert_to_c(aprop_kind), &src_desc.data,
3730  nullptr, epsilon, convert_to_c(flags)),
3731  "could not create a layer normalization forward "
3732  "descriptor");
3733  }
3734  };
3735 
3738  primitive_desc() = default;
3739 
3741  const desc &desc, const engine &e, bool allow_empty = false)
3743  &desc.data, nullptr, e, nullptr, allow_empty) {}
3744 
3745  primitive_desc(const desc &desc, const primitive_attr &attr,
3746  const engine &e, bool allow_empty = false)
3747  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3748  }
3749 
3753  : dnnl::primitive_desc(pd,
3754  dnnl::primitive::kind::layer_normalization,
3757 
3760 
3763  return query_md(query::weights_md, 0);
3764  }
3765 
3768 
3770  memory::desc mean_desc() const { return stat_desc(mean); }
3771 
3773  memory::desc variance_desc() const { return stat_desc(var); }
3774 
3779  return query_md(query::workspace_md, 0);
3780  }
3781 
3782  private:
3783  enum {
3784  mean = 1,
3785  var = 2,
3786  };
3787  memory::desc stat_desc(int kind) const {
3791  dnnl::convert_to_c(query::layer_normalization_d), 0,
3792  &p),
3793  "could not get a layer-normalization descriptor");
3794  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
3795  : query::dst_md,
3796  kind);
3797  }
3798  };
3799 
3800  layer_normalization_forward() = default;
3801 
3802  layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
3803 };
3804 
3808 
3810  struct desc {
3812 
3821  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
3822  const memory::desc &data_desc, const memory::desc &stat_desc,
3823  float epsilon, normalization_flags flags) {
3826  dnnl::convert_to_c(aprop_kind),
3827  &diff_data_desc.data, &data_desc.data,
3828  &stat_desc.data, epsilon, convert_to_c(flags)),
3829  "could not create a layer normalization backward "
3830  "descriptor");
3831  }
3832 
3833  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
3834  const memory::desc &data_desc, float epsilon,
3835  normalization_flags flags) {
3837  dnnl::convert_to_c(aprop_kind),
3838  &diff_data_desc.data, &data_desc.data,
3839  nullptr, epsilon, convert_to_c(flags)),
3840  "could not create a layer normalization backward "
3841  "descriptor");
3842  }
3843  };
3844 
3847  primitive_desc() = default;
3848 
3849  primitive_desc(const desc &desc, const engine &e,
3851  bool allow_empty = false)
3853  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3854 
3855  primitive_desc(const desc &desc, const primitive_attr &attr,
3856  const engine &e,
3858  bool allow_empty = false)
3860  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3861 
3865  : dnnl::primitive_desc(pd,
3866  dnnl::primitive::kind::layer_normalization,
3868  }
3869 
3872 
3875 
3878  return query_md(query::src_md, 2);
3879  }
3880 
3883  return query_md(query::weights_md, 0);
3884  }
3885 
3888 
3891  return query_md(query::diff_dst_md, 0);
3892  }
3893 
3896  return query_md(query::diff_src_md, 0);
3897  }
3898 
3901  return query_md(query::diff_weights_md, 0);
3902  }
3903 
3908  return query_md(query::workspace_md, 0);
3909  }
3910  };
3911 
3912  layer_normalization_backward() = default;
3913 
3915 };
3916 
3918 
3925 
3929 
3939  struct desc {
3941  desc(prop_kind aprop_kind, const memory::desc &src_desc,
3942  const memory::desc &weights_desc, const memory::desc &bias_desc,
3943  const memory::desc &dst_desc) {
3945  dnnl::convert_to_c(aprop_kind),
3946  &src_desc.data, &weights_desc.data,
3947  &bias_desc.data, &dst_desc.data),
3948  "could not create a inner product forward descriptor");
3949  }
3950 
3951  desc(prop_kind aprop_kind, const memory::desc &src_desc,
3952  const memory::desc &weights_desc,
3953  const memory::desc &dst_desc) {
3956  dnnl::convert_to_c(aprop_kind), &src_desc.data,
3957  &weights_desc.data, nullptr, &dst_desc.data),
3958  "could not create a inner product forward descriptor");
3959  }
3960  };
3961 
3964  primitive_desc() = default;
3965 
3967  const desc &desc, const engine &e, bool allow_empty = false)
3969  &desc.data, nullptr, e, nullptr, allow_empty) {}
3970 
3971  primitive_desc(const desc &desc, const primitive_attr &attr,
3972  const engine &e, bool allow_empty = false)
3973  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3974  }
3975 
3979  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
3982 
3985 
3988  return query_md(query::weights_md, 0);
3989  }
3990 
3996  return query_md(query::weights_md, 1);
3997  }
3998 
4001  };
4002 
4003  inner_product_forward() = default;
4004 
4005  inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
4006 };
4007 
4011 
4017  struct desc {
4019  desc(const memory::desc &diff_src_desc,
4020  const memory::desc &weights_desc,
4021  const memory::desc &diff_dst_desc) {
4023  &diff_src_desc.data, &weights_desc.data,
4024  &diff_dst_desc.data),
4025  "could not create a inner product backward data "
4026  "descriptor");
4027  }
4028  };
4029 
4033  primitive_desc() = default;
4034 
4035  primitive_desc(const desc &desc, const engine &e,
4036  const inner_product_forward::primitive_desc &hint_fwd_pd,
4037  bool allow_empty = false)
4039  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4040 
4041  primitive_desc(const desc &desc, const primitive_attr &attr,
4042  const engine &e,
4043  const inner_product_forward::primitive_desc &hint_fwd_pd,
4044  bool allow_empty = false)
4046  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4047 
4051  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
4053 
4056  return query_md(query::diff_src_md, 0);
4057  }
4058 
4061  return query_md(query::weights_md, 0);
4062  }
4063 
4066  return query_md(query::diff_dst_md, 0);
4067  }
4068  };
4069 
4070  inner_product_backward_data() = default;
4071 
4073 };
4074 
4078 
4084  struct desc {
4086  desc(const memory::desc &src_desc,
4087  const memory::desc &diff_weights_desc,
4088  const memory::desc &diff_bias_desc,
4089  const memory::desc &diff_dst_desc) {
4092  &src_desc.data, &diff_weights_desc.data,
4093  &diff_bias_desc.data, &diff_dst_desc.data),
4094  "could not create a inner product backward weights "
4095  "descriptor");
4096  }
4097  desc(const memory::desc &src_desc,
4098  const memory::desc &diff_weights_desc,
4099  const memory::desc &diff_dst_desc) {
4102  &src_desc.data, &diff_weights_desc.data, nullptr,
4103  &diff_dst_desc.data),
4104  "could not create a inner product backward weights "
4105  "descriptor");
4106  }
4107  };
4108 
4112  primitive_desc() = default;
4113 
4114  primitive_desc(const desc &desc, const engine &e,
4115  const inner_product_forward::primitive_desc &hint_fwd_pd,
4116  bool allow_empty = false)
4118  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4119 
4120  primitive_desc(const desc &desc, const primitive_attr &attr,
4121  const engine &e,
4122  const inner_product_forward::primitive_desc &hint_fwd_pd,
4123  bool allow_empty = false)
4125  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4126 
4130  : dnnl::primitive_desc(cpd, dnnl::primitive::kind::inner_product,
4132 
4135 
4138  return query_md(query::diff_weights_md, 0);
4139  }
4140 
4143  return query_md(query::diff_weights_md, 1);
4144  }
4145 
4148  return query_md(query::diff_dst_md, 0);
4149  }
4150  };
4151 
4152  inner_product_backward_weights() = default;
4153 
4155 };
4156 
4158 
4165 
4166 struct rnn_primitive_desc_base : public primitive_desc {
4167  using primitive_desc::primitive_desc;
4168 
4169  rnn_primitive_desc_base() = default;
4170 
4171 protected:
4172  // Constructs an RNN primitive descriptor from a C counterpart while
4173  // checking that it actually describes the expected primitive.
4174  rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
4175  dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
4176  dnnl::algorithm cell_kind) {
4178  dnnl_status_t rc;
4181  rc, "could not retrieve rnn_desc from a primitive descriptor");
4182 
4183  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
4184  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
4185  dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
4186 
4187  bool ok = rnn_d->primitive_kind == dnnl_rnn
4188  && (rnn_d->prop_kind == c_prop_kind1
4189  || rnn_d->prop_kind == c_prop_kind2)
4190  && rnn_d->cell_kind == c_cell_kind;
4191 
4192  if (!ok) throw error(dnnl_invalid_arguments, "rnn descriptor mismatch");
4193 
4194  reset_with_clone(pd);
4195  }
4196 
4197  // Constructs an RNN primitive descriptor from a C counterpart while
4198  // checking that it actually describes the expected primitive.
4199  rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind prop_kind,
4200  dnnl::algorithm cell_kind)
4201  : rnn_primitive_desc_base(pd, prop_kind, prop_kind, cell_kind) {}
4202 };
4203 
4208 
4210  struct desc {
4211  dnnl_rnn_desc_t data;
4212 
4231  desc(prop_kind aprop_kind, algorithm activation,
4232  rnn_direction direction, const memory::desc &src_layer_desc,
4233  const memory::desc &src_iter_desc,
4234  const memory::desc &weights_layer_desc,
4235  const memory::desc &weights_iter_desc,
4236  const memory::desc &bias_desc,
4237  const memory::desc &dst_layer_desc,
4238  const memory::desc &dst_iter_desc,
4239  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
4240  float beta = 0.0f) {
4243  dnnl::convert_to_c(aprop_kind),
4244  dnnl::convert_to_c(activation),
4245  dnnl::convert_to_c(direction), &src_layer_desc.data,
4246  &src_iter_desc.data, &weights_layer_desc.data,
4247  &weights_iter_desc.data, &bias_desc.data,
4248  &dst_layer_desc.data, &dst_iter_desc.data,
4249  dnnl::convert_to_c(flags), alpha, beta),
4250  "could not create an RNN forward descriptor");
4251  }
4252  };
4253 
4255  struct primitive_desc : public rnn_primitive_desc_base {
4256  primitive_desc() = default;
4257 
4259  const desc &desc, const engine &e, bool allow_empty = false)
4260  : rnn_primitive_desc_base(
4261  &desc.data, nullptr, e, nullptr, allow_empty) {}
4262 
4263  primitive_desc(const desc &desc, const primitive_attr &attr,
4264  const engine &e, bool allow_empty = false)
4265  : rnn_primitive_desc_base(
4266  &desc.data, &attr, e, nullptr, allow_empty) {}
4267 
4271  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
4273  dnnl::algorithm::vanilla_rnn) {}
4274 
4277  return query_md(query::src_md, 0);
4278  }
4279 
4285  return query_md(query::src_md, 1);
4286  }
4287 
4290  return query_md(query::weights_md, 0);
4291  }
4292 
4295  return query_md(query::weights_md, 1);
4296  }
4297 
4303  return query_md(query::weights_md, 2);
4304  }
4305 
4308  return query_md(query::dst_md, 0);
4309  }
4310 
4316  return query_md(query::dst_md, 1);
4317  }
4318 
4323  return query_md(query::workspace_md, 0);
4324  }
4325  };
4326 
4327  vanilla_rnn_forward() = default;
4328 
4329  vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
4330 };
4331 
4336 
4338  struct desc {
4339  dnnl_rnn_desc_t data;
4340 
4358  desc(prop_kind aprop_kind, algorithm activation,
4359  rnn_direction direction, const memory::desc &src_layer_desc,
4360  const memory::desc &src_iter_desc,
4361  const memory::desc &weights_layer_desc,
4362  const memory::desc &weights_iter_desc,
4363  const memory::desc &bias_desc,
4364  const memory::desc &dst_layer_desc,
4365  const memory::desc &dst_iter_desc,
4366  const memory::desc &diff_src_layer_desc,
4367  const memory::desc &diff_src_iter_desc,
4368  const memory::desc &diff_weights_layer_desc,
4369  const memory::desc &diff_weights_iter_desc,
4370  const memory::desc &diff_bias_desc,
4371  const memory::desc &diff_dst_layer_desc,
4372  const memory::desc &diff_dst_iter_desc,
4373  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
4374  float beta = 0.0f) {
4377  dnnl::convert_to_c(aprop_kind),
4378  dnnl::convert_to_c(activation),
4379  dnnl::convert_to_c(direction), &src_layer_desc.data,
4380  &src_iter_desc.data, &weights_layer_desc.data,
4381  &weights_iter_desc.data, &bias_desc.data,
4382  &dst_layer_desc.data, &dst_iter_desc.data,
4383  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
4384  &diff_weights_layer_desc.data,
4385  &diff_weights_iter_desc.data, &diff_bias_desc.data,
4386  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
4387  dnnl::convert_to_c(flags), alpha, beta),
4388  "could not create an RNN backward descriptor");
4389  }
4390  };
4391 
4393  struct primitive_desc : public rnn_primitive_desc_base {
4394  primitive_desc() = default;
4395 
4396  primitive_desc(const desc &desc, const engine &e,
4397  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
4398  bool allow_empty = false)
4399  : rnn_primitive_desc_base(
4400  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4401 
4402  primitive_desc(const desc &desc, const primitive_attr &attr,
4403  const engine &e,
4404  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
4405  bool allow_empty = false)
4406  : rnn_primitive_desc_base(
4407  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4408 
4412  : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
4413  dnnl::algorithm::vanilla_rnn) {}
4414 
4417  return query_md(query::src_md, 0);
4418  }
4419 
4425  return query_md(query::src_md, 1);
4426  }
4427 
4430  return query_md(query::weights_md, 0);
4431  }
4432 
4435  return query_md(query::weights_md, 1);
4436  }
4437 
4443  return query_md(query::weights_md, 2);
4444  }
4445 
4448  return query_md(query::dst_md, 0);
4449  }
4450 
4456  return query_md(query::dst_md, 1);
4457  }
4458 
4463  return query_md(query::workspace_md, 0);
4464  }
4465 
4468  return query_md(query::diff_src_md, 0);
4469  }
4470 
4476  return query_md(query::diff_src_md, 1);
4477  }
4478 
4481  return query_md(query::diff_weights_md, 0);
4482  }
4483 
4486  return query_md(query::diff_weights_md, 1);
4487  }
4488 
4491  return query_md(query::diff_weights_md, 2);
4492  }
4493 
4496  return query_md(query::diff_dst_md, 0);
4497  }
4498 
4504  return query_md(query::diff_dst_md, 1);
4505  }
4506  };
4507 
4508  vanilla_rnn_backward() = default;
4509 
4510  vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
4511 };
4512 
4516 struct lstm_forward : public primitive {
4517 
4519  struct desc {
4520  dnnl_rnn_desc_t data;
4521 
4539  desc(prop_kind aprop_kind, rnn_direction direction,
4540  const memory::desc &src_layer_desc,
4541  const memory::desc &src_iter_desc,
4542  const memory::desc &src_iter_c_desc,
4543  const memory::desc &weights_layer_desc,
4544  const memory::desc &weights_iter_desc,
4545  const memory::desc &bias_desc,
4546  const memory::desc &dst_layer_desc,
4547  const memory::desc &dst_iter_desc,
4548  const memory::desc &dst_iter_c_desc,
4549  rnn_flags flags = rnn_flags::undef) {
4552  dnnl::convert_to_c(aprop_kind),
4553  dnnl::convert_to_c(direction), &src_layer_desc.data,
4554  &src_iter_desc.data, &src_iter_c_desc.data,
4555  &weights_layer_desc.data, &weights_iter_desc.data,
4556  &bias_desc.data, &dst_layer_desc.data,
4557  &dst_iter_desc.data, &dst_iter_c_desc.data,
4558  dnnl::convert_to_c(flags)),
4559  "could not create an LSTM forward descriptor");
4560  }
4561  };
4562 
4564  struct primitive_desc : public rnn_primitive_desc_base {
4565  primitive_desc() = default;
4566 
4568  const desc &desc, const engine &e, bool allow_empty = false)
4569  : rnn_primitive_desc_base(
4570  &desc.data, nullptr, e, nullptr, allow_empty) {}
4571 
4572  primitive_desc(const desc &desc, const primitive_attr &attr,
4573  const engine &e, bool allow_empty = false)
4574  : rnn_primitive_desc_base(
4575  &desc.data, &attr, e, nullptr, allow_empty) {}
4576 
4580  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
4583 
4586  return query_md(query::src_md, 0);
4587  }
4588 
4594  return query_md(query::src_md, 1);
4595  }
4596 
4599  return query_md(query::src_md, 2);
4600  }
4601 
4604  return query_md(query::weights_md, 0);
4605  }
4606 
4609  return query_md(query::weights_md, 1);
4610  }
4611 
4617  return query_md(query::weights_md, 2);
4618  }
4619 
4622  return query_md(query::dst_md, 0);
4623  }
4624 
4630  return query_md(query::dst_md, 1);
4631  }
4632 
4635  return query_md(query::dst_md, 2);
4636  }
4637 
4642  return query_md(query::workspace_md, 0);
4643  }
4644  };
4645 
4646  lstm_forward() = default;
4647 
4648  lstm_forward(const primitive_desc &pd) : primitive(pd) {}
4649 };
4650 
4654 struct lstm_backward : public primitive {
4655 
4657  struct desc {
4658  dnnl_rnn_desc_t data;
4659 
4678  desc(prop_kind aprop_kind, rnn_direction direction,
4679  const memory::desc &src_layer_desc,
4680  const memory::desc &src_iter_desc,
4681  const memory::desc &src_iter_c_desc,
4682  const memory::desc &weights_layer_desc,
4683  const memory::desc &weights_iter_desc,
4684  const memory::desc &bias_desc,
4685  const memory::desc &dst_layer_desc,
4686  const memory::desc &dst_iter_desc,
4687  const memory::desc &dst_iter_c_desc,
4688  const memory::desc &diff_src_layer_desc,
4689  const memory::desc &diff_src_iter_desc,
4690  const memory::desc &diff_src_iter_c_desc,
4691  const memory::desc &diff_weights_layer_desc,
4692  const memory::desc &diff_weights_iter_desc,
4693  const memory::desc &diff_bias_desc,
4694  const memory::desc &diff_dst_layer_desc,
4695  const memory::desc &diff_dst_iter_desc,
4696  const memory::desc &diff_dst_iter_c_desc,
4697  rnn_flags flags = rnn_flags::undef) {
4700  dnnl::convert_to_c(aprop_kind),
4701  dnnl::convert_to_c(direction), &src_layer_desc.data,
4702  &src_iter_desc.data, &src_iter_c_desc.data,
4703  &weights_layer_desc.data, &weights_iter_desc.data,
4704  &bias_desc.data, &dst_layer_desc.data,
4705  &dst_iter_desc.data, &dst_iter_c_desc.data,
4706  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
4707  &diff_src_iter_c_desc.data,
4708  &diff_weights_layer_desc.data,
4709  &diff_weights_iter_desc.data, &diff_bias_desc.data,
4710  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
4711  &diff_dst_iter_c_desc.data,
4712  dnnl::convert_to_c(flags)),
4713  "could not create an LSTM backward descriptor");
4714  }
4715  };
4716 
4718  struct primitive_desc : public rnn_primitive_desc_base {
4719  primitive_desc() = default;
4720 
4721  primitive_desc(const desc &desc, const engine &e,
4722  const lstm_forward::primitive_desc &hint_fwd_pd,
4723  bool allow_empty = false)
4724  : rnn_primitive_desc_base(
4725  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4726 
4727  primitive_desc(const desc &desc, const primitive_attr &attr,
4728  const engine &e,
4729  const lstm_forward::primitive_desc &hint_fwd_pd,
4730  bool allow_empty = false)
4731  : rnn_primitive_desc_base(
4732  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4733 
4737  : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
4739 
4742  return query_md(query::src_md, 0);
4743  }
4744 
4750  return query_md(query::src_md, 1);
4751  }
4752 
4755  return query_md(query::src_md, 2);
4756  }
4757 
4760  return query_md(query::weights_md, 0);
4761  }
4762 
4765  return query_md(query::weights_md, 1);
4766  }
4767 
4773  return query_md(query::weights_md, 2);
4774  }
4775 
4778  return query_md(query::dst_md, 0);
4779  }
4780 
4786  return query_md(query::dst_md, 1);
4787  }
4788 
4791  return query_md(query::dst_md, 2);
4792  }
4793 
4798  return query_md(query::workspace_md, 0);
4799  }
4800 
4803  return query_md(query::diff_src_md, 0);
4804  }
4805 
4811  return query_md(query::diff_src_md, 1);
4812  }
4813 
4816  return query_md(query::diff_src_md, 2);
4817  }
4818 
4821  return query_md(query::diff_weights_md, 0);
4822  }
4823 
4826  return query_md(query::diff_weights_md, 1);
4827  }
4828 
4831  return query_md(query::diff_weights_md, 2);
4832  }
4833 
4836  return query_md(query::diff_dst_md, 0);
4837  }
4838 
4844  return query_md(query::diff_dst_md, 1);
4845  }
4846 
4849  return query_md(query::diff_dst_md, 2);
4850  }
4851  };
4852 
4853  lstm_backward() = default;
4854 
4855  // With last iteration (with and without input src_iter)
4856  lstm_backward(const primitive_desc &pd) : primitive(pd) {}
4857 };
4858 
4862 struct gru_forward : public primitive {
4863 
4865  struct desc {
4866  dnnl_rnn_desc_t data;
4867 
4885  desc(prop_kind aprop_kind, rnn_direction direction,
4886  const memory::desc &src_layer_desc,
4887  const memory::desc &src_iter_desc,
4888  const memory::desc &weights_layer_desc,
4889  const memory::desc &weights_iter_desc,
4890  const memory::desc &bias_desc,
4891  const memory::desc &dst_layer_desc,
4892  const memory::desc &dst_iter_desc,
4893  rnn_flags flags = rnn_flags::undef) {
4896  dnnl::convert_to_c(aprop_kind),
4897  dnnl::convert_to_c(direction), &src_layer_desc.data,
4898  &src_iter_desc.data, &weights_layer_desc.data,
4899  &weights_iter_desc.data, &bias_desc.data,
4900  &dst_layer_desc.data, &dst_iter_desc.data,
4901  dnnl::convert_to_c(flags)),
4902  "could not create a GRU forward descriptor");
4903  }
4904  };
4905 
4907  struct primitive_desc : public rnn_primitive_desc_base {
4908  primitive_desc() = default;
4909 
4911  const desc &desc, const engine &e, bool allow_empty = false)
4912  : rnn_primitive_desc_base(
4913  &desc.data, nullptr, e, nullptr, allow_empty) {}
4914 
4915  primitive_desc(const desc &desc, const primitive_attr &attr,
4916  const engine &e, bool allow_empty = false)
4917  : rnn_primitive_desc_base(
4918  &desc.data, &attr, e, nullptr, allow_empty) {}
4919 
4923  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
4925  dnnl::algorithm::vanilla_gru) {}
4926 
4929  return query_md(query::src_md, 0);
4930  }
4931 
4937  return query_md(query::src_md, 1);
4938  }
4939 
4942  return query_md(query::weights_md, 0);
4943  }
4944 
4947  return query_md(query::weights_md, 1);
4948  }
4949 
4955  return query_md(query::weights_md, 2);
4956  }
4957 
4960  return query_md(query::dst_md, 0);
4961  }
4962 
4968  return query_md(query::dst_md, 1);
4969  }
4970 
4975  return query_md(query::workspace_md, 0);
4976  }
4977  };
4978 
4979  gru_forward() = default;
4980 
4981  gru_forward(const primitive_desc &pd) : primitive(pd) {}
4982 };
4983 
4987 struct gru_backward : public primitive {
4988 
4990  struct desc {
4991  dnnl_rnn_desc_t data;
4992 
5008  desc(prop_kind aprop_kind, rnn_direction direction,
5009  const memory::desc &src_layer_desc,
5010  const memory::desc &src_iter_desc,
5011  const memory::desc &weights_layer_desc,
5012  const memory::desc &weights_iter_desc,
5013  const memory::desc &bias_desc,
5014  const memory::desc &dst_layer_desc,
5015  const memory::desc &dst_iter_desc,
5016  const memory::desc &diff_src_layer_desc,
5017  const memory::desc &diff_src_iter_desc,
5018  const memory::desc &diff_weights_layer_desc,
5019  const memory::desc &diff_weights_iter_desc,
5020  const memory::desc &diff_bias_desc,
5021  const memory::desc &diff_dst_layer_desc,
5022  const memory::desc &diff_dst_iter_desc,
5023  rnn_flags flags = rnn_flags::undef) {
5026  dnnl::convert_to_c(aprop_kind),
5027  dnnl::convert_to_c(direction), &src_layer_desc.data,
5028  &src_iter_desc.data, &weights_layer_desc.data,
5029  &weights_iter_desc.data, &bias_desc.data,
5030  &dst_layer_desc.data, &dst_iter_desc.data,
5031  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
5032  &diff_weights_layer_desc.data,
5033  &diff_weights_iter_desc.data, &diff_bias_desc.data,
5034  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
5035  dnnl::convert_to_c(flags)),
5036  "could not create an GRU backward descriptor");
5037  }
5038  };
5039 
5041  struct primitive_desc : public rnn_primitive_desc_base {
5042  primitive_desc() = default;
5043 
5044  primitive_desc(const desc &desc, const engine &e,
5045  const gru_forward::primitive_desc &hint_fwd_pd,
5046  bool allow_empty = false)
5047  : rnn_primitive_desc_base(
5048  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
5049 
5050  primitive_desc(const desc &desc, const primitive_attr &attr,
5051  const engine &e, const gru_forward::primitive_desc &hint_fwd_pd,
5052  bool allow_empty = false)
5053  : rnn_primitive_desc_base(
5054  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
5055 
5059  : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
5060  dnnl::algorithm::vanilla_gru) {}
5061 
5064  return query_md(query::src_md, 0);
5065  }
5066 
5072  return query_md(query::src_md, 1);
5073  }
5074 
5077  return query_md(query::weights_md, 0);
5078  }
5079 
5082  return query_md(query::weights_md, 1);
5083  }
5084 
5090  return query_md(query::weights_md, 2);
5091  }
5092 
5095  return query_md(query::dst_md, 0);
5096  }
5097 
5103  return query_md(query::dst_md, 1);
5104  }
5105 
5110  return query_md(query::workspace_md, 0);
5111  }
5112 
5115  return query_md(query::diff_src_md, 0);
5116  }
5117 
5123  return query_md(query::diff_src_md, 1);
5124  }
5125 
5128  return query_md(query::diff_weights_md, 0);
5129  }
5130 
5133  return query_md(query::diff_weights_md, 1);
5134  }
5135 
5138  return query_md(query::diff_weights_md, 2);
5139  }
5140 
5143  return query_md(query::diff_dst_md, 0);
5144  }
5145 
5151  return query_md(query::diff_dst_md, 1);
5152  }
5153  };
5154 
5155  gru_backward() = default;
5156 
5157  // With last iteration (with and without input src_iter)
5158  gru_backward(const primitive_desc &pd) : primitive(pd) {}
5159 };
5160 
5164 struct lbr_gru_forward : public primitive {
5165 
5167  struct desc {
5168  dnnl_rnn_desc_t data;
5169 
5187  desc(prop_kind aprop_kind, rnn_direction direction,
5188  const memory::desc &src_layer_desc,
5189  const memory::desc &src_iter_desc,
5190  const memory::desc &weights_layer_desc,
5191  const memory::desc &weights_iter_desc,
5192  const memory::desc &bias_desc,
5193  const memory::desc &dst_layer_desc,
5194  const memory::desc &dst_iter_desc,
5195  rnn_flags flags = rnn_flags::undef) {
5198  dnnl::convert_to_c(aprop_kind),
5199  dnnl::convert_to_c(direction), &src_layer_desc.data,
5200  &src_iter_desc.data, &weights_layer_desc.data,
5201  &weights_iter_desc.data, &bias_desc.data,
5202  &dst_layer_desc.data, &dst_iter_desc.data,
5203  dnnl::convert_to_c(flags)),
5204  "could not create a Linear-before-reset GRU forward "
5205  "descriptor");
5206  }
5207  };
5208 
5210  struct primitive_desc : public rnn_primitive_desc_base {
5211  primitive_desc() = default;
5212 
5214  const desc &desc, const engine &e, bool allow_empty = false)
5215  : rnn_primitive_desc_base(
5216  &desc.data, nullptr, e, nullptr, allow_empty) {}
5217 
5218  primitive_desc(const desc &desc, const primitive_attr &attr,
5219  const engine &e, bool allow_empty = false)
5220  : rnn_primitive_desc_base(
5221  &desc.data, &attr, e, nullptr, allow_empty) {}
5222 
5226  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
5228  dnnl::algorithm::lbr_gru) {}
5229 
5232  return query_md(query::src_md, 0);
5233  }
5234 
5240  return query_md(query::src_md, 1);
5241  }
5242 
5245  return query_md(query::weights_md, 0);
5246  }
5247 
5250  return query_md(query::weights_md, 1);
5251  }
5252 
5258  return query_md(query::weights_md, 2);
5259  }
5260 
5263  return query_md(query::dst_md, 0);
5264  }
5265 
5271  return query_md(query::dst_md, 1);
5272  }
5273 
5278  return query_md(query::workspace_md, 0);
5279  }
5280  };
5281 
5282  lbr_gru_forward() = default;
5283 
5284  lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
5285 };
5286 
5290 struct lbr_gru_backward : public primitive {
5291 
5293  struct desc {
5294  dnnl_rnn_desc_t data;
5295 
5311  desc(prop_kind aprop_kind, rnn_direction direction,
5312  const memory::desc &src_layer_desc,
5313  const memory::desc &src_iter_desc,
5314  const memory::desc &weights_layer_desc,
5315  const memory::desc &weights_iter_desc,
5316  const memory::desc &bias_desc,
5317  const memory::desc &dst_layer_desc,
5318  const memory::desc &dst_iter_desc,
5319  const memory::desc &diff_src_layer_desc,
5320  const memory::desc &diff_src_iter_desc,
5321  const memory::desc &diff_weights_layer_desc,
5322  const memory::desc &diff_weights_iter_desc,
5323  const memory::desc &diff_bias_desc,
5324  const memory::desc &diff_dst_layer_desc,
5325  const memory::desc &diff_dst_iter_desc,
5326  rnn_flags flags = rnn_flags::undef) {
5329  dnnl::convert_to_c(aprop_kind),
5330  dnnl::convert_to_c(direction), &src_layer_desc.data,
5331  &src_iter_desc.data, &weights_layer_desc.data,
5332  &weights_iter_desc.data, &bias_desc.data,
5333  &dst_layer_desc.data, &dst_iter_desc.data,
5334  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
5335  &diff_weights_layer_desc.data,
5336  &diff_weights_iter_desc.data, &diff_bias_desc.data,
5337  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
5338  dnnl::convert_to_c(flags)),
5339  "could not create an LBR_GRU backward descriptor");
5340  }
5341  };
5342 
5344  struct primitive_desc : public rnn_primitive_desc_base {
5345  primitive_desc() = default;
5346 
5347  primitive_desc(const desc &desc, const engine &e,
5348  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
5349  bool allow_empty = false)
5350  : rnn_primitive_desc_base(
5351  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
5352 
5353  primitive_desc(const desc &desc, const primitive_attr &attr,
5354  const engine &e,
5355  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
5356  bool allow_empty = false)
5357  : rnn_primitive_desc_base(
5358  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
5359 
5363  : rnn_primitive_desc_base(
5365 
5368  return query_md(query::src_md, 0);
5369  }
5370 
5376  return query_md(query::src_md, 1);
5377  }
5378 
5381  return query_md(query::weights_md, 0);
5382  }
5383 
5386  return query_md(query::weights_md, 1);
5387  }
5388 
5394  return query_md(query::weights_md, 2);
5395  }
5396 
5399  return query_md(query::dst_md, 0);
5400  }
5401 
5407  return query_md(query::dst_md, 1);
5408  }
5409 
5414  return query_md(query::workspace_md, 0);
5415  }
5416 
5419  return query_md(query::diff_src_md, 0);
5420  }
5421 
5427  return query_md(query::diff_src_md, 1);
5428  }
5429 
5432  return query_md(query::diff_weights_md, 0);
5433  }
5434 
5437  return query_md(query::diff_weights_md, 1);
5438  }
5439 
5442  return query_md(query::diff_weights_md, 2);
5443  }
5444 
5447  return query_md(query::diff_dst_md, 0);
5448  }
5449 
5455  return query_md(query::diff_dst_md, 1);
5456  }
5457  };
5458 
5459  lbr_gru_backward() = default;
5460 
5461  lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
5462 };
5463 
5465 
5472 
5475 struct shuffle_forward : public primitive {
5476 
5478  struct desc {
5479  dnnl_shuffle_desc_t data;
5480 
5484  desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis,
5485  int group_size) {
5487  dnnl::convert_to_c(aprop_kind),
5488  &data_desc.data, axis, group_size),
5489  "could not create a shuffle forward descriptor");
5490  }
5491  };
5492 
5495  primitive_desc() = default;
5496 
5497  primitive_desc(const desc &desc, const engine &e,
5498  const primitive_attr &aattr = primitive_attr(),
5499  bool allow_empty = false)
5501  &desc.data, &aattr, e, nullptr, allow_empty) {}
5502 
5506  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
5509 
5512 
5515  };
5516 
5517  shuffle_forward() = default;
5518 
5519  shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
5520 };
5521 
5524 struct shuffle_backward : public primitive {
5525 
5526  // Descriptor for shuffle backward propagation.
5527  struct desc {
5528  dnnl_shuffle_desc_t data;
5529 
5532  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
5534  &diff_data_desc.data, axis, group_size),
5535  "could not create a shuffle backward descriptor");
5536  }
5537  };
5538 
5539  // Primitive descriptor for shuffle backward propagation.
5540  struct primitive_desc : public dnnl::primitive_desc {
5541  primitive_desc() = default;
5542 
5543  primitive_desc(const desc &desc, const engine &e,
5544  const shuffle_forward::primitive_desc &hint_fwd_pd,
5545  const primitive_attr &aattr = primitive_attr(),
5546  bool allow_empty = false)
5548  &desc.data, &aattr, e, hint_fwd_pd.get(), allow_empty) {}
5549 
5552  primitive_desc(dnnl_primitive_desc_t pd)
5555 
5557  memory::desc diff_src_desc() const {
5558  return query_md(query::diff_src_md, 0);
5559  }
5560 
5562  memory::desc diff_dst_desc() const {
5563  return query_md(query::diff_dst_md, 0);
5564  }
5565  };
5566 
5567  shuffle_backward() = default;
5568 
5569  shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
5570 };
5571 
5573 
5580 
5583 struct binary : public primitive {
5584 
5586  struct desc {
5587  dnnl_binary_desc_t data;
5588 
5591  desc(algorithm aalgorithm, const memory::desc &src0,
5592  const memory::desc &src1, const memory::desc &dst) {
5594  dnnl_binary_desc_init(&data, dnnl::convert_to_c(aalgorithm),
5595  &src0.data, &src1.data, &dst.data),
5596  "could not create a binary descriptor");
5597  }
5598  };
5599 
5600  struct primitive_desc : public dnnl::primitive_desc {
5601  primitive_desc() = default;
5602 
5605  const desc &desc, const engine &e, bool allow_empty = false)
5606  : dnnl::primitive_desc(
5607  &desc.data, nullptr, e, nullptr, allow_empty) {}
5608 
5611  primitive_desc(
5612  const desc &desc, const primitive_attr &attr, const engine &e)
5613  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr) {}
5614 
5618  : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
5619 
5621  memory::desc src0_desc() const { return query_md(query::src_md, 0); }
5622 
5624  memory::desc src1_desc() const { return query_md(query::src_md, 1); }
5625 
5627  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
5628  };
5629 
5630  binary() = default;
5631 
5632  binary(const primitive_desc &pd) : primitive(pd) {}
5633 };
5634 
5636 
5638 
5640 
5641 // implementation section
5642 
5644 inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
5645  dnnl_primitive_t result;
5647  "could not create a primitive");
5648  reset(result);
5649 }
5650 
5651 inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
5652 
5653 inline void primitive::execute(
5654  stream &astream, const std::unordered_map<int, memory> &args) const {
5655  std::vector<dnnl_exec_arg_t> c_args;
5656  c_args.reserve(args.size());
5657  for (const auto &a : args)
5658  c_args.push_back({a.first, a.second.get()});
5659 
5660  error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
5661  (int)c_args.size(), c_args.data()),
5662  "could not execute a primitive");
5663 }
5665 
5666 } // namespace dnnl
5667 
5668 #endif
layer normalization for forward propagation.
Definition: dnnl.hpp:3700
cl_device_id get_ocl_device() const
Returns the OpenCL device associated with the engine.
Definition: dnnl.hpp:902
deconvolution_forward(const primitive_desc &pd)
Creates a deconvolution forward propagation primitive from the corresponding primitive descriptor...
Definition: dnnl.hpp:2634
4D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcd8b ...
Definition: dnnl_types.h:454
1D tensor, an alias to dnnl_a
Definition: dnnl_types.h:338
void * get_data_handle() const
Returns a handle of the data contained in the memory.
Definition: dnnl.hpp:1491
flags
Stream flags.
Definition: dnnl.hpp:951
memory::desc diff_weights_desc() const
Queries diff weights memory descriptor.
Definition: dnnl.hpp:2883
Primitive descriptor for shuffle forward propagation.
Definition: dnnl.hpp:5494
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated convolution forward propagation with bias using prop_kind (possi...
Definition: dnnl.hpp:2125
memory(const desc &md, const engine &aengine, void *ahandle)
Constructs a memory.
Definition: dnnl.hpp:1457
desc reshape(const dims &adims)
Constructs a memory descriptor by reshaping existing one.
Definition: dnnl.hpp:1428
2D RNN statistics tensor, an alias to dnnl_ba
Definition: dnnl_types.h:346
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4307
dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible values are dnnl_forward_...
memory::desc src_iter_c_desc() const
Queries source recurrent cell state memory descriptor.
Definition: dnnl.hpp:4754
dnnl_data_type_t
Data type specification.
Definition: dnnl_types.h:68
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3446
32-bit signed integer.
Definition: dnnl_types.h:78
Deconvolution forward propagation.
Definition: dnnl.hpp:2475
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:5249
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2963
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr)
Deletes an attr.
memory::desc diff_dst_iter_desc() const
Queries diff destination recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4843
Primitive descriptor for pooling forward propagation.
Definition: dnnl.hpp:3091
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for convolution backward propagation from a C primitive descriptor...
Definition: dnnl.hpp:2276
RNN descriptor for backward propagation.
Definition: dnnl.hpp:4338
Max pooling.
Definition: dnnl_types.h:695
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:5514
3D CNN activations tensor, an alias to dnnl_abc
Definition: dnnl_types.h:348
5D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcde8b ...
Definition: dnnl_types.h:445
A descriptor for an RNN operation.
Definition: dnnl_types.h:1235
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4764
The operation failed because requested functionality is not implemented.
Definition: dnnl_types.h:58
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
Definition: dnnl_types.h:74
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Initializes a descriptor for backward propagation using aalgorithm, memory descriptors data_desc and ...
Definition: dnnl.hpp:2993
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a pooling descriptor for forward propagation using aprop_kind (possible values are dnnl::...
Definition: dnnl.hpp:3073
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for batch normalization backward propagation from a C primitive de...
Definition: dnnl.hpp:3624
Primitive descriptor for convolution backward propagation.
Definition: dnnl.hpp:2254
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:2966
Primitive descriptor for pooling backward propagation.
Definition: dnnl.hpp:3157
Forward data propagation, alias for dnnl::prop_kind::forward_training.
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Initializes an RNN descriptor for forward propagation using prop_kind, activation, direction, and memory descriptors.
Definition: dnnl.hpp:4231
memory::desc dst_iter_desc() const
Queries destination recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4629
2D CNN activations tensor, an alias to dnnl_ab
Definition: dnnl_types.h:340
3D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:223
5D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcde4b ...
Definition: dnnl_types.h:442
memory::desc diff_weights_desc() const
Queries diff weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3900
Average pooling exclude padding.
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:4065
plain 1D tensor
Definition: dnnl_types.h:183
A reorder primitive.
Definition: dnnl_types.h:619
dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
void set_scratchpad_mode(scratchpad_mode mode)
Sets scratchpad mode.
Definition: dnnl.hpp:702
pooling descriptor
Definition: dnnl_types.h:1591
Eltwise: soft_relu.
Definition: dnnl_types.h:682
An opaque structure to describe a primitive.
deconvolution_backward_weights(const primitive_desc &pd)
Creates a deconvolution weight update primitive from the corresponding primitive descriptor.
Definition: dnnl.hpp:2902
Definition: dnnl.hpp:40
Descriptor for convolution forward propagation.
Definition: dnnl.hpp:2038
permuted 5D tensor
Definition: dnnl_types.h:205
An inner product primitive.
Definition: dnnl_types.h:643
memory::desc diff_weights_desc() const
Queries diff weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3667
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution backward propagation using aalgorithm, memory descriptors...
Definition: dnnl.hpp:2653
memory::desc src_iter_desc() const
Queries source recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4749
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a layer normalization descriptor lnrm_desc for backward propagation with respect to data ...
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:4424
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:2182
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Appends eltwise post operation.
Definition: dnnl.hpp:648
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:5385
desc(const dims &adims, data_type adata_type, const dims &astrides)
Constructs a memory descriptor by strides.
Definition: dnnl.hpp:1400
permuted 3D tensor
Definition: dnnl_types.h:193
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3267
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:5375
Primitive descriptor for convolution weight update.
Definition: dnnl.hpp:2412
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: dnnl.hpp:840
Forward data propagation (training mode).
Definition: dnnl_types.h:594
dnnl_query_t
Primitive descriptor query specification.
Definition: dnnl_types.h:1558
Winograd convolution.
Definition: dnnl_types.h:658
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes an GRU descriptor for backward propagation using prop_kind, direction, and memory descriptors.
Definition: dnnl.hpp:5008
Primitive iterator passed over last primitive descriptor.
Definition: dnnl_types.h:60
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:5109
4D CNN weights tensor, an alias to dnnl_abcd
Definition: dnnl_types.h:375
Eltwise: exponent.
Definition: dnnl_types.h:686
5D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:276
Shuffle for backward propagation.
Definition: dnnl.hpp:5524
Eltwise: abs.
Definition: dnnl_types.h:674
dnnl_status_t DNNL_API dnnl_engine_create_ocl(dnnl_engine_t *engine, dnnl_engine_kind_t kind, cl_device_id device, cl_context context)
Creates an engine of particular kind associated with a given OpenCL device and context objects...
An opaque structure for primitive descriptor attributes.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes a primitive descriptor for deconvolution weight update with attributes defined by attr...
Definition: dnnl.hpp:2866
cl_context get_ocl_context() const
Returns the OpenCL context associated with the engine.
Definition: dnnl.hpp:894
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:5270
3D CNN weights tensor, an alias to dnnl_cba
Definition: dnnl_types.h:371
Deconvolution weight update.
Definition: dnnl.hpp:2749
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:4147
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:4802
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:5511
Packed weights format used in RNN.
void get_output_scales(int &mask, std::vector< float > &scales) const
Gets correspondence scale mask and a constant floating point vector of output scales previously set b...
Definition: dnnl.hpp:710
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution backward propagation using aalgorithm, memory descriptors...
Definition: dnnl.hpp:2213
memory::desc src_iter_desc() const
Queries source recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4593
int64_t dnnl_dim_t
A type to describe tensor dimension.
Definition: dnnl_types.h:778
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2614
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution forward propagation with bias using prop_kind (possible valu...
Definition: dnnl.hpp:2074
5D CNN activations tensor, an alias to dnnl_acdeb
Definition: dnnl_types.h:360
softmax descriptor
2D CNN activations tensor, an alias to dnnl_ba
Definition: dnnl_types.h:342
A class that provides the destructor for an DNNL C handle.
Definition: dnnl.hpp:78
Local response normalization for forward propagation.
Definition: dnnl.hpp:2917
A sum primitive.
Definition: dnnl_types.h:625
An opaque structure to describe a memory.
Initializes an inner product descriptor for backward propagation with respect to data using memory de...
Definition: dnnl.hpp:4017
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution forward propagation without bias using aprop_kind (possible ...
Definition: dnnl.hpp:2049
source memory desc
Definition: dnnl_types.h:1602
memory consumption (bytes)
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3890
dnnl_memory_desc_t data
The underlying C API data structure.
Definition: dnnl.hpp:1375
Undefined primitive.
Definition: dnnl_types.h:617
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args)
Executes a primitive using a stream, and nargs arguments args.
LBR_GRU for backward propagation.
Definition: dnnl.hpp:5290
Descriptor for layer normalization forward propagation.
Definition: dnnl.hpp:3703
softmax descriptor
Definition: dnnl_types.h:1590
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3328
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4289
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2171
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes a primitive descriptor for deconvolution backward propagation with attributes defined by ...
Definition: dnnl.hpp:2709
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for softmax forward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:3383
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:4967
An LRN primitive.
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: dnnl_types.h:95
desc(algorithm aalgorithm, const memory::desc &src0, const memory::desc &src1, const memory::desc &dst)
Initializes a binary descriptor using algorithm, memory descriptors src0_desc, src1_desc and dst_desc...
Definition: dnnl.hpp:5591
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
permuted 5D tensor
Definition: dnnl_types.h:206
Convolution algorithm(either direct or Winograd) is chosen just in time.
Bidirectional execution of RNN primitive with summation of the results.
Definition: dnnl_types.h:1230
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4322
primitive::kind kind(int index) const
Returns the kind of post operation with index index.
Definition: dnnl.hpp:601
dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream)
Destroys an execution stream.
permuted 5D tensor
Definition: dnnl_types.h:192
A pooling primitive.
Definition: dnnl_types.h:635
A user shall query and provide the scratchpad memory to primitives.
Forward data propagation (training mode).
3D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBc16b ...
Definition: dnnl_types.h:457
Initializes an inner product descriptor for backward propagation with respect to weights using memory...
Definition: dnnl.hpp:4084
number of inputs expected
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode)
Returns the scratchpad mode set in the attribute attr.
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3115
implementation name
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4928
weights memory descriptor desc
Definition: dnnl_types.h:1604
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:4825
4D CNN activations tensor, an alias to dnnl::memory::format_tag::abcd
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:5257
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:3987
4D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcd4b ...
Definition: dnnl_types.h:451
dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
plain 4D tensor
Definition: dnnl_types.h:186
dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(dnnl_memory_desc_t *memory_desc, const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims, const dnnl_dims_t offsets)
Initializes a memory_desc for a given parent_memory_desc, with dims sizes and offsets.
The operation was successful.
Definition: dnnl_types.h:52
Descriptor for batch normalization backward propagation.
Definition: dnnl.hpp:3582
plain 6D tensor
Definition: dnnl_types.h:188
memory::desc diff_dst_iter_desc() const
Queries diff destination iteration memory descriptor.
Definition: dnnl.hpp:4503
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:4820
A descriptor of a pooling operation.
Definition: dnnl_types.h:1069
A (out-of-place) concat primitive.
Definition: dnnl_types.h:623
Eltwise: parametric exponential linear unit (elu)
Definition: dnnl_types.h:670
kind get_kind() const
Returns the kind of the engine.
Definition: dnnl.hpp:885
A memory descriptor.
Definition: dnnl.hpp:1372
destination grad. memory desc
Definition: dnnl_types.h:1607
C API.
source gradient memory desc
Definition: dnnl_types.h:1603
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:2187
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:5094
desc(const dims &adims, data_type adata_type, format_tag aformat_tag)
Constructs a memory descriptor.
Definition: dnnl.hpp:1385
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes a GRU descriptor for forward propagation using prop_kind, direction, and memory descripto...
Definition: dnnl.hpp:4885
8-bit signed integer.
Definition: dnnl_types.h:80
memory::desc diff_dst_iter_desc() const
Queries diff destination iteration memory descriptor.
Definition: dnnl.hpp:5150
16-bit/half-precision floating point.
Backward bias propagation.
Definition: dnnl_types.h:610
4D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:245
desc(prop_kind aprop_kind, const memory::desc &src_desc, float epsilon, normalization_flags flags)
Initializes a batch normalization descriptor for forward propagation using prop_kind (possible values...
Definition: dnnl.hpp:3497
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4621
dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a GRU descriptor rnn_desc for backward propagation using prop_kind, direction, and memory descriptors.
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(dnnl_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
Bidirectional execution of RNN primitive with concatenation of the results.
Definition: dnnl_types.h:1227
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2880
const dnnl_memory_desc_t DNNL_API * dnnl_primitive_desc_query_md(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index)
Queries primitive descriptor for memory descriptor.
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:5431
permuted 2D tensor
Out-of-order execution.
Definition: dnnl_types.h:1625
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
Primitive descriptor for local response normalization forward propagation.
Definition: dnnl.hpp:2941
Descriptor for LSTM forward propagation.
Definition: dnnl.hpp:4519
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:5102
An opaque structure to describe a primitive descriptor iterator.
A tensor in a generic format described by the stride and blocking values in each dimension.
workspace memory desc
Definition: dnnl_types.h:1608
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:5418
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated convolution forward propagation without bias using prop_kind (po...
Definition: dnnl.hpp:2098
memory(const desc &md, const engine &aengine)
Constructs a memory.
Definition: dnnl.hpp:1469
An execution engine.
Definition: dnnl.hpp:821
Backward data propagation.
Definition: dnnl_types.h:606
no query
Definition: dnnl_types.h:1559
Average pooling include padding.
GRU for forward propagation.
Definition: dnnl.hpp:4862
Primitive descriptor for layer normalization forward propagation.
Definition: dnnl.hpp:3737
convolution_backward_data(const primitive_desc &pd)
Creates a convolution backward propagation primitive from the corresponding primitive descriptor...
Definition: dnnl.hpp:2300
const char * impl_info_str() const
Returns implementation name.
Definition: dnnl.hpp:1608
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:5380
void unmap_data(void *mapped_ptr) const
Unmaps the previously mapped data for the memory.
Definition: dnnl.hpp:1534
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4954
scratchpad memory desc
Definition: dnnl_types.h:1609
Descriptor for local response normalization backward propagation.
Definition: dnnl.hpp:2986
Primitive descriptor for inner product backward propagation with respect to weights.
Definition: dnnl.hpp:4111
4D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:237
A softmax primitive.
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4974
deconvolution descriptor
Definition: dnnl_types.h:1587
dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are dnnl_forward_tra...
GRU cell with linear before reset.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Backward propagation (with respect to all parameters).
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:5244
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LSTM descriptor for forward propagation using prop_kind, direction, and memory descriptors.
Definition: dnnl.hpp:4539
Descriptor for local response normalization forward propagation.
Definition: dnnl.hpp:2920
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:4134
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for convolution forward propagation from a C primitive descriptor ...
Definition: dnnl.hpp:2165
LSTM for backward propagation.
Definition: dnnl.hpp:4654
engine(kind akind, size_t index)
Constructs an engine.
Definition: dnnl.hpp:850
dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
memory::dim query_s64(query q) const
Queries the memory::dim value (same as int64_t).
Definition: dnnl.hpp:1617
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc)
Creates a primitive using a primitive_desc descriptor.
int DNNL_API dnnl_memory_desc_equal(const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs)
Compares two memory descriptors.
memory::desc diff_weights_desc() const
Queries diff weights memory descriptor.
Definition: dnnl.hpp:2441
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:3770
16-bit/half-precision floating point.
Definition: dnnl_types.h:72
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(const_dnnl_post_ops_t post_ops, int index)
Returns the kind of post operation with index index in given post_ops.
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Gets the eltwise parameters of the post operation with index index.
Definition: dnnl.hpp:655
op descriptor
Definition: dnnl_types.h:1585
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(const_dnnl_primitive_desc_t primitive_desc, const_dnnl_primitive_attr_t *attr)
Returns a constant reference to the attribute of a primitive_desc.
Undefined memory format tag.
Definition: dnnl_types.h:172
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:5398
propagation kind
Definition: dnnl_types.h:1581
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a layer normalization descriptor lnrm_desc for forward propagation using prop_kind (possi...
memory::desc diff_src_desc() const
Queries diff source gradient memory descriptor.
Definition: dnnl.hpp:2281
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:4490
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3882
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
dnnl_normalization_flags_t
Flags for batch normalization primitive.
Definition: dnnl_types.h:726
A reorder primitive.
Inner product for forward propagation.
Definition: dnnl.hpp:3928
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3185
A descriptor of a binary operation.
Definition: dnnl_types.h:1302
Eltwise: bounded_relu.
Definition: dnnl_types.h:680
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:4284
engine get_engine() const
Returns the engine of the primitive descriptor.
Definition: dnnl.hpp:1605
inner product descriptor
Definition: dnnl_types.h:1595
4D CNN weights tensor, an alias to dnnl_bcda
Definition: dnnl_types.h:381
Element-wise operations for backward propagation.
Definition: dnnl.hpp:3277
2D CNN activations tensor, an alias to dnnl::memory::format_tag::ab
deconvolution_backward_data(const primitive_desc &pd)
Creates a deconvolution backward propagation primitive from the corresponding primitive descriptor...
Definition: dnnl.hpp:2742
primitive kind
Definition: dnnl_types.h:1562
void set_rnn_data_qparams(float scale, float shift)
Sets quantization scale and shift for RNN data tensors.
Definition: dnnl.hpp:769
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4741
An element-wise primitive.
Packed weights format used in RNN.
Definition: dnnl_types.h:99
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:4142
32-bit signed integer.
Initializes an eltwise descriptor for forward propagation using prop_kind (possible values are dnnl::...
Definition: dnnl.hpp:3230
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:5413
T get(bool allow_emtpy=false) const
Returns the value of the underlying C handle.
Definition: dnnl.hpp:133
Inner product for backward propagation with respect to weights.
Definition: dnnl.hpp:4077
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, bool allow_empty=false)
Initializes primitive descriptor for deconvolution forward propagation with attributes defined by att...
Definition: dnnl.hpp:2598
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:2888
Use scale and shift parameters.
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:5393
A descriptor of a convolution operation.
Definition: dnnl_types.h:957
3D CNN activations tensor, an alias to dnnl_acb
Definition: dnnl_types.h:350
Default order execution.
dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory)
Deletes a memory.
dnnl_status_t DNNL_API dnnl_engine_get_ocl_device(dnnl_engine_t engine, cl_device_id *device)
Returns an OpenCL device associated with an engine.
dnnl_status_t DNNL_API dnnl_memory_set_data_handle(dnnl_memory_t memory, void *handle)
For a memory, sets the data handle.
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for local response normalization forward propagation from a C prim...
Definition: dnnl.hpp:2957
memory::desc src_iter_desc() const
Queries source iter memory descriptor.
Definition: dnnl.hpp:5071
dnnl_primitive_kind_t
Kinds of primitives.
Definition: dnnl_types.h:615
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for GRU forward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:4922
permuted 5D tensor
Definition: dnnl_types.h:194
Default stream configuration.
Definition: dnnl_types.h:1627
format_kind
Memory format kind.
Definition: dnnl.hpp:1060
desc(const dnnl_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: dnnl.hpp:1413
Base class for all computational primitives.
Definition: dnnl.hpp:181
Backward weights propagation.
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3984
data_type
Data type specification.
Definition: dnnl.hpp:1042
rnn descriptor
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, const dnnl_dims_t strides)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and strides.
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4641
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4759
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Initializes a softmax descriptor for backward propagation using memory descriptors diff_desc and data...
Definition: dnnl.hpp:3410
Binary add.
Definition: dnnl_types.h:720
Creates an out-of-place sum primitive descriptor for sum of n inputs multiplied by the scale with res...
Definition: dnnl.hpp:1898
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for deconvolution forward propagation from a C primitive descripto...
Definition: dnnl.hpp:2605
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:5076
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:5406
batch normalization descriptor
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for shuffle forward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:5505
5D CNN weights tensor, an alias to dnnl_abcde
Definition: dnnl_types.h:385
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:4936
Undefined primitive.
Descriptor for LBR GRU forward propagation.
Definition: dnnl.hpp:5167
A pooling primitive.
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes descriptor for dilated deconvolution backward propagation using aalgorithm, memory descriptors, strides, padding_l, and padding_r.
Definition: dnnl.hpp:2675
memory::desc diff_src_desc() const
Queries diff source gradient memory descriptor.
Definition: dnnl.hpp:4055
Batch normalization for forward propagation.
Definition: dnnl.hpp:3483
Convolution forward propagation.
Definition: dnnl.hpp:2035
Primitive descriptor for deconvolution weight update.
Definition: dnnl.hpp:2854
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:3634
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3192
source gradient memory desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3650
Initializes an eltwise descriptor for backward propagation using aalgorithm algorithm memory descript...
Definition: dnnl.hpp:3282
shuffle descriptor
Definition: dnnl_types.h:1588
memory::desc diff_src_iter_desc() const
Queries diff source iteration memory descriptor.
Definition: dnnl.hpp:5426
cl_command_queue get_ocl_command_queue() const
Returns the OpenCL command queue associated with the stream.
Definition: dnnl.hpp:985
binary descriptor
Definition: dnnl_types.h:1598
memory::desc scratchpad_desc() const
Queries scratchpad memory descriptor.
Definition: dnnl.hpp:1642
Element-wise operations for forward propagation.
Definition: dnnl.hpp:3224
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4946
Local response normalization for backward propagation.
Definition: dnnl.hpp:2983
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4777
Eltwise: square root.
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:5137
runtime estimation (seconds)
Definition: dnnl_types.h:1567
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode)
Sets scratchpad mode.
dnnl_status_t DNNL_API dnnl_stream_get_ocl_command_queue(dnnl_stream_t stream, cl_command_queue *queue)
Returns the OpenCL command queue associated with an execution stream.
memory::desc diff_src_iter_desc() const
Queries diff source recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4810
An opaque structure to describe an engine.
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3264
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr, const engine &e, const_dnnl_primitive_desc_t hint_fwd_pd, bool allow_empty=false)
Creates a primitive descriptor from given op_desc, attr, engine, and optionally a hint primitive desc...
Definition: dnnl.hpp:1978
A (out-of-place) concat primitive.
Vanilla RNN for forward propagation.
Definition: dnnl.hpp:4207
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2438
inner product descriptor
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and a pointer to a constant floating point array of output ...
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:2971
6D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:293
5D CNN weights tensor (incl. groups), an alias to dnnl_acbde
Definition: dnnl_types.h:400
memory::desc query_md(query what, int idx=0) const
Queries and returns requested memory descriptor.
Definition: dnnl.hpp:1625
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:3877
DNNL exception class.
Definition: dnnl.hpp:52
dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...
Eltwise: linear.
Definition: dnnl_types.h:678
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4434
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes an RNN descriptor rnn_desc for backward propagation using prop_kind, activation, direction, and memory descriptors.
8-bit unsigned integer.
Definition: dnnl_types.h:82
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LBR_GRU descriptor for backward propagation using prop_kind, direction, and memory descriptors.
Definition: dnnl.hpp:5311
plain 2D tensor
Definition: dnnl_types.h:184
permuted 4D tensor
Definition: dnnl_types.h:204
Backward weights propagation.
Definition: dnnl_types.h:608
A descriptor of a element-wise operation.
Definition: dnnl_types.h:1017
A descriptor of a Softmax operation.
Definition: dnnl_types.h:1053
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive)
Deletes a primitive.
Descriptor for convolution weight update.
Definition: dnnl.hpp:2310
An opaque structure for a chain of post operations.
algorithm
Kinds of algorithms.
Definition: dnnl.hpp:296
A descriptor of an inner product operation.
Definition: dnnl_types.h:1188
dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine)
Destroys an engine.
An LRN primitive.
Definition: dnnl_types.h:637
Descriptor for deconvolution weight update.
Definition: dnnl.hpp:2752
int len() const
Returns the length of post operations.
Definition: dnnl.hpp:598
A batch normalization primitive.
Definition: dnnl_types.h:639
Pooling for forward propagation.
Definition: dnnl.hpp:3062
A rnn primitive.
Definition: dnnl_types.h:645
Average pooling include padding.
Definition: dnnl_types.h:697
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2174
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4416
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes primitive descriptor for convolution backward propagation.
Definition: dnnl.hpp:2259
execution engine
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for inner product forward propagation from a C primitive descripto...
Definition: dnnl.hpp:3978
dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc, const dnnl_memory_desc_t *src1_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a binary descriptor binary_desc, alg_kind (possible values are dnnl_binary_add and dnnl_b...
Out-of-order execution.
Local response normalization (LRN) across multiple channels.
Definition: dnnl_types.h:702
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a pooling descriptor for backward propagation using aalgorithm, memory descriptors...
Definition: dnnl.hpp:3139
handle(T t, bool weak=false)
Constructs a C handle wrapper from a C handle.
Definition: dnnl.hpp:123
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3180
Undefined data type, used for empty memory descriptors.
memory::desc dst_iter_c_desc() const
Queries destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:4634
error(dnnl_status_t astatus, const char *amessage)
Constructs an error instance.
Definition: dnnl.hpp:60
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3120
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4772
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:5127
Primitive descriptor for batch normalization forward propagation.
Definition: dnnl.hpp:3509
dnnl_rnn_flags_t
Flags for RNN cell.
Definition: dnnl_types.h:1217
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops)
Deletes a post_ops sequence.
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(dnnl_primitive_desc_t *sum_primitive_desc, const dnnl_memory_desc_t *dst_mds, int n, const float *scales, const dnnl_memory_desc_t *src_mds, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
weights memory descriptor desc
Use scale and shift parameters.
Definition: dnnl_types.h:751
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4585
lrn descriptor
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for layer normalization backward propagation from a C primitive de...
Definition: dnnl.hpp:3864
primitive_desc(const desc &desc, const engine &e, bool allow_empty=false)
Initializes a primitive descriptor for deconvolution forward propagation.
Definition: dnnl.hpp:2591
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:5081
dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc and dif...
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3320
4D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:231
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(dnnl_primitive_desc_t *primitive_desc, const_dnnl_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
dnnl_format_tag_t
Memory format tag specification.
Definition: dnnl_types.h:170
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:5262
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:5277
stream & wait()
Waits for all primitives in the stream to finish.
Definition: dnnl.hpp:994
weights grad. memory desc
for creating scratchpad memory
Definition: dnnl_types.h:1576
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3647
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for batch normalization forward propagation from a C primitive des...
Definition: dnnl.hpp:3524
Shuffle for forward propagation.
Definition: dnnl.hpp:5475
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Eltwise: gelu.
Definition: dnnl_types.h:691
post_ops()
Creates an empty sequence of post operations.
Definition: dnnl.hpp:590
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3759
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes an LSTM descriptor rnn_desc for backward propagation using prop_kind, direction...
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: dnnl.hpp:1998
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for RNN backward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:4411
weights grad. memory desc
Definition: dnnl_types.h:1605
Backward propagation (with respect to all parameters).
Definition: dnnl_types.h:604
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3451
dnnl_status_t
Status values returned by the library functions.
Definition: dnnl_types.h:50
const void * const_dnnl_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: dnnl_types.h:954
Implements primitive descriptor and primitive for concat.
Definition: dnnl.hpp:1838
Primitive attributes.
Definition: dnnl.hpp:675
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:5367
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(const_dnnl_primitive_t primitive, const_dnnl_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2286
2D RNN statistics tensor, an alias to dnnl_ab
Definition: dnnl_types.h:344
4D CNN weights tensor, an alias to dnnl_cdba
Definition: dnnl_types.h:377
4D CNN activations tensor, an alias to dnnl_abcd
Definition: dnnl_types.h:352
kind
Kinds of primitives.
Definition: dnnl.hpp:189
number of outputs expected
Definition: dnnl_types.h:1565
primitive_desc(dnnl_primitive_desc_t cpd)
Initializes a primitive descriptor for inner product weights update from a C primitive descriptor cpd...
Definition: dnnl.hpp:4129
primitive_attr(dnnl_primitive_attr_t attr)
Creates primitive attributes from a C dnnl_primitive_attr_t handle.
Definition: dnnl.hpp:689
Vanilla RNN for backward propagation.
Definition: dnnl.hpp:4335
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for layer normalization forward propagation from a C primitive des...
Definition: dnnl.hpp:3752
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:2893
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3034
layer normalization backward propagation.
Definition: dnnl.hpp:3807
dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in the spatial domain: strides, kernel sizes, padding_l, and padding_r.
4D RNN states tensor in the format (num_layers, num_directions, batch, state channels).
Definition: dnnl_types.h:410
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4302
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(dnnl_primitive_desc_t *reorder_primitive_desc, const dnnl_memory_desc_t *src_md, dnnl_engine_t src_engine, const dnnl_memory_desc_t *dst_md, dnnl_engine_t dst_engine, const_dnnl_primitive_attr_t attr)
Initializes a reorder_primitive_desc using the description of the source (src_engine and src_md) and ...
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:4467
Unidirectional execution of RNN primitive from left to right.
Definition: dnnl_types.h:1222
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:2446
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *weights_scales)
Sets quantization scales weights_scales for RNN weights tensors.
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for LSTM forward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:4579
The operation failed because of incorrect function arguments.
Definition: dnnl_types.h:56
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2)
Constructs a primitive_desc from a C counterpart.
Definition: dnnl.hpp:1691
const char * what() const noexcept override
Returns the explanatory string.
Definition: dnnl.hpp:64
dnnl_rnn_direction_t
A direction of RNN primitive execution.
Definition: dnnl_types.h:1220
Use global statistics.
Definition: dnnl_types.h:738
8-bit signed integer.
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc)
Returns a memory_desc associated with memory.
kind
Kinds of engines.
Definition: dnnl.hpp:826
dnnl_status_t DNNL_API dnnl_memory_get_engine(const_dnnl_memory_t memory, dnnl_engine_t *engine)
Returns an engine associated with memory.
Average pooling exclude padding, alias for dnnl::algorithm::pooling_avg_include_padding.
Direct deconvolution.
Definition: dnnl_types.h:662
Primitive descriptor for batch normalization backward propagation.
Definition: dnnl.hpp:3606
GRU descriptor for backward propagation.
Definition: dnnl.hpp:4990
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: dnnl_types.h:668
dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream)
Waits for all primitives in the execution stream to finish.
Batch normalization backward propagation.
Definition: dnnl.hpp:3579
dnnl_status_t DNNL_API dnnl_engine_get_ocl_context(dnnl_engine_t engine, cl_context *context)
Returns an OpenCL context associated with an engine.
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3762
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:2622
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:2291
Default stream configuration.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
4D tensor blocked by 2nd dimension with block size 8
Descriptor for batch normalization forward propagation.
Definition: dnnl.hpp:3486
Primitive descriptor for local response normalization backward propagation.
Definition: dnnl.hpp:3006
dnnl_alg_kind_t
Kinds of algorithms.
Definition: dnnl_types.h:653
Descriptor for binary.
Definition: dnnl.hpp:5586
dnnl_status_t DNNL_API dnnl_engine_create(dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:5114
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
dnnl_status_t DNNL_API dnnl_memory_map_data(const_dnnl_memory_t memory, void **mapped_ptr)
For a memory, maps the data of the memory to mapped_ptr.
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3887
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr)
Makes a copy of an existing_attr.
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for GRU backward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:5058
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:4060
dnnl_status_t DNNL_API dnnl_engine_get_kind(dnnl_engine_t engine, dnnl_engine_kind_t *kind)
Returns the kind of an engine.
LSTM for forward propagation.
Definition: dnnl.hpp:4516
Eltwise: swish.
Definition: dnnl_types.h:693
dnnl_status_t DNNL_API dnnl_primitive_desc_query(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index, void *result)
Queries primitive descriptor.
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3895
In-order execution.
A descriptor of a shuffle operation.
Definition: dnnl_types.h:1000
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: dnnl.hpp:128
Forward data propagation (inference mode).
A deconvolution primitive.
Definition: dnnl_types.h:629
5D CNN activations tensor, an alias to dnnl_abcde
Definition: dnnl_types.h:358
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
3D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBc8b ...
Definition: dnnl_types.h:463
8-bit unsigned integer.
const_dnnl_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: dnnl.hpp:239
memory::desc diff_src_desc() const
Queries diff source gradient memory descriptor.
Definition: dnnl.hpp:2723
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3778
desc submemory_desc(const dims &adims, const dims &offsets)
Constructs a sub-memory descriptor.
Definition: dnnl.hpp:1419
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:5089
Primitive descriptor for LBR_GRU backward propagation.
Definition: dnnl.hpp:5344
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
A convolution primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4429
2D CNN weights tensor, an alias to dnnl_ab
Definition: dnnl_types.h:363
Convolution backward propagation.
Definition: dnnl.hpp:2201
Weights format used in 8bit Winograd convolution.
Definition: dnnl_types.h:97
Primitive descriptor for deconvolution forward propagation.
Definition: dnnl.hpp:2586
execution engine
Definition: dnnl_types.h:1561
Local response normalization (LRN) across multiple channels.
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:3995
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_channels).
Definition: dnnl_types.h:424
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes a primitive descriptor for convolution weight update with attributes defined by attr...
Definition: dnnl.hpp:2424
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4442
Primitive descriptor for RNN forward propagation.
Definition: dnnl.hpp:4255
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:4455
reorder source engine
Softmax for forward propagation.
Definition: dnnl.hpp:3349
LRN within a single channel.
Definition: dnnl_types.h:704
Descriptor for deconvolution backward propagation.
Definition: dnnl.hpp:2644
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes an LBR GRU descriptor rnn_desc for backward propagation using prop_kind, direction, and memory descriptors.
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:5446
5D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:267
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated convolution weight update with bias using aalgorithm, memory descriptors, strides, dilates padding_l, and padding_r.
Definition: dnnl.hpp:2365
dnnl_stream_flags_t
Stream flags.
Definition: dnnl_types.h:1618
#define DNNL_MAX_NDIMS
Maximum number of dimensions a tensor can have.
Definition: dnnl_types.h:775
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4462
4D CNN activations tensor, an alias to dnnl_bcda
Definition: dnnl_types.h:356
A user shall query and provide the scratchpad memory to primitives.
Definition: dnnl_types.h:1382
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: dnnl_types.h:407
memory::desc diff_src_iter_c_desc() const
Queries diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:4815
4D CNN weights tensor, an alias to dnnl_bacd
Definition: dnnl_types.h:383
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, dnnl_format_tag_t tag)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and format tag...
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for softmax backward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:3438
Descriptor for GRU forward propagation.
Definition: dnnl.hpp:4865
3D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:217
Average pooling exclude padding.
Definition: dnnl_types.h:699
A deconvolution primitive.
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Initializes an RNN descriptor for backward propagation using prop_kind, activation, direction, and memory descriptors.
Definition: dnnl.hpp:4358
Primitive descriptor for GRU backward propagation.
Definition: dnnl.hpp:5041
memory::desc diff_dst_iter_desc() const
Queries diff destination iteration memory descriptor.
Definition: dnnl.hpp:5454
engine get_engine() const
Returns the engine of the memory.
Definition: dnnl.hpp:1481
destination grad. memory desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:4000
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:3773
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4941
Eltwise: hyperbolic tangent non-linearity (tanh)
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(dnnl_primitive_desc_t *concat_primitive_desc, const dnnl_memory_desc_t *dst_md, int n, int concat_dimension, const dnnl_memory_desc_t *src_mds, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for deconvolution backward propagation from a C primitive descript...
Definition: dnnl.hpp:2718
CPU engine.
Definition: dnnl_types.h:1325
A descriptor of a Layer Normalization operation.
Definition: dnnl_types.h:1156
RNN cell.
Definition: dnnl_types.h:706
3D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:213
Primitive descriptor for eltwise forward propagation.
Definition: dnnl.hpp:3243
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(dnnl_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, const_dnnl_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
dnnl_scratchpad_mode_t
Scratchpad mode.
Definition: dnnl_types.h:1378
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops)
Returns post_ops for given attr.
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated deconvolution forward propagation with bias using aprop_kind (po...
Definition: dnnl.hpp:2538
dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, const dnnl_memory_desc_t *diff_data_desc, int axis, dnnl_dim_t group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc, axis, and group_size.
5D CNN weights tensor (incl. groups), an alias to dnnl_decab
Definition: dnnl_types.h:398
The library manages scratchpad (default)
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
reorder destination engine
Undefined memory format kind, used for empty memory descriptors.
Definition: dnnl_types.h:88
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scales for primitive operations.
Definition: dnnl.hpp:739
Primitive descriptor for LSTM forward propagation.
Definition: dnnl.hpp:4564
static void wrap_c_api(dnnl_status_t status, const char *message)
A convenience function for wrapping calls to the C API.
Definition: dnnl.hpp:71
source memory desc
dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
convolution_forward(const primitive_desc &pd)
Creates a convolution forward propagation primitive from the corresponding primitive descriptor...
Definition: dnnl.hpp:2194
dnnl_status_t DNNL_API dnnl_memory_unmap_data(const_dnnl_memory_t memory, void *mapped_ptr)
For a memory, unmaps a mapped pointer to the data of the memory.
dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(const_dnnl_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor.
A softmax primitive.
Definition: dnnl_types.h:633
source engine
Definition: dnnl_types.h:1578
GRU for backward propagation.
Definition: dnnl.hpp:4987
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated convolution backward propagation using aalgorithm, memory descriptors, strides, padding_l, and padding_r.
Definition: dnnl.hpp:2234
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution weight update without bias using aalgorithm, memory descriptors, strides, padding_l, and padding_r.
Definition: dnnl.hpp:2785
Inner product for backward propagation with respect to data.
Definition: dnnl.hpp:4010
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets output scales for primitive operations.
LBR_GRU for forward propagation.
Definition: dnnl.hpp:5164
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3907
A base class for descriptors of all primitives that have an operation descriptor and that support ite...
Definition: dnnl.hpp:1968
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:2733
Undefined propagation type.
Definition: dnnl_types.h:591
A class for wrapping an DNNL handle.
Definition: dnnl.hpp:94
5D CNN weights tensor (incl. groups), an alias to dnnl_abcde
Definition: dnnl_types.h:396
4D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcd16b ...
Definition: dnnl_types.h:448
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for LBR GRU backward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:5362
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3112
A shuffle primitive.
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:4830
Memory descriptor.
Definition: dnnl_types.h:884
runtime estimation (seconds), unimplemented
Implements descriptor, primitive descriptor, and primitive for the binary.
Definition: dnnl.hpp:5583
Forward data propagation (alias for dnnl_forward_inference).
Definition: dnnl_types.h:600
dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for LSTM backward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:4736
dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:3874
5D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcde16b ...
Definition: dnnl_types.h:439
6D CNN weights tensor (incl. groups), an alias to dnnl_abcdef
Definition: dnnl_types.h:402
batch normalization descriptor
Definition: dnnl_types.h:1593
void get_params_sum(int index, float &scale) const
Gets the parameters of the accumulation (sum) post operation with index index.
Definition: dnnl.hpp:635
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3392
Descriptor for softmax backward propagation.
Definition: dnnl.hpp:3405
primitive_attr get_primitive_attr() const
Returns the attributes.
Definition: dnnl.hpp:1658
Deconvolution backward propagation.
Definition: dnnl.hpp:2641
Eltwise: ReLU.
Definition: dnnl_types.h:666
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine, void *handle)
Creates a memory for given memory_desc and engine.
Binary mul.
Definition: dnnl_types.h:722
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes an LBR GRU descriptor rnn_desc for forward propagation using prop_kind, direction, and memory descriptors.
dnnl_status_t DNNL_API dnnl_stream_create(dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags)
Creates an execution stream for engine and with flags.
void append_sum(float scale=1.)
Appends accumulation (sum) post operation.
Definition: dnnl.hpp:628
Unidirectional execution of RNN primitive from right to left.
Definition: dnnl_types.h:1224
Undefined propagation kind.
Backward data propagation.
Primitive descriptor for convolution forward propagation.
Definition: dnnl.hpp:2146
memory::desc diff_dst_iter_c_desc() const
Queries diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:4848
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:4480
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
binary descriptor
eltwise descriptor
Forward data propagation, alias for dnnl::prop_kind::forward_inference.
Memory that describes the data.
Definition: dnnl.hpp:1031
convolution_backward_weights(const primitive_desc &pd)
Creates convolution weight update primitive from corresponding primitive descriptor.
Definition: dnnl.hpp:2460
engine(const dnnl_engine_t &aengine)
Constructs an engine from other engine aengine.
Definition: dnnl.hpp:871
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4276
Descriptor for pooling backward propagation.
Definition: dnnl.hpp:3133
op descriptor
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3534
4D CNN weights tensor (incl. groups), an alias to dnnl_abcd
Definition: dnnl_types.h:394
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:5436
number of outputs expected
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(dnnl_primitive_attr_t attr, const float scale, const float shift)
Sets quantization scale and shift for RNN data tensors.
32-bit/single-precision floating point.
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4603
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4616
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:3637
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3041
convolution descriptor
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4959
engine scratchpad_engine() const
Returns the engine that owns the scratchpad memory.
Definition: dnnl.hpp:1647
Primitive descriptor for softmax backward propagation.
Definition: dnnl.hpp:3420
Initializes a reorder primitive using the description of the source (src_engine and src_md) and desti...
Definition: dnnl.hpp:1752
Initializes an inner product descriptor for forward propagation using prop_kind (possible values are ...
Definition: dnnl.hpp:3939
memory::desc src_iter_c_desc() const
Queries source recurrent cell state memory descriptor.
Definition: dnnl.hpp:4598
memory::desc diff_weights_desc() const
Queries diff weights memory descriptor.
Definition: dnnl.hpp:4137
Descriptor for shuffle forward propagation.
Definition: dnnl.hpp:5478
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3871
3D CNN weights tensor, an alias to dnnl_abc
Definition: dnnl_types.h:367
pooling descriptor
query
Primitive descriptor query specification.
Definition: dnnl.hpp:483
LRN within a single channel.
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3029
Primitive descriptor for LSTM backward propagation.
Definition: dnnl.hpp:4718
permuted 4D tensor
Definition: dnnl_types.h:201
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:3549
Placeholder memory format tag.
engine(const handle< dnnl_primitive_desc_t > &pd)
Constructs an engine from the primitive descriptor pd by querying its engine.
Definition: dnnl.hpp:875
Undefined memory format tag.
Definition: dnnl_types.h:175
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates, output_channels).
Definition: dnnl_types.h:417
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, int ndims, const dnnl_dims_t dims)
Initializes an out_memory_desc with new ndims and dims from a in_memory_desc.
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:5132
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3657
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3323
dnnl_status_t DNNL_API dnnl_memory_get_data_handle(const_dnnl_memory_t memory, void **handle)
For a memory, returns the data handle.
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for LBR GRU forward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:5225
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
shuffle descriptor
dnnl_status_t DNNL_API dnnl_memory_get_ocl_mem_object(const_dnnl_memory_t memory, cl_mem *mem_object)
For a memory returns the OpenCL memory object associated with it.
Fuse with ReLU.
Definition: dnnl_types.h:764
3D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBc4b ...
Definition: dnnl_types.h:460
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:2451
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes a primitive descriptor for deconvolution weight update.
Definition: dnnl.hpp:2858
destination memory desc
Undefined memory format tag.
size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind)
Returns the number of engines of a particular kind.
Primitive descriptor for RNN backward propagation.
Definition: dnnl.hpp:4393
plain 3D tensor
Definition: dnnl_types.h:185
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes an LSTM descriptor rnn_desc for forward propagation using prop_kind, direction, and memory descriptors.
desc get_desc() const
Returns the descriptor of the memory.
Definition: dnnl.hpp:1473
The base class for all primitive descriptors.
Definition: dnnl.hpp:1599
Primitive descriptor for layer normalization backward propagation.
Definition: dnnl.hpp:3846
Descriptor for RNN forward propagation.
Definition: dnnl.hpp:4210
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:2627
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: dnnl_types.h:431
prop_kind
Propagation kind.
Definition: dnnl.hpp:265
Eltwise: parametric exponential linear unit (elu)
dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(dnnl_softmax_desc_t *softmax_desc, const dnnl_memory_desc_t *diff_desc, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
permuted 3D tensor
Definition: dnnl_types.h:203
deconvolution descriptor
6D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:289
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, bool allow_empty=false)
Initializes a primitive descriptor for convolution forward propagation with attributes defined by att...
Definition: dnnl.hpp:2158
An inner product primitive.
A descriptor of a Local Response Normalization (LRN) operation.
Definition: dnnl_types.h:1102
Eltwise: x*sigmoid(a*x)
int ndims
Number of dimensions.
Definition: dnnl_types.h:886
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Initializes a descriptor for forward propagation using prop_kind (possible values are dnnl::forward_t...
Definition: dnnl.hpp:2928
number of inputs expected
Definition: dnnl_types.h:1564
Eltwise: exponent.
An element-wise primitive.
Definition: dnnl_types.h:631
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Sets quantization scales weights_scales for RNN weights tensors.
Definition: dnnl.hpp:798
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, float scale, dnnl_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha, and beta (...
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Initializes a layer normalization descriptor for forward propagation using prop_kind (possible values...
Definition: dnnl.hpp:3714
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution weight update with bias using aalgorithm, memory descriptors, strides, padding_l, and padding_r.
Definition: dnnl.hpp:2319
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for RNN forward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:4270
4D CNN activations tensor, an alias to dnnl::memory::format_tag::acdb
Primitive descriptor for inner product forward propagation.
Definition: dnnl.hpp:3963
Direct convolution.
Definition: dnnl_types.h:656
Forward data propagation (alias for dnnl_forward_training).
Definition: dnnl_types.h:602
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Initializes a shuffle descriptor for forward propagation using prop_kind, memory descriptor data_desc...
Definition: dnnl.hpp:5484
Descriptor for convolution forward propagation.
Definition: dnnl.hpp:2478
cl_mem get_ocl_mem_object() const
Returns the OpenCL memory object associated with the memory.
Definition: dnnl.hpp:1541
workspace memory desc
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:5239
A shuffle primitive.
Definition: dnnl_types.h:621
5D CNN weights tensor, an alias to dnnl_cdeba
Definition: dnnl_types.h:387
Default order execution.
Definition: dnnl_types.h:1621
handle()=default
Empty constructor.
The library manages scratchpad (default)
Definition: dnnl_types.h:1380
Backward bias propagation.
stream(const engine &aengine, flags aflags=flags::default_flags)
Constructs a stream.
Definition: dnnl.hpp:966
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:4495
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:5063
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for pooling backward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:3175
Primitive descriptor for GRU forward propagation.
Definition: dnnl.hpp:4907
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4608
A descriptor of a Batch Normalization operation.
Definition: dnnl_types.h:1128
Eltwise: square root.
Definition: dnnl_types.h:676
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4447
(scratch) memory, additional to all inputs and outputs memory (bytes)
Definition: dnnl_types.h:1573
permuted 5D tensor
Definition: dnnl_types.h:196
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes a primitive descriptor for deconvolution backward propagation.
Definition: dnnl.hpp:2701
Primitive or engine failed on execution.
Definition: dnnl_types.h:62
dnnl_prop_kind_t
Kinds of propagation.
Definition: dnnl_types.h:588
scratchpad_mode
Scratchpad mode.
Definition: dnnl.hpp:253
lrn descriptor
Definition: dnnl_types.h:1592
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3389
permuted 4D tensor
Definition: dnnl_types.h:199
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated convolution weight update without bias using aalgorithm...
Definition: dnnl.hpp:2391
Weights format used in 8bit Winograd convolution.
dnnl_status_t DNNL_API dnnl_memory_set_ocl_mem_object(dnnl_memory_t memory, cl_mem mem_object)
For a memory sets the OpenCL memory object associated with it.
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LBR GRU descriptor for forward propagation using prop_kind, direction, and memory descriptors.
Definition: dnnl.hpp:5187
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution weight update with bias using aalgorithm, memory descriptors, strides, padding_l, and padding_r.
Definition: dnnl.hpp:2761
An opaque structure to describe a primitive descriptor.
LBR_GRU descriptor for backward propagation.
Definition: dnnl.hpp:5293
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(const_dnnl_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3662
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(dnnl_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
bool is_zero() const
Returns true if the memory descriptor describes an empty memory.
Definition: dnnl.hpp:1441
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3544
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Initializes a batch normalization descriptor for backward propagation with respect to data and scale-...
Definition: dnnl.hpp:3593
permuted 2D tensor
Definition: dnnl_types.h:197
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for local response normalization backward propagation from a C pri...
Definition: dnnl.hpp:3024
A binary primitive.
Definition: dnnl_types.h:649
LSTM descriptor for backward propagation.
Definition: dnnl.hpp:4657
An opaque structure to describe an execution stream.
Primitive descriptor for inner product backward propagation with respect to data. ...
Definition: dnnl.hpp:4032
Forward data propagation (inference mode).
Definition: dnnl_types.h:598
Descriptor for layer normalization backward propagation.
Definition: dnnl.hpp:3810
memory::desc diff_src_iter_desc() const
Queries diff source iteration memory descriptor.
Definition: dnnl.hpp:5122
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: dnnl_types.h:660
permuted 3D tensor
Definition: dnnl_types.h:198
destination engine
Definition: dnnl_types.h:1579
A layer normalization primitive.
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3531
GRU cell with linear before reset.
Definition: dnnl_types.h:718
memory::desc diff_src_iter_desc() const
Queries diff source iteration memory descriptor.
Definition: dnnl.hpp:4475
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LSTM descriptor for backward propagation using prop_kind, direction, and memory descriptors.
Definition: dnnl.hpp:4678
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution forward propagation with bias using prop_kind (possible va...
Definition: dnnl.hpp:2489
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for element-wise operations for forward propagation from a C primi...
Definition: dnnl.hpp:3258
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2611
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:5142
Primitive descriptor for deconvolution backward propagation.
Definition: dnnl.hpp:2696
Softmax for backward propagation.
Definition: dnnl.hpp:3402
LSTM cell.
Definition: dnnl_types.h:708
layer normalization descriptor
Definition: dnnl_types.h:1594
convolution descriptor
Definition: dnnl_types.h:1586
layer normalization descriptor
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2728
Primitive descriptor for LBR_GRU forward propagation.
Definition: dnnl.hpp:5210
void set_post_ops(post_ops ops)
Sets post_ops for future use.
Definition: dnnl.hpp:756
void set_ocl_mem_object(cl_mem mem_object)
Sets the OpenCL memory object mem_object associated with the memory.
Definition: dnnl.hpp:1549
Undefined memory format kind, used for empty memory descriptors.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
memory::desc dst_iter_desc() const
Queries destination recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4785
Eltwise: logistic.
Definition: dnnl_types.h:684
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for convolution weights update from a C primitive descriptor pd...
Definition: dnnl.hpp:2433
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Initializes a softmax descriptor for forward propagation using prop_kind (possible values are dnnl::f...
Definition: dnnl.hpp:3358
scratchpad_mode get_scratchpad_mode() const
Returns the scratchpad mode.
Definition: dnnl.hpp:693
dnnl_status_t DNNL_API dnnl_stream_create_ocl(dnnl_stream_t *stream, dnnl_engine_t engine, cl_command_queue queue)
Creates an execution stream for a given engine associated with an OpenCL command queue.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes primitive descriptor for convolution backward propagation with attributes defined by attr...
Definition: dnnl.hpp:2267
GRU cell.
Definition: dnnl_types.h:710
Undefined data type, used for empty memory descriptors.
Definition: dnnl_types.h:70
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:4835
T * map_data() const
Maps the data of the memory.
Definition: dnnl.hpp:1519
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated deconvolution weight update without bias using aalgorithm...
Definition: dnnl.hpp:2833
2D CNN weights tensor, an alias to dnnl_ba
Definition: dnnl_types.h:365
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated deconvolution forward propagation without bias using aprop_kind ...
Definition: dnnl.hpp:2565
Descriptor for pooling forward propagation.
Definition: dnnl.hpp:3065
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Initializes a primitive descriptor for convolution weight update.
Definition: dnnl.hpp:2416
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3642
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution forward propagation without bias using prop_kind (possible...
Definition: dnnl.hpp:2514
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution weight update without bias using aalgorithm, memory descriptors, strides, padding_l, and padding_r.
Definition: dnnl.hpp:2343
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Initializes a layer normalization descriptor for backward propagation with respect to data and scale-...
Definition: dnnl.hpp:3821
Convolution weight update.
Definition: dnnl.hpp:2307
A convolution primitive.
Definition: dnnl_types.h:627
desc()
Constructs a zero memory descriptor.
Definition: dnnl.hpp:1378
4D tensor blocked by 1st and 2nd dimension with block size 8
Definition: dnnl_types.h:250
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for element-wise operations for backward propagation from a C prim...
Definition: dnnl.hpp:3315
Primitive descriptor for softmax forward propagation.
Definition: dnnl.hpp:3368
Primitive descriptor for eltwise backward propagation.
Definition: dnnl.hpp:3297
Descriptor for softmax forward propagation.
Definition: dnnl.hpp:3352
const post_ops get_post_ops() const
Returns post_ops previously set by set_post_ops.
Definition: dnnl.hpp:746
primitive_desc(const desc &desc, const engine &e, bool allow_empty=false)
Initializes a primitive descriptor for convolution forward propagation.
Definition: dnnl.hpp:2151
memory::desc dst_iter_c_desc() const
Queries destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:4790
32-bit/single-precision floating point.
Definition: dnnl_types.h:76
destination memory desc
Definition: dnnl_types.h:1606
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:3552
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4797
format_tag
Memory format tag specification.
Definition: dnnl.hpp:1078
Unspecified format kind.
Definition: dnnl_types.h:91
dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff...
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:4315
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:5231
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for deconvolution weights update from a C primitive descriptor pd...
Definition: dnnl.hpp:2875
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for inner product backward propagation from a C primitive descript...
Definition: dnnl.hpp:4050
primitive_desc(dnnl_primitive_desc_t pd)
Initializes a primitive descriptor for pooling forward propagation from a C primitive descriptor pd...
Definition: dnnl.hpp:3106
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4294
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3443
Just a sentinel, not real memory format tag.
Definition: dnnl_types.h:333
rnn descriptor
Definition: dnnl_types.h:1596
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for dilated deconvolution weight update with bias using aalgorithm, memory descriptors, strides, dilates padding_l, and padding_r.
Definition: dnnl.hpp:2807
eltwise descriptor
Definition: dnnl_types.h:1589
normalization_flags
Flags for batch normalization primitive.
Definition: dnnl.hpp:372
4D CNN activations tensor, an alias to dnnl::memory::format_tag::bcda
Winograd deconvolution.
Definition: dnnl_types.h:664
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3631
primitive_attr()
Creates default primitive attributes.
Definition: dnnl.hpp:679
memory consumption – extra
Definition: dnnl_types.h:1568
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: dnnl.hpp:1438
plain 5D tensor
Definition: dnnl_types.h:187
A layer normalization primitive.
Definition: dnnl_types.h:641
Unspecified format kind.
dnnl_engine_kind_t
Kinds of engines.
Definition: dnnl_types.h:1321
An unspecified engine.
Descriptor for convolution backward propagation.
Definition: dnnl.hpp:2204
A batch normalization primitive.
An unspecified engine.
Definition: dnnl_types.h:1323
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3539
Eltwise: square.
Definition: dnnl_types.h:672
size_t DNNL_API dnnl_memory_desc_get_size(const dnnl_memory_desc_t *memory_desc)
Returns the size (in bytes) that is required for given memory_desc.
An execution stream.
Definition: dnnl.hpp:947
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind (possi...
dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a GRU descriptor rnn_desc for forward propagation using prop_kind, direction, and memory descriptors.
permuted 4D tensor
Definition: dnnl_types.h:195
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:5441
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3767
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:4485
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes an RNN descriptor rnn_desc for forward propagation using prop_kind, activation, direction, and memory descriptors.
5D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:260
GPU engine.
Definition: dnnl_types.h:1327
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: dnnl_types.h:405
dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for forward propagation using prop_kind (possible values are dnnl_forward...
Post operations.
Definition: dnnl.hpp:586
4D tensor blocked by 1st and 2nd dimension with block size 8
4D CNN activations tensor, an alias to dnnl_acdb
Definition: dnnl_types.h:354
scratchpad memory desc
dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc...