Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.21.0
Performance library for Deep Learning
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) = delete;
59  handle &operator=(const handle &&other) = delete;
60 protected:
61  bool operator==(const T other) const { return other == _data.get(); }
62  bool operator!=(const T other) const { return !(*this == other); }
63 public:
67  handle(T t = 0, bool weak = false): _data(0) {
68  reset(t, weak);
69  }
70 
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101  static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104 
106 class primitive: public handle<mkldnn_primitive_t> {
107  friend struct error;
108  friend struct stream;
109  friend class primitive_at;
110  using handle::handle;
111 public:
113  enum class kind {
116  view = mkldnn_view,
120  sum = mkldnn_sum,
127  lrn = mkldnn_lrn,
130  rnn = mkldnn_rnn,
131  };
132 
134  struct at {
142 
143  at(const primitive &aprimitive, size_t at = 0)
144  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
146  inline operator primitive() const;
147  };
148 
151  // TODO: use the C++ API wrapper structure.
152 };
153 
155  return static_cast<mkldnn_primitive_kind_t>(akind);
156 }
161 struct error: public std::exception {
163  std::string message;
165 
172 
173  error(mkldnn_status_t astatus, std::string amessage,
174  mkldnn_primitive_t aerror_primitive = 0)
175  : status(astatus)
176  , message(amessage)
177  , error_primitive(aerror_primitive, true)
178  {}
179 
187 
189  const std::string &message,
191  {
192  if (status != mkldnn_success) {
193  if (nullptr != error_primitive)
195  else
196  throw error(status, message, nullptr);
197  }
198  }
199 };
200 
201 inline primitive::at::operator primitive() const {
204  mkldnn_primitive_get_output(data.primitive,
205  data.output_index, &output),
206  "could not get an output primitive");
207  return primitive(const_cast<mkldnn_primitive_t>(output), true);
208 }
209 
213  "could not get primitive descriptor by primitive");
214  return pd;
215 }
217 
222 
226 };
227 
229  return static_cast<mkldnn_round_mode_t>(mode);
230 }
231 
234 };
235 
237  return static_cast<mkldnn_padding_kind_t>(kind);
238 }
239 
240 enum prop_kind {
249 };
250 
252  return static_cast<mkldnn_prop_kind_t>(kind);
253 }
254 
255 enum algorithm {
284 };
285 
287  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
288 }
289 
294 };
295 
297  batch_normalization_flag aflag) {
298  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
299 }
300 
307 };
308 
310  return static_cast<mkldnn_rnn_direction_t>(adir);
311 }
312 
313 enum query {
315 
318 
321 
324 
326 
339 
349 };
350 
352  return static_cast<mkldnn_query_t>(aquery);
353 }
354 
356 
362 
363 #ifndef DOXYGEN_SHOULD_SKIP_THIS
364 template <> struct handle_traits<mkldnn_post_ops_t> {
365  static constexpr auto destructor = &mkldnn_post_ops_destroy;
366 };
367 #endif
368 
369 struct post_ops: public handle<mkldnn_post_ops_t> {
371  mkldnn_post_ops_t result;
373  "could not create post operation sequence");
374  reset(result);
375  }
376 
377  int len() const { return mkldnn_post_ops_len(get()); }
378 
379  primitive::kind kind(int index) const {
382  "post_ops index is out of range");
383  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
384  index));
385  }
386 
387  void append_sum(float scale = 1.) {
389  "could not append sum");
390  }
391 
392  void get_params_sum(int index, float &scale) const {
394  "could not get sum params");
395  }
396 
397  void append_eltwise(float scale, algorithm alg, float alpha,
398  float beta) {
400  convert_to_c(alg), alpha, beta),
401  "could not append eltwise");
402  }
403 
404  void get_params_eltwise(int index, float &scale, algorithm &alg,
405  float &alpha, float &beta) const {
406  mkldnn_alg_kind_t c_alg;
408  &scale, &c_alg, &alpha, &beta),
409  "could not get eltwise params");
410  alg = static_cast<algorithm>(c_alg);
411  }
412 };
413 
414 #ifndef DOXYGEN_SHOULD_SKIP_THIS
415 template <> struct handle_traits<mkldnn_primitive_attr_t> {
416  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
417 };
418 #endif
419 
420 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
424  "could not create a primitive attr");
425  reset(result);
426  }
427 
429  mkldnn_round_mode_t result;
431  get(), &result), "could not get int output round mode");
432  return round_mode(result);
433  }
434 
437  get(), mkldnn::convert_to_c(mode)),
438  "could not set int output round mode");
439  }
440 
441  void get_output_scales(int &mask, std::vector<float> &scales) const
442  {
443  int count, c_mask;
444  const float *c_scales;
446  &count, &c_mask, &c_scales),
447  "could not get int output scales");
448  scales.resize(count);
449 
450  mask = c_mask;
451  for (int c = 0; c < count; ++c)
452  scales[c] = c_scales[c];
453  }
454 
455  void set_output_scales(int mask, const std::vector<float> &scales)
456  {
458  (int)scales.size(), mask, &scales[0]),
459  "could not set int output scales");
460  }
461 
462  const post_ops get_post_ops() const {
463  post_ops result;
464  const_mkldnn_post_ops_t c_result;
466  "could not get post operation sequence");
467  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
468  return result;
469  }
470 
471  void set_post_ops(post_ops ops) {
473  "could not set post operation sequence");
474  }
475 
476  void set_rnn_data_qparams(const float scale, const float shift)
477  {
479  scale, shift), "could not set rnn data int scale/shift");
480  }
481 
482  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
483  {
485  (int)scales.size(), mask, &scales[0]),
486  "could not set rnn weights int scales");
487  }
488 };
489 
491 
497 
498 #ifndef DOXYGEN_SHOULD_SKIP_THIS
499 template <> struct handle_traits<mkldnn_engine_t> {
500  static constexpr auto destructor = &mkldnn_engine_destroy;
501 };
502 #endif
503 
505 struct engine: public handle<mkldnn_engine_t> {
506  friend class primitive;
507  // gcc bug??? using handle::handle;
508 
510  enum kind {
515  };
516 
520 
521  static size_t get_count(kind akind) {
522  return mkldnn_engine_get_count(convert_to_c(akind));
523  }
524 
530 
531  engine(kind akind, size_t index) {
532  mkldnn_engine_t aengine;
534  mkldnn_engine_create(&aengine,
535  convert_to_c(akind), index),
536  "could not create an engine");
537  reset(aengine);
538  }
539 
540  explicit engine(const mkldnn_engine_t& aengine)
541  : handle(aengine, true) {}
542 
544  mkldnn_engine_t engine_q;
547  mkldnn::convert_to_c(eengine), 0, &engine_q),
548  "could not get engine from primitive_desc");
549  reset(engine_q, true);
550  }
551 
552  template <class primitive_desc>
553  static engine query(const primitive_desc &pd) {
554  mkldnn_engine_t engine_q;
557  mkldnn::convert_to_c(eengine), 0, &engine_q),
558  "could not get engine from primitive_desc");
559 
560  return engine(engine_q);
561  }
562 
563 private:
564  static mkldnn_engine_kind_t convert_to_c(kind akind) {
565  return static_cast<mkldnn_engine_kind_t>(akind);
566  }
567 };
568 
570 
573 
579 
581 struct memory: public primitive {
582  private:
583  std::shared_ptr<char> _handle;
584 
585  public:
586  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
587 
588  template <typename T> static void validate_dims(std::vector<T> v) {
589  if (v.size() > TENSOR_MAX_DIMS)
591  "invalid dimensions");
592  }
593 
596  enum data_type {
604  };
605 
608  enum format {
765  };
766 
768  struct desc {
769  friend struct memory;
772 
778  desc(dims adims, data_type adata_type,
779  format aformat) {
780  validate_dims(adims);
782  mkldnn_memory_desc_init(&data, (int)adims.size(),
783  adims.size() == 0 ? nullptr : &adims[0],
784  convert_to_c(adata_type), convert_to_c(aformat)),
785  "could not initialize a memory descriptor");
786  }
787 
791  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
792  };
793 
795  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
796  friend struct memory;
797 
798  // TODO: make private
800 
802  primitive_desc(const desc &adesc, const engine &aengine) {
806  &adesc.data, aengine.get()),
807  "could not initialize a memory primitive descriptor");
808  reset(result);
809  }
810 
814  return memory::desc(*memory_d); }
815 
818  size_t get_size() const {
820  }
821 
822  bool operator==(const primitive_desc &other) const {
823  return (0 == mkldnn_memory_primitive_desc_equal(get(),
824  other.get())) ? false : true;
825  }
826 
827  bool operator!=(const primitive_desc &other) const {
828  return !operator==(other);
829  }
830 
831  engine get_engine() { return engine::query(*this); }
832  };
833 
837  memory(const primitive &aprimitive): primitive(aprimitive) {}
841  memory(const primitive_desc &adesc) {
842  mkldnn_primitive_t result;
844  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
845  "could not create a memory primitive");
846  reset(result);
847  auto _malloc = [](size_t size, int alignment) {
848  void *ptr;
849 #ifdef _WIN32
850  ptr = _aligned_malloc(size, alignment);
851  int rc = ((ptr)? 0 : errno);
852 #else
853  int rc = ::posix_memalign(&ptr, alignment, size);
854 #endif /* _WIN32 */
855  return (rc == 0) ? (char*)ptr : nullptr;
856  };
857  auto _free = [](char* p) {
858 #ifdef _WIN32
859  _aligned_free((void*)p);
860 #else
861  ::free((void*)p);
862 #endif /* _WIN32 */
863  };
864  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
865  set_data_handle(_handle.get());
866  }
867 
868  memory(const primitive_desc &adesc, void *ahandle) {
869  mkldnn_primitive_t result;
871  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
872  "could not create a memory primitive");
873  reset(result);
874  set_data_handle(ahandle);
875  }
876 
879  primitive_desc adesc;
882  &cdesc),
883  "could not get primitive descriptor from a memory primitive");
884  /* FIXME: no const_cast should be here */
885  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
886  return adesc;
887  }
888 
891  inline void *get_data_handle() const {
892  void *handle;
894  "could not get native handle");
895  return handle;
896  }
897 
898  inline void set_data_handle(void *handle) const {
900  "could not set native handle");
901  }
902 
903  // Must go away or be private:
905  return static_cast<mkldnn_data_type_t>(adata_type);
906  }
908  return static_cast<mkldnn_memory_format_t>(aformat);
909  }
910 };
911 
913  auto zero = mkldnn_memory_desc_t();
914  zero.primitive_kind = mkldnn_memory;
915  return memory::desc(zero);
916 }
917 
918 inline memory null_memory(engine eng) {
920  return memory({zero, eng}, nullptr);
921 }
922 
924  &aprimitive_desc, int n_inputs, int n_outputs,
925  const std::string &prim_name) {
926  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
927  aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
928  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
929  aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
930  if (n_outputs_expected > n_outputs ) {
931  std::string message = "could not create " + prim_name +
932  " primitive, not enought output parameters";
933  throw error(mkldnn_invalid_arguments, message, nullptr);
934  }
935  if (n_inputs_expected > n_inputs ) {
936  std::string message = "could not create " + prim_name +
937  " primitive, not enought input parameters";
938  throw error(mkldnn_invalid_arguments, message, nullptr);
939  }
940 }
941 
942 
943 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
944  const_mkldnn_primitive_desc_t aprimitive_pd;
945  mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
947  aprimitive_pd);
948 
949  return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
950 }
951 
953  return a == memory::convert_to_c(b);
954 }
956  return !(a == b);
957 }
959  return b == a;
960 }
962  return !(a == b);
963 }
964 
966  return a == memory::convert_to_c(b);
967 }
969  return !(a == b);
970 }
972  return b == a;
973 }
975  return !(a == b);
976 }
977 
979 
985 
986 struct reorder : public primitive {
987  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
989  const memory::primitive_desc &output) {
992  &result, input.get(), output.get()),
993  "could not create a reorder primitive descriptor");
994  reset(result);
995  }
996 
998  const memory::primitive_desc &output,
999  const primitive_attr &aattr) {
1000  mkldnn_primitive_desc_t result;
1002  &result, input.get(), output.get(), aattr.get()),
1003  "could not create a reorder primitive descriptor");
1004  reset(result);
1005  }
1006 
1007  engine get_engine() { return engine::query(*this); }
1008  };
1009 
1010  reorder(const primitive_desc &aprimitive_desc,
1011  const primitive::at &input, const memory &output) {
1012  mkldnn_primitive_t result;
1013  mkldnn_primitive_at_t inputs[] = { input.data };
1014  const_mkldnn_primitive_t outputs[] = { output.get() };
1016  aprimitive_desc.get(), inputs, outputs),
1017  "could not create a reorder primitive");
1018  reset(result);
1019  }
1020 
1021  reorder(const primitive::at &input, const memory &output) {
1022  auto input_mpd = memory(input).get_primitive_desc();
1023  auto output_mpd = output.get_primitive_desc();
1024 
1025  auto reorder_d = primitive_desc(input_mpd, output_mpd);
1026 
1027  mkldnn_primitive_t result;
1028  mkldnn_primitive_at_t inputs[] = { input.data };
1029  const_mkldnn_primitive_t outputs[] = { output.get() };
1031  reorder_d.get(), inputs, outputs),
1032  "could not create a reorder primitive");
1033  reset(result);
1034  }
1035 };
1036 
1038 
1044 
1045 struct view : public primitive {
1046  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1048  memory::dims offsets) {
1049  mkldnn_primitive_desc_t result;
1050 
1052  &result, input.get(), &dims[0], &offsets[0]),
1053  "could not create a view primitive descriptor");
1054  reset(result);
1055  }
1056 
1058  memory::primitive_desc adesc;
1060  const_mkldnn_primitive_desc_t const_cdesc =
1064  const_cdesc),
1065  "could not clone a dst primitive descriptor");
1066  adesc.reset(cdesc);
1067  return adesc;
1068  }
1069 
1070  engine get_engine() { return engine::query(*this); }
1071  };
1072 
1073  view(const primitive_desc &view_pd, primitive::at input) {
1074  mkldnn_primitive_t result;
1075  mkldnn_primitive_at_t inputs[] = { input.data };
1077  view_pd.get(), inputs, nullptr),
1078  "could not create a view primitive");
1079  reset(result);
1080  }
1081 
1082  view(memory input, memory::dims dims, memory::dims offsets) {
1083  mkldnn_primitive_t result;
1084  primitive_desc view_pd(input.get_primitive_desc(), dims,
1085  offsets);
1086  mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1088  view_pd.get(), inputs, nullptr),
1089  "could not create a view primitive");
1090  reset(result);
1091  }
1092 };
1093 
1095 
1101 
1102 struct concat : public primitive {
1103  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1104  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1105  std::vector<memory::primitive_desc> inputs) {
1106  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1107  c_api_inputs.reserve(inputs.size());
1108  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1109  std::transform(inputs.begin(), inputs.end(),
1110  std::back_inserter(c_api_inputs), convert_to_c);
1111  return c_api_inputs;
1112  }
1113 
1114  primitive_desc(const memory::desc &output, int concat_dimension,
1115  std::vector<memory::primitive_desc> inputs) {
1116  mkldnn_primitive_desc_t result;
1117 
1118  auto c_api_inputs = cpp_to_c(inputs);
1119 
1121  &result, &output.data, (int)c_api_inputs.size(),
1122  concat_dimension, &c_api_inputs[0]),
1123  "could not create a concat primitive descriptor");
1124  reset(result);
1125  }
1126 
1127  primitive_desc(int concat_dimension,
1128  std::vector<memory::primitive_desc> inputs) {
1129  mkldnn_primitive_desc_t result;
1130 
1131  auto c_api_inputs = cpp_to_c(inputs);
1132 
1134  &result, nullptr, (int)c_api_inputs.size(),
1135  concat_dimension, &c_api_inputs[0]),
1136  "could not create a concat primitive descriptor");
1137  reset(result);
1138  }
1139 
1141  memory::primitive_desc adesc;
1143  const_mkldnn_primitive_desc_t const_cdesc =
1146  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1147  "could not clone a dst primitive descriptor");
1148  adesc.reset(cdesc);
1149  return adesc;
1150  }
1151 
1152  engine get_engine() { return engine::query(*this); }
1153  };
1154 
1155  concat(const primitive_desc &concat_pd,
1156  std::vector<primitive::at> &inputs, const memory &output) {
1157  mkldnn_primitive_t result;
1158 
1159  std::vector<mkldnn_primitive_at_t> p_inputs;
1160  for (size_t i = 0; i < inputs.size(); i++)
1161  p_inputs.push_back(inputs[i].data);
1162  const_mkldnn_primitive_t outputs[] = { output.get() };
1163 
1165  concat_pd.get(), &p_inputs[0], outputs),
1166  "could not create a concat primitive");
1167  reset(result);
1168  }
1169 };
1170 
1172 
1178 
1179 struct sum : public primitive {
1180  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1181  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1182  std::vector<memory::primitive_desc> inputs) {
1183  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1184  c_api_inputs.reserve(inputs.size());
1185  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1186  std::transform(inputs.begin(), inputs.end(),
1187  std::back_inserter(c_api_inputs), convert_to_c);
1188  return c_api_inputs;
1189  }
1190 
1192  const std::vector<float> &scales,
1193  std::vector<memory::primitive_desc> inputs) {
1194  mkldnn_primitive_desc_t result;
1195 
1196  auto c_api_inputs = cpp_to_c(inputs);
1197 
1199  scales.size() == inputs.size() ? mkldnn_success
1201  "number of scales not equal to number of inputs");
1202 
1204  &result, &output.data, (int)c_api_inputs.size(),
1205  &scales[0], &c_api_inputs[0]),
1206  "could not create a sum primitive descriptor");
1207  reset(result);
1208  }
1209 
1210  primitive_desc(const std::vector<float> &scales,
1211  std::vector<memory::primitive_desc> inputs) {
1212  mkldnn_primitive_desc_t result;
1213 
1214  auto c_api_inputs = cpp_to_c(inputs);
1215 
1217  scales.size() == inputs.size() ? mkldnn_success
1219  "number of scales not equal to number of inputs");
1220 
1222  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1223  &c_api_inputs[0]),
1224  "could not create a sum primitive descriptor");
1225  reset(result);
1226  }
1227 
1229  memory::primitive_desc adesc;
1231  const_mkldnn_primitive_desc_t const_cdesc =
1235  const_cdesc),
1236  "could not clone a dst primitive descriptor");
1237  adesc.reset(cdesc);
1238  return adesc;
1239  }
1240 
1241  engine get_engine() { return engine::query(*this); }
1242  };
1243 
1244  sum(const primitive_desc &sum_pd,
1245  std::vector<primitive::at> &inputs, const memory &output) {
1246  mkldnn_primitive_t result;
1247 
1248  std::vector<mkldnn_primitive_at_t> p_inputs;
1249  for (size_t i = 0; i < inputs.size(); i++)
1250  p_inputs.push_back(inputs[i].data);
1251  const_mkldnn_primitive_t outputs[] = { output.get() };
1252 
1254  sum_pd.get(), &p_inputs[0], outputs),
1255  "could not create a sum primitive");
1256  reset(result);
1257  }
1258 };
1259 
1261 
1263 
1266 
1269 
1271 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1273  const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1274  mkldnn_primitive_desc_iterator_t iterator = nullptr;
1276  &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1277  hint_fwd_pd);
1278  error::wrap_c_api(status,
1279  "could not create a primitive descriptor iterator");
1280  pd_iterator.reset(iterator);
1281  fetch_impl();
1282  }
1283 
1284  engine get_engine() { return engine::query(*this); }
1285 
1287  const_mkldnn_primitive_attr_t const_cattr;
1289  "could not get attributes");
1291  error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1292  "could not clone attributes");
1293 
1294  primitive_attr attr;
1295  attr.reset(cattr);
1296  return attr;
1297  }
1298 
1300  const char *impl_info_str() const {
1301  const char *res;
1303  mkldnn_query_impl_info_str, 0, &res),
1304  "could not query implementation info string");
1305  return res;
1306  }
1307 
1314  bool next_impl() {
1316  pd_iterator.get());
1317  if (status == mkldnn_iterator_ends) return false;
1318  error::wrap_c_api(status, "primitive descriptor iterator next failed");
1319 
1320  fetch_impl();
1321  return true;
1322  }
1323 
1325  memory::primitive_desc query_mpd(query what, int idx = 0) const {
1326  std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1328  if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1329  [=](query q) { return what == q; }))
1330  throw error(mkldnn_invalid_arguments, "invalid memory query");
1331 
1332  const_mkldnn_primitive_desc_t const_cdesc
1334  mkldnn::convert_to_c(what), idx);
1335 
1336  // TODO: is there a better way to inform about this?
1337  if (const_cdesc == nullptr)
1338  throw error(mkldnn_not_required, "queried memory is not required");
1339 
1341  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1342  "could not clone a memory primitive descriptor");
1343 
1345  ret.reset(cdesc);
1346  return ret;
1347  }
1348 
1349  // register specialized queries, e.g. src_primitive_desc()
1350 # define REG_QUERY_MPD(name, what, idx) \
1351  memory::primitive_desc name ## _primitive_desc() const \
1352  { return query_mpd(what ## _pd, idx); }
1353 
1354  private:
1355  handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1356  void fetch_impl() {
1358  pd_iterator.get());
1360  "could not fetch a primitive descriptor from the iterator");
1361  reset(pd);
1362  }
1363 };
1364 
1366 
1372 
1374  struct desc {
1376  desc(prop_kind aprop_kind, algorithm aalgorithm,
1377  const memory::desc &src_desc,
1378  const memory::desc &weights_desc,
1379  const memory::desc &bias_desc,
1380  const memory::desc &dst_desc,
1381  const memory::dims strides,
1382  const memory::dims padding_l,
1383  const memory::dims padding_r,
1384  const padding_kind apadding_kind) {
1385  memory::validate_dims(strides);
1386  memory::validate_dims(padding_l);
1387  memory::validate_dims(padding_r);
1389  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1390  &src_desc.data, &weights_desc.data, &bias_desc.data,
1391  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1392  mkldnn::convert_to_c(apadding_kind)),
1393  "could not create a convolution forward descriptor");
1394  }
1395  desc(prop_kind aprop_kind, algorithm aalgorithm,
1396  const memory::desc &src_desc,
1397  const memory::desc &weights_desc,
1398  const memory::desc &dst_desc,
1399  const memory::dims strides,
1400  const memory::dims padding_l,
1401  const memory::dims padding_r,
1402  const padding_kind apadding_kind) {
1403  memory::validate_dims(strides);
1404  memory::validate_dims(padding_l);
1405  memory::validate_dims(padding_r);
1407  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1408  &src_desc.data, &weights_desc.data, nullptr,
1409  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1410  mkldnn::convert_to_c(apadding_kind)),
1411  "could not create a convolution forward descriptor");
1412  }
1413  desc(prop_kind aprop_kind, algorithm aalgorithm,
1414  const memory::desc &src_desc,
1415  const memory::desc &weights_desc,
1416  const memory::desc &bias_desc,
1417  const memory::desc &dst_desc,
1418  const memory::dims strides,
1419  const memory::dims dilates,
1420  const memory::dims padding_l,
1421  const memory::dims padding_r,
1422  const padding_kind apadding_kind) {
1423  memory::validate_dims(strides);
1424  memory::validate_dims(dilates);
1425  memory::validate_dims(padding_l);
1426  memory::validate_dims(padding_r);
1429  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1430  &src_desc.data, &weights_desc.data, &bias_desc.data,
1431  &dst_desc.data, &strides[0], &dilates[0],
1432  &padding_l[0], &padding_r[0],
1433  mkldnn::convert_to_c(apadding_kind)),
1434  "could not create a dilated convolution forward descriptor");
1435  }
1436  desc(prop_kind aprop_kind, algorithm aalgorithm,
1437  const memory::desc &src_desc,
1438  const memory::desc &weights_desc,
1439  const memory::desc &dst_desc,
1440  const memory::dims strides,
1441  const memory::dims dilates,
1442  const memory::dims padding_l,
1443  const memory::dims padding_r,
1444  const padding_kind apadding_kind) {
1445  memory::validate_dims(strides);
1446  memory::validate_dims(dilates);
1447  memory::validate_dims(padding_l);
1448  memory::validate_dims(padding_r);
1451  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1452  &src_desc.data, &weights_desc.data, nullptr,
1453  &dst_desc.data, &strides[0], &dilates[0],
1454  &padding_l[0], &padding_r[0],
1455  mkldnn::convert_to_c(apadding_kind)),
1456  "could not create a dilated convolution forward descriptor");
1457  }
1458  };
1459 
1461  primitive_desc(const desc &desc, const engine &e)
1462  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1463 
1464  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1465  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1466 
1467  REG_QUERY_MPD(src, src, 0);
1468  REG_QUERY_MPD(weights, weights, 0);
1469  REG_QUERY_MPD(bias, weights, 1);
1470  REG_QUERY_MPD(dst, dst, 0);
1471  };
1472 
1473  convolution_forward(const primitive_desc &aprimitive_desc,
1474  const primitive::at &src, const primitive::at &weights,
1475  const primitive::at &bias, const memory &dst) {
1476  mkldnn_primitive_t result;
1477  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1478  bias.data };
1479  const_mkldnn_primitive_t outputs[] = { dst.get() };
1481  aprimitive_desc.get(), inputs, outputs),
1482  "could not create a convolution forward bias primitive");
1483  reset(result);
1484  }
1485 
1486  convolution_forward(const primitive_desc &aprimitive_desc,
1487  const primitive::at &src, const primitive::at &weights,
1488  const memory &dst) {
1489  mkldnn_primitive_t result;
1490  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1491  const_mkldnn_primitive_t outputs[] = { dst.get() };
1492  check_num_parameters(aprimitive_desc.get(), 2, 1,
1493  "convolution forward");
1495  aprimitive_desc.get(), inputs, outputs),
1496  "could not create a convolution forward primitive");
1497  reset(result);
1498  }
1499 };
1500 
1502  struct desc {
1504  desc(algorithm aalgorithm,
1505  const memory::desc &diff_src_desc,
1506  const memory::desc &weights_desc,
1507  const memory::desc &diff_dst_desc,
1508  const memory::dims strides,
1509  const memory::dims padding_l,
1510  const memory::dims padding_r,
1511  const padding_kind apadding_kind) {
1512  memory::validate_dims(strides);
1513  memory::validate_dims(padding_l);
1514  memory::validate_dims(padding_r);
1516  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1517  &weights_desc.data, &diff_dst_desc.data,
1518  &strides[0], &padding_l[0], &padding_r[0],
1519  mkldnn::convert_to_c(apadding_kind)),
1520  "could not create a convolution backward data descriptor");
1521  }
1522  desc(algorithm aalgorithm,
1523  const memory::desc &diff_src_desc,
1524  const memory::desc &weights_desc,
1525  const memory::desc &diff_dst_desc,
1526  const memory::dims strides,
1527  const memory::dims dilates,
1528  const memory::dims padding_l,
1529  const memory::dims padding_r,
1530  const padding_kind apadding_kind) {
1531  memory::validate_dims(strides);
1532  memory::validate_dims(dilates);
1533  memory::validate_dims(padding_l);
1534  memory::validate_dims(padding_r);
1537  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1538  &weights_desc.data, &diff_dst_desc.data,
1539  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1540  mkldnn::convert_to_c(apadding_kind)),
1541  "could not create a convolution backward data descriptor");
1542  }
1543  };
1544 
1546  primitive_desc(const desc &desc, const engine &e,
1547  const convolution_forward::primitive_desc &hint_fwd_pd)
1548  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1549 
1550  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1551  const convolution_forward::primitive_desc &hint_fwd_pd)
1552  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1553 
1554  REG_QUERY_MPD(diff_src, diff_src, 0);
1555  REG_QUERY_MPD(weights, weights, 0);
1556  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1557  };
1558 
1560  const primitive::at &diff_dst, const primitive::at &weights,
1561  const memory &diff_src) {
1562  mkldnn_primitive_t result;
1563  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1564  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1565  check_num_parameters(aprimitive_desc.get(), 2, 1,
1566  "convolution backward data");
1568  aprimitive_desc.get(), inputs, outputs),
1569  "could not create a convolution backward data primitive");
1570  reset(result);
1571  }
1572 };
1573 
1575  struct desc {
1577  desc(algorithm aalgorithm,
1578  const memory::desc &src_desc,
1579  const memory::desc &diff_weights_desc,
1580  const memory::desc &diff_bias_desc,
1581  const memory::desc &diff_dst_desc,
1582  const memory::dims strides,
1583  const memory::dims padding_l,
1584  const memory::dims padding_r,
1585  const padding_kind apadding_kind) {
1586  memory::validate_dims(strides);
1587  memory::validate_dims(padding_l);
1588  memory::validate_dims(padding_r);
1590  &data, convert_to_c(aalgorithm), &src_desc.data,
1591  &diff_weights_desc.data, &diff_bias_desc.data,
1592  &diff_dst_desc.data,
1593  &strides[0], &padding_l[0], &padding_r[0],
1594  mkldnn::convert_to_c(apadding_kind)),
1595  "could not create a convolution backward weights descriptor");
1596  }
1597  desc(algorithm aalgorithm,
1598  const memory::desc &src_desc,
1599  const memory::desc &diff_weights_desc,
1600  const memory::desc &diff_dst_desc,
1601  const memory::dims strides,
1602  const memory::dims padding_l,
1603  const memory::dims padding_r,
1604  const padding_kind apadding_kind) {
1605  memory::validate_dims(strides);
1606  memory::validate_dims(padding_l);
1607  memory::validate_dims(padding_r);
1609  &data, convert_to_c(aalgorithm), &src_desc.data,
1610  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1611  &strides[0], &padding_l[0], &padding_r[0],
1612  mkldnn::convert_to_c(apadding_kind)),
1613  "could not create a convolution backward weights descriptor");
1614  }
1615  desc(algorithm aalgorithm,
1616  const memory::desc &src_desc,
1617  const memory::desc &diff_weights_desc,
1618  const memory::desc &diff_bias_desc,
1619  const memory::desc &diff_dst_desc,
1620  const memory::dims strides,
1621  const memory::dims dilates,
1622  const memory::dims padding_l,
1623  const memory::dims padding_r,
1624  const padding_kind apadding_kind) {
1625  memory::validate_dims(strides);
1626  memory::validate_dims(dilates);
1627  memory::validate_dims(padding_l);
1628  memory::validate_dims(padding_r);
1630  &data, convert_to_c(aalgorithm), &src_desc.data,
1631  &diff_weights_desc.data, &diff_bias_desc.data,
1632  &diff_dst_desc.data,
1633  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1634  mkldnn::convert_to_c(apadding_kind)),
1635  "could not create a convolution backward weights descriptor");
1636  }
1637  desc(algorithm aalgorithm,
1638  const memory::desc &src_desc,
1639  const memory::desc &diff_weights_desc,
1640  const memory::desc &diff_dst_desc,
1641  const memory::dims strides,
1642  const memory::dims dilates,
1643  const memory::dims padding_l,
1644  const memory::dims padding_r,
1645  const padding_kind apadding_kind) {
1646  memory::validate_dims(strides);
1647  memory::validate_dims(dilates);
1648  memory::validate_dims(padding_l);
1649  memory::validate_dims(padding_r);
1651  &data, convert_to_c(aalgorithm), &src_desc.data,
1652  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1653  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1654  mkldnn::convert_to_c(apadding_kind)),
1655  "could not create a convolution backward weights descriptor");
1656  }
1657 
1658  };
1659 
1661  primitive_desc(const desc &desc, const engine &e,
1662  const convolution_forward::primitive_desc &hint_fwd_pd)
1663  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1664 
1665  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1666  const convolution_forward::primitive_desc &hint_fwd_pd)
1667  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1668 
1669  REG_QUERY_MPD(src, src, 0);
1670  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1671  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1672  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1673  };
1674 
1676  const primitive::at &src, const primitive::at &diff_dst,
1677  const memory &diff_weights, const memory &diff_bias) {
1678  mkldnn_primitive_t result;
1679  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1680  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1681  diff_bias.get() };
1682  check_num_parameters(aprimitive_desc.get(), 2, 2,
1683  "convolution backward weights");
1685  aprimitive_desc.get(), inputs, outputs),
1686  "could not create a convolution backward weights primitive");
1687  reset(result);
1688  }
1690  const primitive::at &src, const primitive::at &diff_dst,
1691  const memory &diff_weights) {
1692  mkldnn_primitive_t result;
1693  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1694  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1695  check_num_parameters(aprimitive_desc.get(), 2, 1,
1696  "convolution backward weights");
1698  aprimitive_desc.get(), inputs, outputs),
1699  "could not create a convolution backward weights primitive");
1700  reset(result);
1701  }
1702 };
1703 
1705 //
1711 
1713  struct desc {
1715  desc(prop_kind aprop_kind, algorithm aalgorithm,
1716  const memory::desc &src_desc,
1717  const memory::desc &weights_desc,
1718  const memory::desc &bias_desc,
1719  const memory::desc &dst_desc,
1720  const memory::dims strides,
1721  const memory::dims padding_l,
1722  const memory::dims padding_r,
1723  const padding_kind apadding_kind) {
1724  memory::validate_dims(strides);
1725  memory::validate_dims(padding_l);
1726  memory::validate_dims(padding_r);
1728  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1729  &src_desc.data, &weights_desc.data, &bias_desc.data,
1730  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1731  mkldnn::convert_to_c(apadding_kind)),
1732  "could not create a deconvolution forward descriptor");
1733  }
1734  desc(prop_kind aprop_kind, algorithm aalgorithm,
1735  const memory::desc &src_desc,
1736  const memory::desc &weights_desc,
1737  const memory::desc &dst_desc,
1738  const memory::dims strides,
1739  const memory::dims padding_l,
1740  const memory::dims padding_r,
1741  const padding_kind apadding_kind) {
1742  memory::validate_dims(strides);
1743  memory::validate_dims(padding_l);
1744  memory::validate_dims(padding_r);
1746  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1747  &src_desc.data, &weights_desc.data, nullptr,
1748  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1749  mkldnn::convert_to_c(apadding_kind)),
1750  "could not create a deconvolution forward descriptor");
1751  }
1752  desc(prop_kind aprop_kind, algorithm aalgorithm,
1753  const memory::desc &src_desc,
1754  const memory::desc &weights_desc,
1755  const memory::desc &bias_desc,
1756  const memory::desc &dst_desc,
1757  const memory::dims strides,
1758  const memory::dims dilates,
1759  const memory::dims padding_l,
1760  const memory::dims padding_r,
1761  const padding_kind apadding_kind) {
1762  memory::validate_dims(strides);
1763  memory::validate_dims(dilates);
1764  memory::validate_dims(padding_l);
1765  memory::validate_dims(padding_r);
1767  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1768  &src_desc.data, &weights_desc.data, &bias_desc.data,
1769  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1770  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1771  "could not create a dilated deconvolution forward descriptor");
1772  }
1773  desc(prop_kind aprop_kind, algorithm aalgorithm,
1774  const memory::desc &src_desc,
1775  const memory::desc &weights_desc,
1776  const memory::desc &dst_desc,
1777  const memory::dims strides,
1778  const memory::dims dilates,
1779  const memory::dims padding_l,
1780  const memory::dims padding_r,
1781  const padding_kind apadding_kind) {
1782  memory::validate_dims(strides);
1783  memory::validate_dims(dilates);
1784  memory::validate_dims(padding_l);
1785  memory::validate_dims(padding_r);
1787  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1788  &src_desc.data, &weights_desc.data, nullptr,
1789  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1790  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1791  "could not create a dilated deconvolution forward descriptor");
1792  }
1793  };
1794 
1796  primitive_desc(const desc &desc, const engine &e)
1797  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1798 
1799  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1800  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1801 
1802  REG_QUERY_MPD(src, src, 0);
1803  REG_QUERY_MPD(weights, weights, 0);
1804  REG_QUERY_MPD(bias, weights, 1);
1805  REG_QUERY_MPD(dst, dst, 0);
1806  };
1807 
1808  deconvolution_forward(const primitive_desc &aprimitive_desc,
1809  const primitive::at &src, const primitive::at &weights,
1810  const primitive::at &bias, const memory &dst) {
1811  mkldnn_primitive_t result;
1812  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1813  bias.data };
1814  const_mkldnn_primitive_t outputs[] = { dst.get() };
1815  check_num_parameters(aprimitive_desc.get(), 3, 1,
1816  "deconvolution forward");
1818  aprimitive_desc.get(), inputs, outputs),
1819  "could not create a deconvolution forward bias primitive");
1820  reset(result);
1821  }
1822 
1823  deconvolution_forward(const primitive_desc &aprimitive_desc,
1824  const primitive::at &src, const primitive::at &weights,
1825  const memory &dst) {
1826  mkldnn_primitive_t result;
1827  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1828  const_mkldnn_primitive_t outputs[] = { dst.get() };
1829  check_num_parameters(aprimitive_desc.get(), 2, 1,
1830  "deconvolution forward");
1832  aprimitive_desc.get(), inputs, outputs),
1833  "could not create a deconvolution forward primitive");
1834  reset(result);
1835  }
1836 };
1837 
1839  struct desc {
1841  desc(algorithm aalgorithm,
1842  const memory::desc &diff_src_desc,
1843  const memory::desc &weights_desc,
1844  const memory::desc &diff_dst_desc,
1845  const memory::dims strides,
1846  const memory::dims padding_l,
1847  const memory::dims padding_r,
1848  const padding_kind apadding_kind) {
1849  memory::validate_dims(strides);
1850  memory::validate_dims(padding_l);
1851  memory::validate_dims(padding_r);
1853  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1854  &weights_desc.data, &diff_dst_desc.data,
1855  &strides[0], &padding_l[0], &padding_r[0],
1856  mkldnn::convert_to_c(apadding_kind)),
1857  "could not create a deconvolution backward data descriptor");
1858  }
1859  desc(algorithm aalgorithm,
1860  const memory::desc &diff_src_desc,
1861  const memory::desc &weights_desc,
1862  const memory::desc &diff_dst_desc,
1863  const memory::dims strides,
1864  const memory::dims dilates,
1865  const memory::dims padding_l,
1866  const memory::dims padding_r,
1867  const padding_kind apadding_kind) {
1868  memory::validate_dims(strides);
1869  memory::validate_dims(dilates);
1870  memory::validate_dims(padding_l);
1871  memory::validate_dims(padding_r);
1873  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1874  &weights_desc.data, &diff_dst_desc.data,
1875  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1876  mkldnn::convert_to_c(apadding_kind)),
1877  "could not create a dilated deconvolution backward data descriptor");
1878  }
1879  };
1880 
1882  primitive_desc(const desc &desc, const engine &e,
1883  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1884  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1885 
1886  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1887  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1888  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1889 
1890  REG_QUERY_MPD(diff_src, diff_src, 0);
1891  REG_QUERY_MPD(weights, weights, 0);
1892  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1893  };
1894 
1896  const primitive::at &diff_dst, const primitive::at &weights,
1897  const memory &diff_src) {
1898  mkldnn_primitive_t result;
1899  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1900  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1901  check_num_parameters(aprimitive_desc.get(), 2, 1,
1902  "deconvolution backward data");
1904  aprimitive_desc.get(), inputs, outputs),
1905  "could not create a deconvolution backward data primitive");
1906  reset(result);
1907  }
1908 };
1909 
1911  struct desc {
1913  desc(algorithm aalgorithm,
1914  const memory::desc &src_desc,
1915  const memory::desc &diff_weights_desc,
1916  const memory::desc &diff_bias_desc,
1917  const memory::desc &diff_dst_desc,
1918  const memory::dims strides,
1919  const memory::dims padding_l,
1920  const memory::dims padding_r,
1921  const padding_kind apadding_kind) {
1922  memory::validate_dims(strides);
1923  memory::validate_dims(padding_l);
1924  memory::validate_dims(padding_r);
1926  &data, convert_to_c(aalgorithm), &src_desc.data,
1927  &diff_weights_desc.data, &diff_bias_desc.data,
1928  &diff_dst_desc.data,
1929  &strides[0], &padding_l[0], &padding_r[0],
1930  mkldnn::convert_to_c(apadding_kind)),
1931  "could not create a deconvolution backward weights descriptor");
1932  }
1933  desc(algorithm aalgorithm,
1934  const memory::desc &src_desc,
1935  const memory::desc &diff_weights_desc,
1936  const memory::desc &diff_dst_desc,
1937  const memory::dims strides,
1938  const memory::dims padding_l,
1939  const memory::dims padding_r,
1940  const padding_kind apadding_kind) {
1941  memory::validate_dims(strides);
1942  memory::validate_dims(padding_l);
1943  memory::validate_dims(padding_r);
1945  &data, convert_to_c(aalgorithm), &src_desc.data,
1946  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1947  &strides[0], &padding_l[0], &padding_r[0],
1948  mkldnn::convert_to_c(apadding_kind)),
1949  "could not create a deconvolution backward weights descriptor");
1950  }
1951  desc(algorithm aalgorithm,
1952  const memory::desc &src_desc,
1953  const memory::desc &diff_weights_desc,
1954  const memory::desc &diff_bias_desc,
1955  const memory::desc &diff_dst_desc,
1956  const memory::dims strides,
1957  const memory::dims dilates,
1958  const memory::dims padding_l,
1959  const memory::dims padding_r,
1960  const padding_kind apadding_kind) {
1961  memory::validate_dims(strides);
1962  memory::validate_dims(dilates);
1963  memory::validate_dims(padding_l);
1964  memory::validate_dims(padding_r);
1966  &data, convert_to_c(aalgorithm), &src_desc.data,
1967  &diff_weights_desc.data, &diff_bias_desc.data,
1968  &diff_dst_desc.data,
1969  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1970  mkldnn::convert_to_c(apadding_kind)),
1971  "could not create a dilated deconvolution backward weights descriptor");
1972  }
1973  desc(algorithm aalgorithm,
1974  const memory::desc &src_desc,
1975  const memory::desc &diff_weights_desc,
1976  const memory::desc &diff_dst_desc,
1977  const memory::dims strides,
1978  const memory::dims dilates,
1979  const memory::dims padding_l,
1980  const memory::dims padding_r,
1981  const padding_kind apadding_kind) {
1982  memory::validate_dims(strides);
1983  memory::validate_dims(dilates);
1984  memory::validate_dims(padding_l);
1985  memory::validate_dims(padding_r);
1987  &data, convert_to_c(aalgorithm), &src_desc.data,
1988  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1989  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1990  mkldnn::convert_to_c(apadding_kind)),
1991  "could not create a dilated deconvolution backward weights descriptor");
1992  }
1993  };
1994 
1996  primitive_desc(const desc &desc, const engine &e,
1997  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1998  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1999 
2000  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2001  const deconvolution_forward::primitive_desc &hint_fwd_pd)
2002  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2003 
2004  REG_QUERY_MPD(src, src, 0);
2005  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2006  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2007  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2008  };
2009 
2011  const primitive::at &src, const primitive::at &diff_dst,
2012  const memory &diff_weights, const memory &diff_bias) {
2013  mkldnn_primitive_t result;
2014  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2015  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2016  diff_bias.get() };
2017  check_num_parameters(aprimitive_desc.get(), 2, 2,
2018  "deconvolution backward weights");
2020  aprimitive_desc.get(), inputs, outputs),
2021  "could not create a deconvolution backward weights primitive");
2022  reset(result);
2023  }
2025  const primitive::at &src, const primitive::at &diff_dst,
2026  const memory &diff_weights) {
2027  mkldnn_primitive_t result;
2028  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2029  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2030  check_num_parameters(aprimitive_desc.get(), 2, 1,
2031  "deconvolution backward weights");
2033  aprimitive_desc.get(), inputs, outputs),
2034  "could not create a deconvolution backward weights primitive");
2035  reset(result);
2036  }
2037 };
2038 
2040 
2047 
2048 struct lrn_forward : public primitive {
2049  struct desc {
2051  desc(prop_kind aprop_kind, algorithm aalgorithm,
2052  const memory::desc &src_desc,
2053  int local_size, float alpha, float beta, float k)
2054  {
2056  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2057  &src_desc.data, local_size, alpha, beta, k),
2058  "could not create a lrn forward descriptor");
2059  }
2060  desc(prop_kind aprop_kind, algorithm aalgorithm,
2061  const memory::desc &src_desc,
2062  int local_size, float alpha, float beta)
2063  {
2065  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2066  &src_desc.data, local_size, alpha, beta, float(1.0)),
2067  "could not create a lrn forward descriptor");
2068  }
2069  };
2070 
2072  primitive_desc(const desc &desc, const engine &e)
2073  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2074 
2075  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2076  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2077 
2078  REG_QUERY_MPD(src, src, 0);
2079  REG_QUERY_MPD(dst, dst, 0);
2080  REG_QUERY_MPD(workspace, workspace, 0);
2081  };
2082 
2083  lrn_forward(const primitive_desc &aprimitive_desc,
2084  const primitive::at &src, const memory &workspace,
2085  const memory &dst) {
2086  mkldnn_primitive_t result;
2087  mkldnn_primitive_at_t inputs[] = { src.data };
2088  const_mkldnn_primitive_t outputs[] = { dst.get(),
2089  workspace.get() };
2090  check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2092  aprimitive_desc.get(), inputs, outputs),
2093  "could not create a lrn forward primitive");
2094  reset(result);
2095  }
2096 
2097  lrn_forward(const primitive_desc &aprimitive_desc,
2098  const primitive::at &src, const memory &dst) {
2099  mkldnn_primitive_t result;
2100  mkldnn_primitive_at_t inputs[] = { src.data };
2101  const_mkldnn_primitive_t outputs[] = { dst.get() };
2102  check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2104  aprimitive_desc.get(), inputs, outputs),
2105  "could not create a lrn forward primitive");
2106  reset(result);
2107  }
2108 };
2109 
2110 struct lrn_backward : public primitive {
2111  struct desc {
2113  desc(algorithm aalgorithm,
2114  const memory::desc &data_desc,
2115  const memory::desc &diff_data_desc,
2116  int local_size, float alpha, float beta, float k)
2117  {
2119  convert_to_c(aalgorithm), &diff_data_desc.data,
2120  &data_desc.data, local_size, alpha, beta, k),
2121  "could not create a lrn backward descriptor");
2122  }
2123  desc(algorithm aalgorithm,
2124  const memory::desc &data_desc,
2125  const memory::desc &diff_data_desc,
2126  int local_size, float alpha, float beta)
2127  {
2129  convert_to_c(aalgorithm), &diff_data_desc.data,
2130  &data_desc.data, local_size, alpha, beta, float(1.0)),
2131  "could not create a lrn backward descriptor");
2132  }
2133  };
2134 
2136  primitive_desc(const desc &desc, const engine &e,
2137  const lrn_forward::primitive_desc &hint_fwd_pd)
2138  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2139 
2140  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2141  const lrn_forward::primitive_desc &hint_fwd_pd)
2142  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2143 
2144  REG_QUERY_MPD(diff_src, diff_src, 0);
2145  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2146  REG_QUERY_MPD(workspace, workspace, 0);
2147  };
2148 
2149  lrn_backward(const primitive_desc &aprimitive_desc,
2150  const primitive::at &src, const primitive::at &diff_dst,
2151  const primitive::at &workspace, const memory &diff_src) {
2152  mkldnn_primitive_t result;
2153  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2154  workspace.data };
2155  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2156  check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2158  aprimitive_desc.get(), inputs, outputs),
2159  "could not create a lrn backward primitive");
2160  reset(result);
2161  }
2162 
2163  lrn_backward(const primitive_desc &aprimitive_desc,
2164  const primitive::at &src, const primitive::at &diff_dst,
2165  const memory &diff_src) {
2166  mkldnn_primitive_t result;
2167  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2168  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2169  check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2171  aprimitive_desc.get(), inputs, outputs),
2172  "could not create a lrn backward primitive");
2173  reset(result);
2174  }
2175 };
2176 
2178 
2184 
2185 struct pooling_forward : public primitive {
2186  struct desc {
2188  desc(prop_kind aprop_kind, algorithm aalgorithm,
2189  const memory::desc &src_desc,
2190  const memory::desc &dst_desc,
2191  const memory::dims strides,
2192  const memory::dims kernel,
2193  const memory::dims padding_l,
2194  const memory::dims padding_r,
2195  const padding_kind apadding_kind) {
2196  memory::validate_dims(strides);
2197  memory::validate_dims(kernel);
2198  memory::validate_dims(padding_l);
2199  memory::validate_dims(padding_r);
2201  mkldnn::convert_to_c(aprop_kind),
2202  convert_to_c(aalgorithm),
2203  &src_desc.data, &dst_desc.data,
2204  &strides[0], &kernel[0],
2205  &padding_l[0], &padding_r[0],
2206  mkldnn::convert_to_c(apadding_kind)),
2207  "could not init a forward pooling descriptor");
2208  }
2209  };
2210 
2212  primitive_desc(const desc &desc, const engine &e)
2213  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2214 
2215  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2216  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2217 
2218  REG_QUERY_MPD(src, src, 0);
2219  REG_QUERY_MPD(dst, dst, 0);
2220  REG_QUERY_MPD(workspace, workspace, 0);
2221  };
2222 
2223  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2224  const memory &dst) {
2225  mkldnn_primitive_t result;
2226  mkldnn_primitive_at_t inputs[] = { src.data };
2227  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2228  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2230  aprimitive_desc.get(), inputs, outputs),
2231  "could not create a pooling forward primitive");
2232  reset(result);
2233  }
2234 
2235  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2236  const memory &dst, const memory &workspace) {
2237  mkldnn_primitive_t result;
2238  mkldnn_primitive_at_t inputs[] = { src.data };
2239  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2240  check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2242  aprimitive_desc.get(), inputs, outputs),
2243  "could not create a pooling forward primitive");
2244  reset(result);
2245  }
2246 };
2247 
2248 struct pooling_backward : public primitive {
2249  struct desc {
2251  desc(algorithm aalgorithm,
2252  const memory::desc &diff_src_desc,
2253  const memory::desc &diff_dst_desc,
2254  const memory::dims &strides,
2255  const memory::dims &kernel,
2256  const memory::dims &padding_l,
2257  const memory::dims &padding_r,
2258  const padding_kind apadding_kind) {
2259  memory::validate_dims(strides);
2260  memory::validate_dims(kernel);
2261  memory::validate_dims(padding_l);
2262  memory::validate_dims(padding_r);
2264  convert_to_c(aalgorithm),
2265  &diff_src_desc.data, &diff_dst_desc.data,
2266  &strides[0], &kernel[0],
2267  &padding_l[0], &padding_r[0],
2268  mkldnn::convert_to_c(apadding_kind)),
2269  "could not init a backward pooling descriptor");
2270  }
2271  };
2272 
2274  primitive_desc(const desc &desc, const engine &e,
2275  const pooling_forward::primitive_desc &hint_fwd_pd)
2276  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2277 
2278  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2279  const pooling_forward::primitive_desc &hint_fwd_pd)
2280  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2281 
2282  REG_QUERY_MPD(diff_src, diff_src, 0);
2283  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2284  REG_QUERY_MPD(workspace, workspace, 0);
2285  };
2286 
2287  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2288  const memory &diff_src) {
2289  mkldnn_primitive_t result;
2290  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2291  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2292  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2294  aprimitive_desc.get(), inputs, outputs),
2295  "could not create a pooling backward primitive");
2296  reset(result);
2297  }
2298 
2299  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2300  const primitive::at &workspace, const memory &diff_src) {
2301  mkldnn_primitive_t result;
2302  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2303  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2304  check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2306  aprimitive_desc.get(), inputs, outputs),
2307  "could not create a pooling backward primitive");
2308  reset(result);
2309  }
2310 };
2311 
2313 
2320 
2321 struct eltwise_forward : public primitive {
2322  struct desc {
2324  template <typename T>
2325  desc(prop_kind aprop_kind, algorithm alg_kind,
2326  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2328  mkldnn::convert_to_c(aprop_kind),
2329  mkldnn::convert_to_c(alg_kind), &src_desc.data,
2330  static_cast<float>(alpha), static_cast<float>(beta)),
2331  "could not create a eltwise forward descriptor");
2332  }
2333  };
2334 
2336  primitive_desc(const desc &desc, const engine &e)
2337  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2338 
2339  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2340  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2341 
2342  REG_QUERY_MPD(src, src, 0);
2343  REG_QUERY_MPD(dst, dst, 0);
2344  };
2345 
2346  eltwise_forward(const primitive_desc &aprimitive_desc,
2347  const primitive::at &src, const memory &dst) {
2348  mkldnn_primitive_t result;
2349  mkldnn_primitive_at_t inputs[] = { src.data };
2350  const_mkldnn_primitive_t outputs[] = { dst.get() };
2351  check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2353  aprimitive_desc.get(), inputs, outputs),
2354  "could not create a eltwise forward primitive");
2355  reset(result);
2356  }
2357 };
2358 
2359 struct eltwise_backward : public primitive {
2360  struct desc {
2362 
2363  template <typename T>
2364  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2365  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2367  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2368  &data_desc.data, static_cast<float>(alpha),
2369  static_cast<float>(beta)),
2370  "could not create a eltwise backward descriptor");
2371  }
2372  };
2373 
2375  primitive_desc(const desc &desc, const engine &e,
2376  const eltwise_forward::primitive_desc &hint_fwd_pd)
2377  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2378 
2379  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2380  const eltwise_forward::primitive_desc &hint_fwd_pd)
2381  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2382 
2383  REG_QUERY_MPD(src, src, 0);
2384  REG_QUERY_MPD(diff_src, diff_src, 0);
2385  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2386  };
2387 
2388  eltwise_backward(const primitive_desc &aprimitive_desc,
2389  const primitive::at &src, const primitive::at &diff_dst,
2390  const memory &diff_src) {
2391  mkldnn_primitive_t result;
2392  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2393  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2394  check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2396  aprimitive_desc.get(), inputs, outputs),
2397  "could not create a eltwise backward primitive");
2398  reset(result);
2399  }
2400 };
2401 
2403 
2409 
2410 struct softmax_forward : public primitive {
2411  struct desc {
2413  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2414  int softmax_axis) {
2416  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2417  softmax_axis),
2418  "could not create a softmax forward descriptor");
2419  }
2420  };
2421 
2423  primitive_desc(const desc &desc, const engine &e)
2424  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2425 
2426  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2427  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2428 
2429  REG_QUERY_MPD(src, src, 0);
2430  REG_QUERY_MPD(dst, dst, 0);
2431  };
2432 
2433  softmax_forward(const primitive_desc &aprimitive_desc,
2434  const primitive::at &src, const memory &dst) {
2435  mkldnn_primitive_t result;
2436  mkldnn_primitive_at_t inputs[] = { src.data };
2437  const_mkldnn_primitive_t outputs[] = { dst.get() };
2438  check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2440  aprimitive_desc.get(), inputs, outputs),
2441  "could not create a softmax forward primitive");
2442  reset(result);
2443  }
2444 };
2445 
2446 struct softmax_backward : public primitive {
2447  struct desc {
2449  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2450  int softmax_axis) {
2452  &diff_desc.data, &data_desc.data, softmax_axis),
2453  "could not init a backward softmax descriptor");
2454  }
2455  };
2456 
2458  primitive_desc(const desc &desc, const engine &e,
2459  const softmax_forward::primitive_desc &hint_fwd_pd)
2460  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2461 
2462  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2463  const softmax_forward::primitive_desc &hint_fwd_pd)
2464  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2465 
2466  REG_QUERY_MPD(dst, dst, 0);
2467  REG_QUERY_MPD(diff_src, diff_src, 0);
2468  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2469  REG_QUERY_MPD(workspace, workspace, 0);
2470  };
2471 
2472  softmax_backward(const primitive_desc &aprimitive_desc,
2473  const primitive::at &dst, const primitive::at &diff_dst,
2474  const memory &diff_src) {
2475  mkldnn_primitive_t result;
2476  mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2477  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2479  aprimitive_desc.get(), inputs, outputs),
2480  "could not create a softmax backward primitive");
2481  reset(result);
2482  }
2483 };
2484 
2486 
2492 
2494  struct desc {
2496  template <typename T>
2497  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2498  unsigned flags) {
2501  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2502  static_cast<float>(epsilon), flags),
2503  "could not create a batch normalization forward descriptor");
2504  }
2505  };
2506 
2508  primitive_desc(const desc &desc, const engine &e)
2509  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2510 
2511  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2512  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2513 
2514  REG_QUERY_MPD(src, src, 0);
2515  REG_QUERY_MPD(weights, weights, 0);
2516  REG_QUERY_MPD(dst, dst, 0);
2517  REG_QUERY_MPD(workspace, workspace, 0);
2518 
2520  { return stat_primitive_desc(mean); }
2522  { return stat_primitive_desc(var); }
2523 
2524  private:
2525  enum { mean = 1, var = 2, };
2526  memory::primitive_desc stat_primitive_desc(int kind) const {
2530  "could not get a batch-normalization descriptor");
2531  return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2532  }
2533  };
2534 
2536  const primitive::at &src, const primitive::at &mean,
2537  const primitive::at &variance, const primitive::at &weights,
2538  const memory &dst) {
2539  mkldnn_primitive_t result;
2540  mkldnn_primitive_at_t inputs[] = { src.data,
2541  mean.data, variance.data, weights.data };
2542  const_mkldnn_primitive_t outputs[] = { dst.get() };
2543  check_num_parameters(aprimitive_desc.get(), 4, 1,
2544  "batch normalization forward");
2546  aprimitive_desc.get(), inputs, outputs),
2547  "could not create a batch normalization forward primitive");
2548  reset(result);
2549  }
2550 
2552  const primitive::at &src, const primitive::at &mean,
2553  const primitive::at &variance, const memory &dst) {
2554  mkldnn_primitive_t result;
2555  mkldnn_primitive_at_t inputs[] = { src.data,
2556  mean.data, variance.data };
2557  const_mkldnn_primitive_t outputs[] = { dst.get() };
2558  check_num_parameters(aprimitive_desc.get(), 3, 1,
2559  "batch normalization forward");
2561  aprimitive_desc.get(), inputs, outputs),
2562  "could not create a batch normalization forward primitive");
2563  reset(result);
2564  }
2565 
2574  const primitive::at &src, const primitive::at &weights,
2575  const memory &dst, const memory &mean, const memory &variance) {
2576  mkldnn_primitive_t result;
2577  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2578  const_mkldnn_primitive_t outputs[] = { dst.get(),
2579  mean.get(), variance.get() };
2580  check_num_parameters(aprimitive_desc.get(), 2, 3,
2581  "batch normalization forward");
2583  aprimitive_desc.get(), inputs, outputs),
2584  "could not create a batch normalization forward primitive");
2585  reset(result);
2586  }
2587 
2589  const primitive::at &src, const primitive::at &weights,
2590  const memory &dst, const memory &mean, const memory &variance,
2591  const memory &workspace) {
2592  mkldnn_primitive_t result;
2593  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2594  const_mkldnn_primitive_t outputs[] = { dst.get(),
2595  mean.get(), variance.get(), workspace.get() };
2596  check_num_parameters(aprimitive_desc.get(), 2, 4,
2597  "batch normalization forward");
2599  aprimitive_desc.get(), inputs, outputs),
2600  "could not create a batch normalization forward primitive");
2601  reset(result);
2602  }
2603 
2605  const primitive::at &src, const memory &dst, const memory &mean,
2606  const memory &variance) {
2607  mkldnn_primitive_t result;
2608  mkldnn_primitive_at_t inputs[] = { src.data };
2609  const_mkldnn_primitive_t outputs[] = { dst.get(),
2610  mean.get(), variance.get() };
2611  check_num_parameters(aprimitive_desc.get(), 1, 3,
2612  "batch normalization forward");
2614  aprimitive_desc.get(), inputs, outputs),
2615  "could not create a batch normalization forward primitive");
2616  reset(result);
2617  }
2618 
2630  const primitive::at &src, const memory &dst, const memory &mean,
2631  const memory &variance, const memory &workspace) {
2632  mkldnn_primitive_t result;
2633  mkldnn_primitive_at_t inputs[2] = { src.data };
2634  const_mkldnn_primitive_t outputs[4] = { dst.get(),
2635  mean.get(), variance.get(), workspace.get() };
2636 
2637  if (1) { // check whether this is the `wrong` constructor
2638  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2639  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2640  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2641  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2642  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2643  // shift parameters, get rid of workspace, and add weights...
2644  auto _weights = dst;
2645  inputs[1] = {_weights.get(), 0};
2646 
2647  auto _dst = mean, _mean = variance, _variance = workspace;
2648  outputs[0] = _dst.get();
2649  outputs[1] = _mean.get();
2650  outputs[2] = _variance.get();
2651  outputs[3] = nullptr;
2652  }
2653  }
2655  aprimitive_desc.get(), inputs, outputs),
2656  "could not create a batch normalization forward primitive");
2657  reset(result);
2658  }
2659 
2661  const primitive::at &src, const primitive::at &weights,
2662  const memory &dst) {
2663  mkldnn_primitive_t result;
2664  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2665  const_mkldnn_primitive_t outputs[] = { dst.get() };
2666  check_num_parameters(aprimitive_desc.get(), 2, 1,
2667  "batch normalization forward");
2669  aprimitive_desc.get(), inputs, outputs),
2670  "could not create a batch normalization forward primitive");
2671  reset(result);
2672  }
2673 
2675  const primitive::at &src, const memory &dst) {
2676  mkldnn_primitive_t result;
2677  mkldnn_primitive_at_t inputs[] = { src.data };
2678  const_mkldnn_primitive_t outputs[] = { dst.get() };
2679  check_num_parameters(aprimitive_desc.get(), 1, 1,
2680  "batch normalization forward");
2682  aprimitive_desc.get(), inputs, outputs),
2683  "could not create a batch normalization forward primitive");
2684  reset(result);
2685  }
2686 };
2687 
2689  struct desc {
2691  template <typename T>
2692  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2693  const memory::desc &data_desc, T epsilon, unsigned flags) {
2696  mkldnn::convert_to_c(aprop_kind),
2697  &diff_data_desc.data, &data_desc.data,
2698  static_cast<float>(epsilon), flags),
2699  "could not create a batch normalization backward descriptor");
2700  }
2701  };
2702 
2704  primitive_desc(const desc &desc, const engine &e,
2706  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2707 
2708  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2710  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2711 
2712  REG_QUERY_MPD(src, src, 0);
2713  REG_QUERY_MPD(mean, src, 1);
2714  REG_QUERY_MPD(variance, src, 2);
2715  REG_QUERY_MPD(weights, weights, 0);
2716  REG_QUERY_MPD(dst, dst, 0);
2717  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2718  REG_QUERY_MPD(workspace, workspace, 0);
2719 
2720  REG_QUERY_MPD(diff_src, diff_src, 0);
2721  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2722  };
2723 
2724  // Prop_kind == backward
2726  const primitive::at &src, const primitive::at &mean,
2727  const primitive::at &variance, const primitive::at &diff_dst,
2728  const primitive::at &weights, const memory &diff_src,
2729  const memory &diff_weights) {
2730  mkldnn_primitive_t result;
2731  mkldnn_primitive_at_t inputs[] = { src.data,
2732  mean.data, variance.data, diff_dst.data, weights.data };
2733  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2734  diff_weights.get() };
2735  check_num_parameters(aprimitive_desc.get(), 5, 2,
2736  "batch normalization backward");
2738  aprimitive_desc.get(), inputs, outputs),
2739  "could not create a batch normalization backward primitive");
2740  reset(result);
2741  }
2742 
2743  // Prop_kind == backward (+ws)
2745  const primitive::at &src, const primitive::at &mean,
2746  const primitive::at &variance, const primitive::at &diff_dst,
2747  const primitive::at &weights, const primitive::at &workspace,
2748  const memory &diff_src, const memory &diff_weights) {
2749  mkldnn_primitive_t result;
2750  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2751  diff_dst.data, weights.data, workspace.data };
2752  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2753  diff_weights.get() };
2754  check_num_parameters(aprimitive_desc.get(), 6, 2,
2755  "batch normalization backward");
2757  aprimitive_desc.get(), inputs, outputs),
2758  "could not create a batch normalization backward primitive");
2759  reset(result);
2760  }
2761 
2762  // Prop_kind == backward_data (+ws or +weights)
2767  const primitive::at &src, const primitive::at &mean,
2768  const primitive::at &variance,const primitive::at &diff_dst,
2769  const primitive::at &weights_or_workspace, const memory &diff_src) {
2770  mkldnn_primitive_t result;
2771  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2772  diff_dst.data, weights_or_workspace.data };
2773  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2774  check_num_parameters(aprimitive_desc.get(), 5, 1,
2775  "batch normalization backward");
2777  aprimitive_desc.get(), inputs, outputs),
2778  "could not create a batch normalization backward primitive");
2779  reset(result);
2780  }
2781 
2782  // Prop_kind == backward_data
2784  const primitive::at &src, const primitive::at &mean,
2785  const primitive::at &variance, const primitive::at &diff_dst,
2786  const memory &diff_src) {
2787  mkldnn_primitive_t result;
2788  mkldnn_primitive_at_t inputs[] = { src.data,
2789  mean.data, variance.data, diff_dst.data };
2790  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2791  check_num_parameters(aprimitive_desc.get(), 4, 1,
2792  "batch normalization backward");
2794  aprimitive_desc.get(), inputs, outputs),
2795  "could not create a batch normalization backward primitive");
2796  reset(result);
2797  }
2798 };
2799 
2801 
2807 
2809  struct desc {
2811  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2812  const memory::desc &weights_desc,
2813  const memory::desc &bias_desc,
2814  const memory::desc &dst_desc) {
2817  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2818  &weights_desc.data, &bias_desc.data, &dst_desc.data),
2819  "could not create a inner product forward descriptor");
2820  }
2821 
2822  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2823  const memory::desc &weights_desc,
2824  const memory::desc &dst_desc) {
2827  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2828  &weights_desc.data, nullptr, &dst_desc.data),
2829  "could not create a inner product forward descriptor");
2830  }
2831  };
2832 
2834  primitive_desc(const desc &desc, const engine &e)
2835  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2836 
2837  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2838  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2839 
2840  REG_QUERY_MPD(src, src, 0);
2841  REG_QUERY_MPD(weights, weights, 0);
2842  REG_QUERY_MPD(bias, weights, 1);
2843  REG_QUERY_MPD(dst, dst, 0);
2844  };
2845 
2846  inner_product_forward(const primitive_desc &aprimitive_desc,
2847  const primitive::at &src, const primitive::at weights,
2848  const primitive::at &bias, const memory &dst) {
2849  mkldnn_primitive_t result;
2850  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
2851  bias.data };
2852  const_mkldnn_primitive_t outputs[] = { dst.get() };
2853  check_num_parameters(aprimitive_desc.get(), 3, 1,
2854  "inner product forward");
2856  aprimitive_desc.get(), inputs, outputs),
2857  "could not create a inner product forward primitive");
2858  reset(result);
2859  }
2860 
2861  inner_product_forward(const primitive_desc &aprimitive_desc,
2862  const primitive::at &src, const primitive::at weights,
2863  const memory &dst) {
2864  mkldnn_primitive_t result;
2865  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2866  const_mkldnn_primitive_t outputs[] = { dst.get() };
2867  check_num_parameters(aprimitive_desc.get(), 2, 1,
2868  "inner product forward");
2870  aprimitive_desc.get(), inputs, outputs),
2871  "could not create a inner product forward primitive");
2872  reset(result);
2873  }
2874 };
2875 
2877  struct desc {
2879  desc(const memory::desc &diff_src_desc,
2880  const memory::desc &weights_desc,
2881  const memory::desc &diff_dst_desc) {
2884  &diff_src_desc.data, &weights_desc.data,
2885  &diff_dst_desc.data),
2886  "could not create a inner product backward data descriptor");
2887  }
2888  };
2889 
2891  primitive_desc(const desc &desc, const engine &e,
2892  const inner_product_forward::primitive_desc &hint_fwd_pd)
2893  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2894 
2895  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2896  const inner_product_forward::primitive_desc &hint_fwd_pd)
2897  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2898 
2899  REG_QUERY_MPD(diff_src, diff_src, 0);
2900  REG_QUERY_MPD(weights, weights, 0);
2901  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2902  };
2903 
2905  const primitive::at &diff_dst, const primitive::at weights,
2906  const memory &diff_src) {
2907  mkldnn_primitive_t result;
2908  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2909  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2910  check_num_parameters(aprimitive_desc.get(), 2, 1,
2911  "inner product backward data");
2913  aprimitive_desc.get(), inputs, outputs),
2914  "could not create a inner product backward data primitive");
2915  reset(result);
2916  }
2917 };
2918 
2920  struct desc {
2922  desc(const memory::desc &src_desc,
2923  const memory::desc &diff_weights_desc,
2924  const memory::desc &diff_bias_desc,
2925  const memory::desc &diff_dst_desc) {
2928  &data, &src_desc.data, &diff_weights_desc.data,
2929  &diff_bias_desc.data, &diff_dst_desc.data),
2930  "could not create a inner product backward weights descriptor");
2931  }
2932  desc(const memory::desc &src_desc,
2933  const memory::desc &diff_weights_desc,
2934  const memory::desc &diff_dst_desc) {
2937  &data, &src_desc.data, &diff_weights_desc.data,
2938  nullptr, &diff_dst_desc.data),
2939  "could not create a inner product backward weights descriptor");
2940  }
2941  };
2942 
2944  primitive_desc(const desc &desc, const engine &e,
2945  const inner_product_forward::primitive_desc &hint_fwd_pd)
2946  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2947 
2948  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2949  const inner_product_forward::primitive_desc &hint_fwd_pd)
2950  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2951 
2952  REG_QUERY_MPD(src, src, 0);
2953  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2954  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2955  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2956  };
2957 
2959  const primitive::at &src, const primitive::at diff_dst,
2960  const memory &diff_weights) {
2961  mkldnn_primitive_t result;
2962  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2963  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2964  check_num_parameters(aprimitive_desc.get(), 2, 1,
2965  "inner product backward weights");
2967  aprimitive_desc.get(), inputs, outputs),
2968  "could not create a inner product backward weights primitive");
2969  reset(result);
2970  }
2971 
2973  const primitive::at &src, const primitive::at diff_dst,
2974  const memory &diff_weights, const memory &diff_bias) {
2975  mkldnn_primitive_t result;
2976  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2977  const_mkldnn_primitive_t outputs[] =
2978  { diff_weights.get(), diff_bias.get()};
2979  check_num_parameters(aprimitive_desc.get(), 2, 2,
2980  "inner product backward weights");
2982  aprimitive_desc.get(), inputs, outputs),
2983  "could not create a inner product backward weights primitive");
2984  reset(result);
2985  }
2986 };
2987 
2989 
2995 
2996 struct rnn_cell {
2997  struct desc {
2999 
3000  desc(algorithm kind, algorithm activation_f) {
3002  mkldnn::convert_to_c(kind),
3003  mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3004  "could not init an rnn cell descriptor");
3005  }
3007 
3008  operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3009 
3011  { return algorithm(c_rnn_cell_.cell_kind); }
3014 
3015  float get_alpha() const { return c_rnn_cell_.alpha; }
3016  void set_alpha(float alpha) {
3018  c_rnn_cell_.alpha = alpha;
3019  }
3020 
3021  float get_clipping() const { return c_rnn_cell_.clipping; }
3022  void set_clipping(float clipping) {
3024  c_rnn_cell_.clipping = clipping;
3025  }
3026 
3027  int get_gates_count() const {
3029  }
3030  int get_state_count() const {
3032  }
3033  };
3034 };
3035 
3036 struct rnn_forward : public primitive {
3037  struct desc {
3039  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3040  const rnn_direction direction,
3041  const memory::desc &src_layer_desc,
3042  const memory::desc &src_iter_desc,
3043  const memory::desc &weights_layer_desc,
3044  const memory::desc &weights_iter_desc,
3045  const memory::desc &bias_desc,
3046  const memory::desc &dst_layer_desc,
3047  const memory::desc &dst_iter_desc
3048  ) {
3050  mkldnn::convert_to_c(aprop_kind), cell,
3051  mkldnn::convert_to_c(direction),
3052  &src_layer_desc.data, &src_iter_desc.data,
3053  &weights_layer_desc.data, &weights_iter_desc.data,
3054  &bias_desc.data,
3055  &dst_layer_desc.data, &dst_iter_desc.data),
3056  "could not create an RNN forward descriptor");
3057  }
3058 
3059  };
3060 
3062  primitive_desc(const desc &desc, const engine &e)
3063  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3064 
3065  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3066  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3067 
3068  REG_QUERY_MPD(src_layer, src, 0);
3069  REG_QUERY_MPD(src_iter, src, 1);
3070  REG_QUERY_MPD(weights_layer, weights, 0);
3071  REG_QUERY_MPD(weights_iter, weights, 1);
3072  REG_QUERY_MPD(bias, weights, 2);
3073  REG_QUERY_MPD(dst_layer, dst, 0);
3074  REG_QUERY_MPD(dst_iter, dst, 1);
3075  REG_QUERY_MPD(workspace, workspace, 0);
3076  };
3077 
3078  rnn_forward(const primitive_desc &aprimitive_desc,
3079  const primitive::at &src_layer, const primitive::at &src_iter,
3080  const primitive::at &weights_layer,
3081  const primitive::at &weights_iter, const primitive::at &bias,
3082  const memory &dst_layer, const memory &dst_iter,
3083  const memory &workspace) {
3084  mkldnn_primitive_t result;
3085  mkldnn_primitive_at_t inputs[5];
3086  const_mkldnn_primitive_t outputs[3];
3087  int idx=0;
3088  inputs[idx++] = src_layer.data;
3089  if (!is_null_memory(src_iter.data.primitive))
3090  inputs[idx++] = src_iter.data;
3091  inputs[idx++] = weights_layer.data;
3092  inputs[idx++] = weights_iter.data;
3093  if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3094 
3095  idx=0;
3096  outputs[idx++] = dst_layer.get();
3097  if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3098  if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3099 
3101  aprimitive_desc.get(), inputs, outputs),
3102  "could not create an RNN forward primitive");
3103  reset(result);
3104  }
3105 };
3106 
3107 struct rnn_backward : public primitive {
3108  struct desc {
3110  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3111  const rnn_direction direction,
3112  const memory::desc &src_layer_desc,
3113  const memory::desc &src_iter_desc,
3114  const memory::desc &weights_layer_desc,
3115  const memory::desc &weights_iter_desc,
3116  const memory::desc &bias_desc,
3117  const memory::desc &dst_layer_desc,
3118  const memory::desc &dst_iter_desc,
3119  const memory::desc &diff_src_layer_desc,
3120  const memory::desc &diff_src_iter_desc,
3121  const memory::desc &diff_weights_layer_desc,
3122  const memory::desc &diff_weights_iter_desc,
3123  const memory::desc &diff_bias_desc,
3124  const memory::desc &diff_dst_layer_desc,
3125  const memory::desc &diff_dst_iter_desc) {
3127  mkldnn::convert_to_c(aprop_kind), cell,
3128  mkldnn::convert_to_c(direction),
3129  &src_layer_desc.data, &src_iter_desc.data,
3130  &weights_layer_desc.data, &weights_iter_desc.data,
3131  &bias_desc.data,
3132  &dst_layer_desc.data, &dst_iter_desc.data,
3133  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3134  &diff_weights_layer_desc.data,
3135  &diff_weights_iter_desc.data, &diff_bias_desc.data,
3136  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3137  "could not create an RNN backward descriptor");
3138  }
3139 
3140  };
3141 
3143  primitive_desc(const desc &desc, const engine &e,
3144  const rnn_forward::primitive_desc &hint_fwd_pd)
3145  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3146 
3147  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3148  const rnn_forward::primitive_desc &hint_fwd_pd)
3149  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3150 
3151  REG_QUERY_MPD(src_layer, src, 0);
3152  REG_QUERY_MPD(src_iter, src, 1);
3153  REG_QUERY_MPD(weights_layer, weights, 0);
3154  REG_QUERY_MPD(weights_iter, weights, 1);
3155  REG_QUERY_MPD(bias, weights, 2);
3156  REG_QUERY_MPD(dst_layer, dst, 0);
3157  REG_QUERY_MPD(dst_iter, dst, 1);
3158  REG_QUERY_MPD(workspace, workspace, 0);
3159 
3160  REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3161  REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3162  REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3163  REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3164  REG_QUERY_MPD(diff_bias, diff_weights, 2);
3165  REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3166  REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3167  };
3168 
3169  // With last iteration (with and without input src_iter)
3170  rnn_backward(const primitive_desc &aprimitive_desc,
3171  const primitive::at &src_layer,
3172  const primitive::at &src_iter,
3173  const primitive::at &weights_layer,
3174  const primitive::at &weights_iter,
3175  const primitive::at &bias,
3176  const primitive::at &dst_layer,
3177  const primitive::at &dst_iter,
3178  const memory &diff_src_layer,
3179  const memory &diff_src_iter,
3180  const memory &diff_weights_layer,
3181  const memory &diff_weights_iter,
3182  const memory &diff_bias,
3183  const primitive::at &diff_dst_layer,
3184  const primitive::at &diff_dst_iter,
3185  const primitive::at &workspace) {
3186  mkldnn_primitive_t result;
3187  mkldnn_primitive_at_t inputs[10];
3188  const_mkldnn_primitive_t outputs[5];
3189  int idx=0;
3190  inputs[idx++] = src_layer.data;
3191  if (!is_null_memory(src_iter.data.primitive))
3192  inputs[idx++] = src_iter.data;
3193  inputs[idx++] = weights_layer.data;
3194  inputs[idx++] = weights_iter.data;
3195  if (!is_null_memory(bias.data.primitive))
3196  inputs[idx++] = bias.data;
3197  inputs[idx++] = dst_layer.data;
3198  if (!is_null_memory(dst_iter.data.primitive))
3199  inputs[idx++] = dst_iter.data;
3200  inputs[idx++] = diff_dst_layer.data;
3201  if (!is_null_memory(diff_dst_iter.data.primitive))
3202  inputs[idx++] = diff_dst_iter.data;
3203  inputs[idx++] = workspace.data;
3204 
3205  idx = 0;
3206  outputs[idx++] = diff_src_layer.get();
3207  if (!is_null_memory(diff_src_iter.get()))
3208  outputs[idx++] = diff_src_iter.get();
3209  outputs[idx++] = diff_weights_layer.get();
3210  outputs[idx++] = diff_weights_iter.get();
3211  if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3213  aprimitive_desc.get(), inputs, outputs),
3214  "could not create an RNN backward primitive");
3215  reset(result);
3216  }
3217 };
3218 
3220 
3226 
3227 struct shuffle_forward : public primitive {
3228  struct desc {
3230  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3231  int axis, int group_size) {
3233  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3234  axis, group_size),
3235  "could not create a shuffle forward descriptor");
3236  }
3237  };
3238 
3240  primitive_desc(const desc &desc, const engine &e)
3241  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3242 
3243  REG_QUERY_MPD(src, src, 0);
3244  REG_QUERY_MPD(dst, dst, 0);
3245  };
3246 
3247  shuffle_forward(const primitive_desc &aprimitive_desc,
3248  const primitive::at &src, const memory &dst) {
3249  mkldnn_primitive_t result;
3250  mkldnn_primitive_at_t inputs[] = { src.data };
3251  const_mkldnn_primitive_t outputs[] = { dst.get() };
3252  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3254  aprimitive_desc.get(), inputs, outputs),
3255  "could not create a shuffle forward primitive");
3256  reset(result);
3257  }
3258 };
3259 
3260 struct shuffle_backward : public primitive {
3261  struct desc {
3263  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3265  &diff_data_desc.data, axis, group_size),
3266  "could not create a shuffle backward descriptor");
3267  }
3268  };
3269 
3271  primitive_desc(const desc &desc, const engine &e,
3272  const shuffle_forward::primitive_desc &hint_fwd_pd)
3273  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3274 
3275  REG_QUERY_MPD(diff_src, diff_src, 0);
3276  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3277  };
3278 
3279  shuffle_backward(const primitive_desc &aprimitive_desc,
3280  const primitive::at &diff_dst, const memory &diff_src) {
3281  mkldnn_primitive_t result;
3282  mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3283  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3284  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3286  aprimitive_desc.get(), inputs, outputs),
3287  "could not create a shuffle backward primitive");
3288  reset(result);
3289  }
3290 };
3291 
3293 
3295 
3301 
3302 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3303 template <> struct handle_traits<mkldnn_stream_t> {
3304  static constexpr auto destructor = &mkldnn_stream_destroy;
3305 };
3306 #endif
3307 
3308 struct stream: public handle<mkldnn_stream_t> {
3309  using handle::handle;
3310 
3314 
3316  return static_cast<mkldnn_stream_kind_t>(akind);
3317  }
3319  stream(kind akind) {
3320  mkldnn_stream_t astream;
3322  convert_to_c(akind)),
3323  "could not create a stream");
3324  reset(astream);
3325  }
3326 
3331  stream &submit(std::vector<primitive> primitives) {
3332  // TODO: find a proper way to convert vector<primitive> to
3333  // vector<mkldnn_primitive_t>
3334  if (primitives.size() == 0) return *this;
3335  std::vector<mkldnn_primitive_t> c_api_primitives;
3336  c_api_primitives.reserve(primitives.size());
3337  auto convert_to_c = [](primitive p) { return p.get(); };
3338  std::transform(primitives.begin(), primitives.end(),
3339  std::back_inserter(c_api_primitives), convert_to_c);
3340 
3341  mkldnn_primitive_t c_api_error_primitive;
3343  mkldnn_stream_submit(get(),
3344  c_api_primitives.size(), &c_api_primitives[0],
3345  &c_api_error_primitive),
3346  "could not submit primitives to a stream",
3347  &c_api_error_primitive);
3348 
3349  return *this;
3350  }
3351 
3358  bool wait(bool block = true) {
3359  mkldnn_primitive_t c_api_error_primitive;
3360  mkldnn_status_t status = mkldnn_stream_wait(get(),
3361  block, &c_api_error_primitive);
3362  if (status != mkldnn_success
3363  && status != mkldnn_try_again)
3364  error::wrap_c_api(status, "could not wait on a stream",
3365  &c_api_error_primitive);
3366  return (status == mkldnn_success);
3367  }
3368 
3370  mkldnn_primitive_t c_api_error_primitive;
3372  mkldnn_stream_rerun(get(), &c_api_error_primitive),
3373  "could not rerun a stream", &c_api_error_primitive);
3374  return *this;
3375  }
3376 };
3377 
3378 #undef REG_QUERY_MPD
3379 
3381 
3383 
3384 } // namespace mkldnn
3385 
3386 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:387
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2423
Definition: mkldnn.hpp:2374
LRN within a single channel.
Definition: mkldnn_types.h:569
primitive error_primitive
Definition: mkldnn.hpp:164
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:908
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1522
Definition: mkldnn.hpp:743
Definition: mkldnn.hpp:344
blocked weights format
Definition: mkldnn_types.h:348
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:2861
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2215
Definition: mkldnn.hpp:269
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1104
blocked weights format
Definition: mkldnn_types.h:355
op descriptor
Definition: mkldnn_types.h:1250
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1114
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1665
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:393
Definition: mkldnn.hpp:3107
blocked weights format
Definition: mkldnn_types.h:332
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
Definition: mkldnn.hpp:711
Definition: mkldnn.hpp:654
blocked weights format
Definition: mkldnn_types.h:433
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
Definition: mkldnn.hpp:257
Definition: mkldnn.hpp:652
A Softmax primitive.
Definition: mkldnn_types.h:509
number of outputs expected
Definition: mkldnn_types.h:1239
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:3065
blocked weights format
Definition: mkldnn_types.h:438
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1675
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2535
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:3331
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:822
A base class for all primitive descriptors.
Definition: mkldnn.hpp:1271
Definition: mkldnn.hpp:601
Definition: mkldnn.hpp:2248
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:49
stream & rerun()
Definition: mkldnn.hpp:3369
Definition: mkldnn.hpp:2211
A descriptor of a convolution operation.
Definition: mkldnn_types.h:760
Definition: mkldnn.hpp:302
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3230
Definition: mkldnn.hpp:2186
The operation failed and should be retried.
Definition: mkldnn_types.h:55
memory null_memory(engine eng)
Definition: mkldnn.hpp:918
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
Definition: mkldnn.hpp:688
blocked weights format
Definition: mkldnn_types.h:292
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
Definition: mkldnn.hpp:658
Definition: mkldnn.hpp:331
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1615
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: mkldnn_types.h:266
4D data tensor with the physical layout chwn, used in Neon.
Definition: mkldnn_types.h:175
Definition: mkldnn.hpp:265
padding_kind
Definition: mkldnn.hpp:232
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:57
Definition: mkldnn.hpp:703
Eltwise: exponent.
Definition: mkldnn_types.h:556
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:470
Definition: mkldnn.hpp:2049
Definition: mkldnn.hpp:675
An opaque structure to describe an engine.
Definition: mkldnn.hpp:750
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1577
Backward data propagation.
Definition: mkldnn_types.h:476
Definition: mkldnn.hpp:2447
Definition: mkldnn.hpp:712
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:588
Definition: mkldnn.hpp:3270
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(const_mkldnn_primitive_desc_t primitive_desc, const_mkldnn_primitive_attr_t *attr)
Returns a constant reference to the attribute of a primitive_desc.
Definition: mkldnn.hpp:3260
Definition: mkldnn.hpp:644
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
Definition: mkldnn.hpp:721
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2413
Definition: mkldnn.hpp:276
blocked weights format
Definition: mkldnn_types.h:326
blocked weights format
Definition: mkldnn_types.h:404
Definition: mkldnn.hpp:735
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:149
Definition: mkldnn.hpp:683
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:210
Definition: mkldnn.hpp:687
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1155
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:812
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2010
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
Definition: mkldnn.hpp:648
float alpha
alpha is a negative slope parameter (used only if (flags & mkldnn_rnn_cell_with_relu) != 0) ...
Definition: mkldnn_types.h:1012
Definition: mkldnn.hpp:610
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(mkldnn_primitive_attr_t *attr, const_mkldnn_primitive_attr_t existing_attr)
Makes a copy of an existing_attr.
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:634
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:608
Definition: mkldnn.hpp:292
4D weights tensor with physical layout oihw, used in Caffe.
Definition: mkldnn_types.h:199
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:858
blocked weights format
Definition: mkldnn_types.h:439
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
Definition: mkldnn.hpp:662
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2433
blocked weights format
Definition: mkldnn_types.h:440
blocked weights format
Definition: mkldnn_types.h:403
Definition: mkldnn.hpp:272
blocked data format
Definition: mkldnn_types.h:275
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:244
Definition: mkldnn.hpp:663
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:966
Definition: mkldnn.hpp:741
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:586
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: mkldnn_types.h:242
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3240
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor.
batch normalization descriptor
Definition: mkldnn_types.h:1259
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1734
mkldnn_rnn_direction_t
A direction of RNN primitive execution.
Definition: mkldnn_types.h:1019
Definition: mkldnn.hpp:659
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:503
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1882
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2112
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
Definition: mkldnn.hpp:640
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:540
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:301
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:543
Definition: mkldnn.hpp:751
engine get_engine()
Definition: mkldnn.hpp:1284
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:778
blocked data format
Definition: mkldnn_types.h:276
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind (possi...
Definition: mkldnn.hpp:225
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2810
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1244
An execution engine.
Definition: mkldnn.hpp:505
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:868
blocked weights format
Definition: mkldnn_types.h:429
Definition: mkldnn.hpp:764
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2878
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha, and beta (...
Definition: mkldnn.hpp:618
static void wrap_c_api(mkldnn_status_t status, const std::string &message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:188
Definition: mkldnn.hpp:726
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2250
blocked weights format
Definition: mkldnn_types.h:339
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:487
Definition: mkldnn.hpp:698
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
An inner product primitive.
Definition: mkldnn_types.h:517
Packed weights format used in RNN.
Definition: mkldnn_types.h:444
void check_num_parameters(const const_mkldnn_primitive_desc_t &aprimitive_desc, int n_inputs, int n_outputs, const std::string &prim_name)
Definition: mkldnn.hpp:923
Definition: mkldnn.hpp:762
Round down.
Definition: mkldnn_types.h:94
4D grouped weights tensor with the physical layout goiw.
Definition: mkldnn_types.h:223
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2462
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1752
Definition: mkldnn.hpp:705
Definition: mkldnn.hpp:264
Definition: mkldnn.hpp:709
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:428
blocked weights format
Definition: mkldnn_types.h:435
Definition: mkldnn.hpp:678
blocked weights format
Definition: mkldnn_types.h:294
primitive_attr()
Definition: mkldnn.hpp:421
Definition: mkldnn_types.h:565
Definition: mkldnn.hpp:2359
Definition: mkldnn.hpp:636
An unspecified engine.
Definition: mkldnn.hpp:512
Definition: mkldnn.hpp:722
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams(mkldnn_primitive_attr_t attr, int count, int mask, const float *weights_scales)
Sets quantization scales weights_scales for RNN weights tensors.
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
Definition: mkldnn.hpp:598
primitive_desc(const desc &desc, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2458
Definition: mkldnn.hpp:685
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2448
Definition: mkldnn.hpp:2422
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:392
Definition: mkldnn.hpp:247
Definition: mkldnn.hpp:723
32-bit signed integer.
Definition: mkldnn_types.h:78
Definition: mkldnn.hpp:715
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2891
Max pooling.
Definition: mkldnn_types.h:560
Definition: mkldnn.hpp:737
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1436
memory::desc zero_md()
Definition: mkldnn.hpp:912
Definition: mkldnn.hpp:338
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1047
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
blocked weights format
Definition: mkldnn_types.h:314
blocked weights format
Definition: mkldnn_types.h:338
Definition: mkldnn.hpp:646
const post_ops get_post_ops() const
Definition: mkldnn.hpp:462
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2188
Definition: mkldnn.hpp:631
execution engine
Definition: mkldnn_types.h:1235
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:3319
Definition: mkldnn.hpp:1046
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(mkldnn_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
Definition: mkldnn.hpp:337
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2879
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in the spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
Definition: mkldnn.hpp:2185
blocked weights format
Definition: mkldnn_types.h:322
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:907
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2379
Definition: mkldnn.hpp:322
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
Definition: mkldnn_types.h:997
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:255
input memory primitive desc
Definition: mkldnn_types.h:1265
blocked weights format
Definition: mkldnn_types.h:341
Definition: mkldnn.hpp:757
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3229
5D grouped weights tensor with the physical layout goihw, used in Caffe.
Definition: mkldnn_types.h:227
const_mkldnn_primitive_t primitive
Primitive to specify the output for.
Definition: mkldnn_types.h:1195
Definition: mkldnn.hpp:291
blocked weights format
Definition: mkldnn_types.h:354
rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace)
Definition: mkldnn.hpp:3078
mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(mkldnn_rnn_cell_desc_t *rnn_cell_desc, mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, unsigned int flags, float alpha, float clipping)
Initializes a recurrent cell descriptor rnn_cell_desc using rnn_cell_desc, kind (possible values are ...
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:822
Definition: mkldnn.hpp:708
rnn descriptor
Definition: mkldnn_types.h:1261
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2521
An element-wise primitive.
Definition: mkldnn_types.h:507
Definition: mkldnn.hpp:2446
blocked weights format
Definition: mkldnn_types.h:331
destination grad.
Definition: mkldnn_types.h:1272
algorithm get_cell_kind() const
Definition: mkldnn.hpp:3010
Definition: mkldnn.hpp:758
engine get_engine()
Definition: mkldnn.hpp:1241
Definition: mkldnn.hpp:2360
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
mkldnn_alg_kind_t activation_kind
Activation function used.
Definition: mkldnn_types.h:1007
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1228
blocked weights format
Definition: mkldnn_types.h:344
A descriptor for an RNN operation.
Definition: mkldnn_types.h:1034
Definition: mkldnn.hpp:623
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1413
Definition: mkldnn.hpp:1102
Definition: mkldnn.hpp:279
Definition: mkldnn.hpp:259
eltwise descriptor
Definition: mkldnn_types.h:1255
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2629
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1461
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams(mkldnn_primitive_attr_t attr, const float scale, const float shift)
Sets quantization scale and shift for RNN data tensors.
Definition: mkldnn.hpp:278
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:2766
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2097
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2922
batch_normalization_flag
Definition: mkldnn.hpp:290
A memory primitive.
Definition: mkldnn_types.h:489
float clipping
clipping parameter (used only if (flags & mkldnn_rnn_cell_with_clipping) != 0)
Definition: mkldnn_types.h:1015
Definition: mkldnn.hpp:704
blocked weights format
Definition: mkldnn_types.h:311
blocked weights format
Definition: mkldnn_types.h:325
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc)
Definition: mkldnn.hpp:3110
Eltwise: soft_relu.
Definition: mkldnn_types.h:552
Definition: mkldnn.hpp:679
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:471
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:2846
Definition: mkldnn.hpp:343
Definition: mkldnn.hpp:718
Definition: mkldnn.hpp:649
Definition: mkldnn.hpp:261
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
Definition: mkldnn.hpp:602
RNN cell.
Definition: mkldnn_types.h:571
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2212
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1773
bool is_null_memory(const const_mkldnn_primitive_t &aprimitive)
Definition: mkldnn.hpp:943
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2895
Definition: mkldnn.hpp:369
Eltwise: gelu.
Definition: mkldnn_types.h:558
blocked weights format
Definition: mkldnn_types.h:362
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
Definition: mkldnn.hpp:1373
Backward weights propagation.
Definition: mkldnn_types.h:478
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:435
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3038
blocked weights format
Definition: mkldnn_types.h:432
Definition: mkldnn.hpp:673
32-bit/single-precision floating point.
Definition: mkldnn_types.h:76
Definition: mkldnn.hpp:756
blocked weights format
Definition: mkldnn_types.h:288
blocked data format
Definition: mkldnn_types.h:273
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1597
algorithm get_activation() const
Definition: mkldnn.hpp:3012
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2223
2D weights tensor with physical layout oi.
Definition: mkldnn_types.h:184
Just a sentinel, not real memory format.
Definition: mkldnn_types.h:448
Memory descriptor.
Definition: mkldnn_types.h:719
Definition: mkldnn.hpp:732
mkldnn_query_t convert_to_c(query aquery)
Definition: mkldnn.hpp:351
Definition: mkldnn.hpp:2809
Definition: mkldnn.hpp:305
blocked weights format
Definition: mkldnn_types.h:350
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:106
shuffle_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:3247
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:588
void set_clipping(float clipping)
Definition: mkldnn.hpp:3022
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1689
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2050
Definition: mkldnn.hpp:2808
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2497
Definition: mkldnn.hpp:282
Definition: mkldnn.hpp:611
pooling descriptor
Definition: mkldnn_types.h:1257
Definition: mkldnn.hpp:2249
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:240
Definition: mkldnn.hpp:728
Definition: mkldnn.hpp:634
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2187
Definition: mkldnn.hpp:267
blocked weights format
Definition: mkldnn_types.h:287
blocked data format
Definition: mkldnn_types.h:277
3D weights tensor with physical layout wio.
Definition: mkldnn_types.h:196
Definition: mkldnn.hpp:710
blocked weights format
Definition: mkldnn_types.h:414
blocked weights format
Definition: mkldnn_types.h:361
Definition: mkldnn.hpp:633
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
Definition: mkldnn.hpp:738
unsigned int flags
RNN cell flags.
Definition: mkldnn_types.h:1009
Definition: mkldnn.hpp:651
3D data tensor with the physical layout ncw.
Definition: mkldnn_types.h:163
blocked weights format
Definition: mkldnn_types.h:329
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1559
The operation was successful.
Definition: mkldnn_types.h:51
Definition: mkldnn.hpp:273
Definition: mkldnn.hpp:635
blocked weights format with additional buffer with size equal to the number of groups and containing ...
Definition: mkldnn_types.h:424
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
blocked weights format
Definition: mkldnn_types.h:386
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2948
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1661
desc(algorithm kind, algorithm activation_f)
Definition: mkldnn.hpp:3000
blocked weights format
Definition: mkldnn_types.h:400
Definition: mkldnn.hpp:328
Definition: mkldnn.hpp:245
primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd)
Definition: mkldnn.hpp:1272
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
blocked weights format
Definition: mkldnn_types.h:430
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3109
Definition: mkldnn.hpp:657
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:474
5D data tensor with the physical layout ndhwc, used in TensorFlow.
Definition: mkldnn_types.h:181
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2972
softmax descriptor
Definition: mkldnn_types.h:1256
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:90
A deconvolution primitive.
Definition: mkldnn_types.h:505
Definition: mkldnn.hpp:332
Definition: mkldnn.hpp:277
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:802
Use global statistics.
Definition: mkldnn_types.h:601
Definition: mkldnn.hpp:31
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1127
Definition: mkldnn.hpp:672
blocked weights format
Definition: mkldnn_types.h:330
Definition: mkldnn.hpp:639
no query
Definition: mkldnn_types.h:1233
Definition: mkldnn.hpp:1713
blocked weights format
Definition: mkldnn_types.h:416
blocked weights format
Definition: mkldnn_types.h:346
blocked weights format
Definition: mkldnn_types.h:365
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offsets offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:84
Definition: mkldnn.hpp:668
blocked weights format
Definition: mkldnn_types.h:428
Definition: mkldnn.hpp:348
Average pooling include padding.
Definition: mkldnn_types.h:562
Unspecified format.
Definition: mkldnn_types.h:152
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:2904
Definition: mkldnn.hpp:2071
destination memory primitive desc
Definition: mkldnn_types.h:1271
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2519
Definition: mkldnn.hpp:680
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates, output_channels).
Definition: mkldnn_types.h:252
GRU cell with linear before reset.
Definition: mkldnn_types.h:584
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:841
Definition: mkldnn.hpp:653
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2149
mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int axis, int group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc...
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:567
Definition: mkldnn.hpp:706
blocked weights format
Definition: mkldnn_types.h:310
GRU cell.
Definition: mkldnn_types.h:575
Eager stream.
Definition: mkldnn_types.h:1286
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:997
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:455
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:143
implementation name
Definition: mkldnn_types.h:1246
CPU engine.
Definition: mkldnn.hpp:514
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1951
Definition: mkldnn.hpp:1374
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3263
Definition: mkldnn.hpp:3261
Definition: mkldnn.hpp:256
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2287
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and a pointer to a constant floating point array of output ...
3D weights tensor with physical layout oiw.
Definition: mkldnn_types.h:190
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:540
kind
Kinds of engines.
Definition: mkldnn.hpp:510
Definition: mkldnn.hpp:2111
Definition: mkldnn.hpp:2876
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2426
Definition: mkldnn.hpp:701
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:161
round_mode
Definition: mkldnn.hpp:223
Definition: mkldnn.hpp:760
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:952
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1840
Eltwise: ReLU.
Definition: mkldnn_types.h:536
Definition: mkldnn.hpp:2410
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1375
Definition: mkldnn.hpp:233
1D data tensor.
Definition: mkldnn_types.h:158
REG_QUERY_MPD(diff_src, diff_src, 0)
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:136
memory::primitive_desc query_mpd(query what, int idx=0) const
Queries and returns requested memory primitive descriptor.
Definition: mkldnn.hpp:1325
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2708
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3147
Definition: mkldnn.hpp:714
primitive_desc(const desc &desc, const engine &e, const shuffle_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3271
4D weights tensor with physical layout ihwo.
Definition: mkldnn_types.h:208
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2361
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:147
Definition: mkldnn.hpp:1045
Eltwise: square.
Definition: mkldnn_types.h:542
blocked weights format
Definition: mkldnn_types.h:323
Definition: mkldnn.hpp:612
Definition: mkldnn.hpp:1179
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1395
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1057
Definition: mkldnn.hpp:283
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwa...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
void set_rnn_data_qparams(const float scale, const float shift)
Definition: mkldnn.hpp:476
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:904
4D data tensor with the physical layout nhwc, used in TensorFlow.
Definition: mkldnn_types.h:172
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:898
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2604
Definition: mkldnn.hpp:268
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2113
Definition: mkldnn.hpp:619
Backward bias propagation.
Definition: mkldnn_types.h:480
Definition: mkldnn.hpp:986
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2060
blocked weights format
Definition: mkldnn_types.h:425
Use scale and shift parameters.
Definition: mkldnn_types.h:614
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1715
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
query
Definition: mkldnn.hpp:313
Definition: mkldnn.hpp:281
weights format with additional buffer size equal to the number of output channels multiplied by numbe...
Definition: mkldnn_types.h:384
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
Definition: mkldnn.hpp:665
float get_alpha() const
Definition: mkldnn.hpp:3015
blocked weights format
Definition: mkldnn_types.h:309
blocked weights format
Definition: mkldnn_types.h:402
A descriptor of a shuffle operation.
Definition: mkldnn_types.h:805
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:404
Definition: mkldnn_types.h:1029
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2323
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1996
Definition: mkldnn.hpp:420
blocked weights format
Definition: mkldnn_types.h:419
blocked weights format
Definition: mkldnn_types.h:357
Definition: mkldnn.hpp:696
int get_gates_count() const
Definition: mkldnn.hpp:3027
int ndims
Number of dimensions.
Definition: mkldnn_types.h:724
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:1010
Definition: mkldnn.hpp:2048
Definition: mkldnn.hpp:1103
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:113
blocked weights format with additional buffer with size equal to the number of groups and containing ...
Definition: mkldnn_types.h:377
5D grouped weights tensor with the physical layout giohw.
Definition: mkldnn_types.h:234
An opaque structure to describe an execution stream.
void set_alpha(float alpha)
Definition: mkldnn.hpp:3016
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2123
Definition: mkldnn.hpp:660
5D data tensor with the physical layout ncdhw.
Definition: mkldnn_types.h:178
Definition: mkldnn.hpp:3228
Definition: mkldnn.hpp:624
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(mkldnn_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
5D RNN states tensor in the format (num_layers, num_directions, num_states, batch, state channels).
Definition: mkldnn_types.h:245
Definition: mkldnn.hpp:2135
Definition: mkldnn.hpp:754
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:818
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1574
Definition: mkldnn.hpp:671
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1808
A rnn primitive.
Definition: mkldnn_types.h:519
Definition: mkldnn.hpp:693
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
blocked weights format
Definition: mkldnn_types.h:340
blocked weights format
Definition: mkldnn_types.h:283
mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, const mkldnn_memory_desc_t *diff_data_desc, int axis, int group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc, axis, and group_size.
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1912
Definition: mkldnn.hpp:626
Definition: mkldnn.hpp:2997
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2388
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:458
Definition: mkldnn.hpp:691
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:134
CPU engine.
Definition: mkldnn_types.h:1085
Definition: mkldnn.hpp:293
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2364
Eltwise: square root.
Definition: mkldnn_types.h:546
Definition: mkldnn.hpp:744
Definition: mkldnn.hpp:752
blocked weights format
Definition: mkldnn_types.h:434
Definition: mkldnn.hpp:700
blocked weights format
Definition: mkldnn_types.h:290
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:1282
Definition: mkldnn.hpp:271
Definition: mkldnn.hpp:689
Definition: mkldnn.hpp:613
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor with physical layout hwio, used in TensorFlow.
Definition: mkldnn_types.h:202
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:1193
Winograd convolution.
Definition: mkldnn_types.h:528
Definition: mkldnn.hpp:641
Definition: mkldnn.hpp:246
Definition: mkldnn.hpp:345
Eltwise: linear.
Definition: mkldnn_types.h:548
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1841
bfloat 16-bit.
Definition: mkldnn_types.h:86
mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(mkldnn_softmax_desc_t *softmax_desc, const mkldnn_memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1913
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:1021
Eltwise: logistic.
Definition: mkldnn_types.h:554
Definition: mkldnn.hpp:2688
Direct convolution.
Definition: mkldnn_types.h:526
Primitive iterator passed over last primitive descriptor.
Definition: mkldnn_types.h:64
Definition: mkldnn.hpp:340
Definition: mkldnn.hpp:270
Definition: mkldnn.hpp:734
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:2083
source gradient memory primitive desc
Definition: mkldnn_types.h:1268
mkldnn_alg_kind_t cell_kind
RNN cell kind.
Definition: mkldnn_types.h:1004
Definition: mkldnn.hpp:1502
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2690
Definition: mkldnn_types.h:1021
An opaque structure for primitive descriptor attributes.
Definition: mkldnn.hpp:314
blocked data format
Definition: mkldnn_types.h:279
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
blocked weights format
Definition: mkldnn_types.h:345
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2051
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2660
Definition: mkldnn.hpp:3312
Definition: mkldnn.hpp:761
mkldnn_rnn_cell_desc_t c_rnn_cell_
Definition: mkldnn.hpp:2998
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:827
runtime estimation (seconds)
Definition: mkldnn_types.h:1241
blocked weights format
Definition: mkldnn_types.h:418
bool operator==(const T other) const
Definition: mkldnn.hpp:61
A (in-place) concat primitive.
Definition: mkldnn_types.h:499
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:878
Definition: mkldnn.hpp:664
blocked weights format
Definition: mkldnn_types.h:312
Definition: mkldnn.hpp:677
LSTM cell.
Definition: mkldnn_types.h:573
blocked weights format
Definition: mkldnn_types.h:293
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Definition: mkldnn.hpp:753
Definition: mkldnn_types.h:1030
Definition: mkldnn.hpp:727
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2508
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2834
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2837
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:74
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:372
Definition: mkldnn.hpp:1838
16-bit signed integer.
Definition: mkldnn_types.h:80
Definition: mkldnn.hpp:2322
A shuffle primitive.
Definition: mkldnn_types.h:495
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:319
Definition: mkldnn.hpp:627
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3262
primitive_desc()
Definition: mkldnn.hpp:799
int len() const
Definition: mkldnn.hpp:377
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
blocked weights format
Definition: mkldnn_types.h:328
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1191
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2822
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
blocked data format
Definition: mkldnn_types.h:271
Definition: mkldnn.hpp:242
blocked weights format
Definition: mkldnn_types.h:347
Definition: mkldnn.hpp:684
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1503
blocked weights format
Definition: mkldnn_types.h:337
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:497
blocked weights format
Definition: mkldnn_types.h:358
Definition: mkldnn.hpp:632
Fuse with ReLU.
Definition: mkldnn_types.h:623
Definition: mkldnn.hpp:759
Definition: mkldnn.hpp:686
Definition: mkldnn.hpp:260
Definition: mkldnn.hpp:280
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:521
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:1232
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:935
Definition: mkldnn.hpp:699
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:553
Definition: mkldnn.hpp:3036
blocked weights format
Definition: mkldnn_types.h:373
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2024
blocked data format
Definition: mkldnn_types.h:278
blocked weights format
Definition: mkldnn_types.h:289
A sum primitive.
Definition: mkldnn_types.h:501
blocked weights format
Definition: mkldnn_types.h:360
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2783
Definition: mkldnn.hpp:304
Definition: mkldnn.hpp:630
blocked weights format
Definition: mkldnn_types.h:413
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2346
blocked weights format
Definition: mkldnn_types.h:296
Definition: mkldnn.hpp:740
unsigned flags
Definition: mkldnn_types.h:962
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
blocked weights format
Definition: mkldnn_types.h:295
blocked weights format
Definition: mkldnn_types.h:363
Definition: mkldnn.hpp:2996
Definition: mkldnn.hpp:597
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: mkldnn_types.h:530
softmax_backward(const primitive_desc &aprimitive_desc, const primitive::at &dst, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2472
blocked weights format
Definition: mkldnn_types.h:284
Definition: mkldnn.hpp:3037
Definition: mkldnn.hpp:258
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2336
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
blocked weights format
Definition: mkldnn_types.h:420
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
2D weights tensor with physical layout io.
Definition: mkldnn_types.h:187
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:1242
blocked weights format
Definition: mkldnn_types.h:353
An batch normalization primitive.
Definition: mkldnn_types.h:515
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
Definition: mkldnn_types.h:524
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:531
Definition: mkldnn.hpp:2321
A descriptor of a pooling operation.
Definition: mkldnn_types.h:874
Definition: mkldnn.hpp:642
Definition: mkldnn.hpp:3308
Definition: mkldnn.hpp:274
Definition: mkldnn.hpp:275
engine get_engine()
Definition: mkldnn.hpp:831
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:173
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2000
const char * impl_info_str() const
Returns implementation name.
Definition: mkldnn.hpp:1300
deconvolution descriptor
Definition: mkldnn_types.h:1253
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1181
blocked weights format
Definition: mkldnn_types.h:366
shuffle_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:3279
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:988
primitive_desc(const desc &desc, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2274
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:771
mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(const_mkldnn_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1464
engine get_engine()
Definition: mkldnn.hpp:1007
Definition: mkldnn.hpp:600
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:82
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:454
int MKLDNN_API mkldnn_rnn_cell_get_states_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of states of a particular rnn_cell_desc.
Definition: mkldnn.hpp:2335
friend struct error
Definition: mkldnn.hpp:107
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2932
Definition: mkldnn.hpp:742
source memory primitive desc
Definition: mkldnn_types.h:1267
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:485
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1886
Definition: mkldnn.hpp:720
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1973
Definition: mkldnn.hpp:3239
Winograd deconvolution.
Definition: mkldnn_types.h:534
Definition: mkldnn.hpp:3313
Definition: mkldnn.hpp:248
number of inputs expected
Definition: mkldnn_types.h:1238
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2412
Definition: mkldnn.hpp:347
Definition: mkldnn.hpp:661
Definition: mkldnn.hpp:3061
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2511
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2325
An unspecified engine.
Definition: mkldnn_types.h:1284
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1796
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:891
A view primitive.
Definition: mkldnn_types.h:491
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:3108
Definition: mkldnn.hpp:262
Definition: mkldnn.hpp:670
Definition: mkldnn.hpp:330
Definition: mkldnn.hpp:763
Definition: mkldnn.hpp:625
Definition: mkldnn.hpp:3142
Definition: mkldnn.hpp:755
blocked weights format
Definition: mkldnn_types.h:327
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:154
Definition: mkldnn.hpp:731
Definition: mkldnn.hpp:747
Definition: mkldnn.hpp:713
blocked data format
Definition: mkldnn_types.h:274
Definition: mkldnn.hpp:342
Definition: mkldnn.hpp:730
Definition: mkldnn.hpp:333
Definition: mkldnn.hpp:682
Definition: mkldnn.hpp:325
Definition: mkldnn.hpp:335
Average pooling exclude padding.
Definition: mkldnn_types.h:564
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
primitive::kind kind(int index) const
Definition: mkldnn.hpp:379
Definition: mkldnn_types.h:1000
Forward data propagation (inference mode).
Definition: mkldnn_types.h:468
primitive_attr get_primitive_attr() const
Definition: mkldnn.hpp:1286
6D grouped weights tensor with the physical layout goidhw, used in Caffe.
Definition: mkldnn_types.h:238
Definition: mkldnn.hpp:695
5D weights tensor with physical layout iodhw, used in Caffe.
Definition: mkldnn_types.h:214
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:596
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2551
Direct deconvolution.
Definition: mkldnn_types.h:532
Eltwise: abs.
Definition: mkldnn_types.h:544
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2573
blocked weights format
Definition: mkldnn_types.h:388
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2299
blocked weights format
Definition: mkldnn_types.h:313
A memory descriptor.
Definition: mkldnn.hpp:768
deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1895
5D grouped weights tensor with the physical layout hwigo, used in TensorFlow.
Definition: mkldnn_types.h:231
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2339
blocked weights format
Definition: mkldnn_types.h:410
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:955
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:482
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:67
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:538
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2921
mkldnn_status_t status
Definition: mkldnn.hpp:162
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1823
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:409
Definition: mkldnn.hpp:650
mkldnn_primitive_t get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
blocked weights format
Definition: mkldnn_types.h:401
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
Definition: mkldnn.hpp:690
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:1073
blocked weights format
Definition: mkldnn_types.h:367
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1933
Definition: mkldnn.hpp:667
blocked weights format
Definition: mkldnn_types.h:364
2D data tensor.
Definition: mkldnn_types.h:160
primitive_desc(const desc &desc, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2704
Definition: mkldnn.hpp:628
blocked weights format
Definition: mkldnn_types.h:321
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2811
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:3358
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc and dif...
Primitive or engine failed on execution.
Definition: mkldnn_types.h:66
memory descriptor for memory and view
Definition: mkldnn_types.h:1251
Definition: mkldnn.hpp:719
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1082
Definition: mkldnn.hpp:266
Definition: mkldnn.hpp:681
An LRN primitive.
Definition: mkldnn_types.h:513
Definition: mkldnn_types.h:1026
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:452
rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace)
Definition: mkldnn.hpp:3170
Lazy stream.
Definition: mkldnn_types.h:1288
Definition: mkldnn.hpp:334
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2449
blocked weights format
Definition: mkldnn_types.h:415
Definition: mkldnn.hpp:306
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:441
Definition: mkldnn.hpp:739
blocked weights format
Definition: mkldnn_types.h:286
desc(algorithm kind)
Definition: mkldnn.hpp:3006
Definition: mkldnn.hpp:697
primitive_desc(const desc &desc, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3143
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_channels).
Definition: mkldnn_types.h:259
blocked weights format
Definition: mkldnn_types.h:356
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
Definition: mkldnn.hpp:615
Definition: mkldnn.hpp:643
Definition: mkldnn.hpp:2919
Definition: mkldnn.hpp:717
shuffle descriptor
Definition: mkldnn_types.h:1254
Forward data propagation (training mode).
Definition: mkldnn_types.h:464
Definition: mkldnn.hpp:746
Definition: mkldnn.hpp:669
Definition: mkldnn.hpp:346
primitive_desc(const desc &desc, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2136
Definition: mkldnn.hpp:629
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2958
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1576
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:837
3D data tensor with the physical layout nwc.
Definition: mkldnn_types.h:166
engine get_engine()
Definition: mkldnn.hpp:1152
Definition: mkldnn.hpp:616
post_ops()
Definition: mkldnn.hpp:370
An opaque structure to describe a primitive.
Definition: mkldnn.hpp:724
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2744
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:156
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1376
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:72
Definition: mkldnn.hpp:1501
Definition: mkldnn.hpp:603
Definition: mkldnn.hpp:666
Definition: mkldnn.hpp:647
Definition: mkldnn.hpp:614
Definition: mkldnn.hpp:327
Definition: mkldnn.hpp:320
convolution descriptor
Definition: mkldnn_types.h:1252
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1550
Definition: mkldnn.hpp:716
Definition: mkldnn.hpp:609
A memory primitive descriptor.
Definition: mkldnn.hpp:795
Definition: mkldnn.hpp:316
Definition: mkldnn.hpp:2457
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
blocked weights format
Definition: mkldnn_types.h:342
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1546
blocked weights format
Definition: mkldnn_types.h:333
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2674
Eltwise: bounded_relu.
Definition: mkldnn_types.h:550
Definition: mkldnn.hpp:2411
Definition: mkldnn_types.h:1023
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1486
Definition: mkldnn.hpp:748
Definition: mkldnn.hpp:702
Definition: mkldnn.hpp:620
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:1081
Definition: mkldnn_types.h:996
int MKLDNN_API mkldnn_rnn_cell_get_gates_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of gates of a particular rnn_cell_desc.
Queried element is not required for given primitive.
Definition: mkldnn_types.h:68
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3062
blocked weights format
Definition: mkldnn_types.h:437
bool operator!=(const T other) const
Definition: mkldnn.hpp:62
blocked weights format
Definition: mkldnn_types.h:385
Memory primitive that describes the data.
Definition: mkldnn.hpp:581
Weights format used in 8bit Winograd convolution.
Definition: mkldnn_types.h:442
Definition: mkldnn.hpp:329
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2072
Definition: mkldnn.hpp:2110
Definition: mkldnn.hpp:303
Round nearest.
Definition: mkldnn_types.h:92
blocked weights format
Definition: mkldnn_types.h:436
Definition: mkldnn.hpp:243
Definition: mkldnn.hpp:3311
Definition: mkldnn.hpp:707
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2725
Definition: mkldnn.hpp:1712
const void * const_mkldnn_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: mkldnn_types.h:713
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:3315
blocked weights format
Definition: mkldnn_types.h:285
blocked weights format
Definition: mkldnn_types.h:431
Definition: mkldnn.hpp:1910
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1140
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create_v2(mkldnn_primitive_desc_iterator_t *iterator, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
Definition: mkldnn.hpp:2493
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:2235
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1473
4D weights tensor with physical layout iohw.
Definition: mkldnn_types.h:211
A reorder primitive.
Definition: mkldnn_types.h:493
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1799
rnn_direction
Definition: mkldnn.hpp:301
Definition: mkldnn.hpp:676
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1210
blocked weights format
Definition: mkldnn_types.h:411
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:399
blocked weights format
Definition: mkldnn_types.h:336
Definition: mkldnn.hpp:638
An unspecified engine.
Definition: mkldnn_types.h:1083
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:791
blocked weights format
Definition: mkldnn_types.h:359
Definition: mkldnn.hpp:599
Definition: mkldnn.hpp:1180
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
engine get_engine()
Definition: mkldnn.hpp:1070
Definition: mkldnn.hpp:692
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2278
friend class primitive_at
Definition: mkldnn.hpp:109
blocked weights format
Definition: mkldnn_types.h:412
Definition: mkldnn.hpp:733
blocked weights format
Definition: mkldnn_types.h:387
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:523
Definition: mkldnn.hpp:729
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2944
Definition: mkldnn.hpp:263
inner product descriptor
Definition: mkldnn_types.h:1260
blocked weights format
Definition: mkldnn_types.h:394
A pooling primitive.
Definition: mkldnn_types.h:511
Definition: mkldnn.hpp:736
weights memory primitive descriptor desc
Definition: mkldnn_types.h:1269
output memory primitive desc
Definition: mkldnn_types.h:1266
Definition: mkldnn.hpp:2273
blocked weights format
Definition: mkldnn_types.h:417
blocked weights format
Definition: mkldnn_types.h:349
5D weights tensor with physical layout dhwio, used in TensorFlow.
Definition: mkldnn_types.h:217
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2075
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2495
Definition: mkldnn.hpp:987
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:336
Definition: mkldnn.hpp:637
std::string message
Definition: mkldnn.hpp:163
Definition: mkldnn.hpp:3227
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc, const mkldnn_memory_desc_t *diff_src_layer_desc, const mkldnn_memory_desc_t *diff_src_iter_desc, const mkldnn_memory_desc_t *diff_weights_layer_desc, const mkldnn_memory_desc_t *diff_weights_iter_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_layer, const mkldnn_memory_desc_t *diff_dst_iter_desc)
Initializes a rnn descriptor rnn_desc for backward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
Definition: mkldnn.hpp:655
primitive_desc(const desc &desc, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2375
Definition: mkldnn.hpp:317
blocked weights format
Definition: mkldnn_types.h:324
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:472
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: mkldnn_types.h:240
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:241
lrn descriptor
Definition: mkldnn_types.h:1258
Definition: mkldnn.hpp:674
workspace memory primitive desc
Definition: mkldnn_types.h:1273
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2163
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1637
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: mkldnn.hpp:1314
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
Definition: mkldnn.hpp:622
blocked weights format
Definition: mkldnn_types.h:282
blocked weights format
Definition: mkldnn_types.h:291
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1714
Definition: mkldnn.hpp:645
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2692
blocked weights format
Definition: mkldnn_types.h:343
Definition: mkldnn.hpp:224
weights format with additional buffer size equal to the number of output channels and containing the ...
Definition: mkldnn_types.h:308
Definition: mkldnn.hpp:749
Definition: mkldnn.hpp:617
Definition: mkldnn.hpp:694
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2140
float get_clipping() const
Definition: mkldnn.hpp:3021
weights grad.
Definition: mkldnn_types.h:1270
4D data tensor with the physical layout nchw, used in Caffe.
Definition: mkldnn_types.h:169
Definition: mkldnn.hpp:323
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc)
Initializes a rnn descriptor rnn_desc for forward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
Definition: mkldnn.hpp:621
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:397
primitive kind
Definition: mkldnn_types.h:1236
blocked data format
Definition: mkldnn_types.h:272
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1859
int get_state_count() const
Definition: mkldnn.hpp:3030
blocked weights format
Definition: mkldnn_types.h:320
Definition: mkldnn.hpp:319
Definition: mkldnn.hpp:656
An opaque structure to describe a primitive descriptor iterator.
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2251
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2588
kind
Definition: mkldnn.hpp:3311
Definition: mkldnn.hpp:745
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1504
Definition: mkldnn.hpp:341
Definition: mkldnn.hpp:725
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc)
Definition: mkldnn.hpp:3039
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...