oneAPI Deep Neural Network Library (oneDNN)
Performance library for Deep Learning
2.2.4
dnnl.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
19 
20 #ifndef ONEAPI_DNNL_DNNL_HPP
21 #define ONEAPI_DNNL_DNNL_HPP
22 
23 #include "oneapi/dnnl/dnnl_config.h"
24 
26 #include <algorithm>
27 #include <cstdlib>
28 #include <iterator>
29 #include <memory>
30 #include <string>
31 #include <vector>
32 #include <unordered_map>
33 
34 #include "oneapi/dnnl/dnnl.h"
35 
37 
38 // __cpp_exceptions is referred from
39 // https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html
40 // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
41 // Microsoft C++ Compiler does not provide an option to disable exceptions
42 #ifndef DNNL_ENABLE_EXCEPTIONS
43 #if __cpp_exceptions || __EXCEPTIONS \
44  || (defined(_MSC_VER) && !defined(__clang__))
45 #define DNNL_ENABLE_EXCEPTIONS 1
46 #else
47 #define DNNL_ENABLE_EXCEPTIONS 0
48 #endif
49 #endif
50 
51 #if defined(__GNUC__) || defined(__clang__)
52 #define DNNL_TRAP() __builtin_trap()
53 #elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
54 #define DNNL_TRAP() __debugbreak()
55 #else
56 #error "unknown compiler"
57 #endif
58 
59 #if DNNL_ENABLE_EXCEPTIONS
60 #define DNNL_THROW_ERROR(status, msg) throw error(status, msg)
61 #else
62 #include <cstdio>
63 #define DNNL_THROW_ERROR(status, msg) \
64  do { \
65  fputs(msg, stderr); \
66  DNNL_TRAP(); \
67  } while (0)
68 #endif
69 
72 
74 namespace dnnl {
75 
79 
84 struct error : public std::exception {
86  const char *message;
87 
92  error(dnnl_status_t status, const char *message)
93  : status(status), message(message) {}
94 
96  const char *what() const noexcept override { return message; }
97 
103  static void wrap_c_api(dnnl_status_t status, const char *message) {
104  if (status != dnnl_success) DNNL_THROW_ERROR(status, message);
105  }
106 };
107 
109 template <typename T>
110 void validate_container_size(const T &v, const char *error_message,
111  int min_size = 1, int max_size = -1) {
112  const int size = (int)v.size();
113  if (size < min_size || (max_size >= 0 && size > max_size))
114  DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
115 }
117 
119 template <typename T>
120 struct handle_traits {};
121 
135 template <typename T, typename traits = handle_traits<T>>
136 struct handle {
137 private:
138  static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
139  std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
140 
141 protected:
142  bool operator==(const T other) const { return other == data_.get(); }
143  bool operator!=(const T other) const { return !(*this == other); }
144 
145 public:
153  handle() = default;
154 
156  handle(const handle<T, traits> &) = default;
160  handle(handle<T, traits> &&) = default;
163 
169  explicit handle(T t, bool weak = false) { reset(t, weak); }
170 
176  void reset(T t, bool weak = false) {
177  data_.reset(t, weak ? &dummy_destructor : traits::destructor);
178  }
179 
185  T get(bool allow_empty = false) const {
186  T result = data_.get();
187  if (allow_empty == false && result == nullptr)
188  DNNL_THROW_ERROR(
189  dnnl_invalid_arguments, "object is not initialized");
190  return result;
191  }
192 
197  explicit operator T() const { return get(true); }
198 
202  explicit operator bool() const { return get(true) != nullptr; }
203 
210  bool operator==(const handle<T, traits> &other) const {
211  return other.data_.get() == data_.get();
212  }
213 
220  bool operator!=(const handle &other) const { return !(*this == other); }
221 };
222 
224 template <>
225 struct handle_traits<dnnl_memory_t> {
226  static dnnl_status_t destructor(dnnl_memory_t p) {
227  return dnnl_memory_destroy(p);
228  }
229 };
230 
231 template <>
232 struct handle_traits<dnnl_primitive_desc_t> {
233  static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
234  return dnnl_primitive_desc_destroy(p);
235  }
236 };
237 
238 template <>
239 struct handle_traits<dnnl_primitive_t> {
240  static dnnl_status_t destructor(dnnl_primitive_t p) {
241  return dnnl_primitive_destroy(p);
242  }
243 };
244 
245 template <>
246 struct handle_traits<dnnl_primitive_desc_iterator_t> {
247  static dnnl_status_t destructor(dnnl_primitive_desc_iterator_t p) {
249  }
250 };
252 
254 
255 struct stream;
256 struct memory;
257 struct primitive_desc;
258 
263 
267 
269 struct primitive : public handle<dnnl_primitive_t> {
271  enum class kind {
281  sum = dnnl_sum,
293  lrn = dnnl_lrn,
301  rnn = dnnl_rnn,
315  prelu = dnnl_prelu,
316  };
317 
318  using handle::handle;
319 
321  primitive() = default;
322 
327 
332 
338 
342  inline kind get_kind() const;
343 
356  void execute(const stream &astream,
357  const std::unordered_map<int, memory> &args) const;
358 };
359 
365  return static_cast<dnnl_primitive_kind_t>(akind);
366 }
367 
371  "could not get a primitive descriptor from a primitive");
372  return pd;
373 }
374 
377  // TODO (Roma): the code below is only needed because get_primitive_desc
378  // returns a C type.
381  pd, dnnl_query_primitive_kind, 0, (void *)&kind),
382  "could not get a primitive kind from a primitive descriptor");
383  return static_cast<dnnl::primitive::kind>(kind);
384 }
385 
387 
399 
401 enum class scratchpad_mode {
424 };
425 
431  return static_cast<dnnl_scratchpad_mode_t>(mode);
432 }
433 
435 enum class prop_kind {
459 };
460 
466  return static_cast<dnnl_prop_kind_t>(akind);
467 }
468 
470 enum class algorithm {
472  undef = dnnl_alg_kind_undef,
616 };
617 
622  return static_cast<dnnl_alg_kind_t>(aalgorithm);
623 }
624 
626 
629 
631 enum class normalization_flags : unsigned {
637 
646 
653 
659 };
660 
665  return static_cast<dnnl_normalization_flags_t>(flags);
666 }
667 
669 
672 
674 enum class rnn_flags : unsigned {
677 };
678 
683  return static_cast<dnnl_rnn_flags_t>(flags);
684 }
685 
686 #define DNNL_DEFINE_BITMASK_OPS(enum_name) \
687  inline enum_name operator|(enum_name lhs, enum_name rhs) { \
688  return static_cast<enum_name>( \
689  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
690  } \
691 \
692  inline enum_name operator&(enum_name lhs, enum_name rhs) { \
693  return static_cast<enum_name>( \
694  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
695  } \
696 \
697  inline enum_name operator^(enum_name lhs, enum_name rhs) { \
698  return static_cast<enum_name>( \
699  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
700  } \
701 \
702  inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
703  lhs = static_cast<enum_name>( \
704  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
705  return lhs; \
706  } \
707 \
708  inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
709  lhs = static_cast<enum_name>( \
710  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
711  return lhs; \
712  } \
713 \
714  inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
715  lhs = static_cast<enum_name>( \
716  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
717  return lhs; \
718  } \
719 \
720  inline enum_name operator~(enum_name rhs) { \
721  return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
722  }
723 
724 DNNL_DEFINE_BITMASK_OPS(normalization_flags)
725 DNNL_DEFINE_BITMASK_OPS(rnn_flags)
726 
727 enum class rnn_direction {
741 };
742 
747  return static_cast<dnnl_rnn_direction_t>(dir);
748 }
749 
751 
754 
761 enum class query {
764 
769 
774 
781 
786 
791 
794 
797 
832 
851 };
852 
857  return static_cast<dnnl_query_t>(aquery);
858 }
859 
861 
863 
874 
876 template <>
877 struct handle_traits<dnnl_engine_t> {
878  static dnnl_status_t destructor(dnnl_engine_t p) {
879  return dnnl_engine_destroy(p);
880  }
881 };
883 
885 struct engine : public handle<dnnl_engine_t> {
886  friend struct primitive;
887  friend struct reorder;
888 
890  enum class kind {
894  cpu = dnnl_cpu,
896  gpu = dnnl_gpu,
897  };
898 
899  using handle::handle;
900 
903  engine() = default;
904 
909  static size_t get_count(kind akind) {
910  return dnnl_engine_get_count(convert_to_c(akind));
911  }
912 
918  engine(kind akind, size_t index) {
921  dnnl_engine_create(&engine, convert_to_c(akind), index),
922  "could not create an engine");
923  reset(engine);
924  }
925 
931  dnnl_engine_t c_engine;
934  dnnl::convert_to_c(dnnl::query::engine), 0, &c_engine),
935  "could not get an engine from a primitive_desc");
936  reset(c_engine, true);
937  }
938 
941  kind get_kind() const {
944  "could not get kind of an engine");
945  return static_cast<engine::kind>(kind);
946  }
947 
953  template <typename primitive_desc>
954  static engine query(const primitive_desc &pd) {
955  return query(pd, dnnl::query::engine);
956  }
957 
958 private:
959  static dnnl_engine_kind_t convert_to_c(kind akind) {
960  return static_cast<dnnl_engine_kind_t>(akind);
961  }
962 
963  template <typename primitive_desc>
964  static engine query(const primitive_desc &pd, dnnl::query what) {
965  dnnl_engine_t c_engine;
967  dnnl::convert_to_c(what), 0, &c_engine),
968  "could not get an engine from a primitive_desc");
969  return engine(c_engine, true);
970  }
971 };
972 
978  return static_cast<dnnl_engine_kind_t>(akind);
979 }
980 
982 
990 
992 template <>
993 struct handle_traits<dnnl_stream_t> {
994  static dnnl_status_t destructor(dnnl_stream_t p) {
995  return dnnl_stream_destroy(p);
996  }
997 };
999 
1001 struct stream : public handle<dnnl_stream_t> {
1002  using handle::handle;
1003 
1005  enum class flags : unsigned {
1007  in_order = dnnl_stream_in_order,
1012  };
1013 
1016  stream() = default;
1017 
1023  stream(const engine &aengine, flags aflags = flags::default_flags) {
1026  static_cast<dnnl_stream_flags_t>(aflags)),
1027  "could not create a stream");
1028  reset(stream);
1029  }
1030 
1032  engine get_engine() const {
1033  dnnl_engine_t c_engine;
1035  "could not get an engine from a stream object");
1036  return engine(c_engine, true);
1037  }
1038 
1043  dnnl_stream_wait(get()), "could not wait on a stream");
1044  return *this;
1045  }
1046 };
1047 
1048 DNNL_DEFINE_BITMASK_OPS(stream::flags)
1049 
1050 
1117 
1124 struct memory : public handle<dnnl_memory_t> {
1125  using handle::handle;
1126 
1128  typedef dnnl_dim_t dim;
1131  typedef std::vector<dim> dims;
1132 
1139  template <typename T>
1140  static void validate_dims(const std::vector<T> &v, int min_size = 0) {
1141  validate_container_size(
1142  v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
1143  }
1144 
1146  enum class data_type {
1150  f16 = dnnl_f16,
1153  bf16 = dnnl_bf16,
1155  f32 = dnnl_f32,
1157  s32 = dnnl_s32,
1159  s8 = dnnl_s8,
1161  u8 = dnnl_u8,
1162  };
1163 
1166  static size_t data_type_size(data_type adata_type) {
1167  return dnnl_data_type_size(convert_to_c(adata_type));
1168  }
1169 
1171  enum class format_kind {
1176  any = dnnl_format_kind_any,
1180  blocked = dnnl_blocked,
1182  wino = dnnl_format_kind_wino,
1184  packed = dnnl_format_kind_rnn_packed,
1185  };
1186 
1227  enum class format_tag {
1232  any = dnnl_format_tag_any,
1233 
1235  a = dnnl_a,
1236 
1238  ab = dnnl_ab,
1240  ba = dnnl_ba,
1241 
1243  abc = dnnl_abc,
1245  acb = dnnl_acb,
1247  bac = dnnl_bac,
1249  bca = dnnl_bca,
1251  cba = dnnl_cba,
1252 
1254  abcd = dnnl_abcd,
1256  abdc = dnnl_abdc,
1258  acdb = dnnl_acdb,
1260  bacd = dnnl_bacd,
1262  bcda = dnnl_bcda,
1264  cdba = dnnl_cdba,
1266  dcab = dnnl_dcab,
1267 
1269  abcde = dnnl_abcde,
1271  abdec = dnnl_abdec,
1273  acbde = dnnl_acbde,
1275  acdeb = dnnl_acdeb,
1277  bacde = dnnl_bacde,
1279  bcdea = dnnl_bcdea,
1281  cdeba = dnnl_cdeba,
1283  decab = dnnl_decab,
1285  abced = dnnl_abced,
1286 
1288  abcdef = dnnl_abcdef,
1290  abdfce = dnnl_abdfce,
1292  acbdef = dnnl_acbdef,
1294  abdefc = dnnl_abdefc,
1296  defcab = dnnl_defcab,
1298  abcdfe = dnnl_abcdfe,
1299 
1301  abcdefg = dnnl_abcdefg,
1303  abcdegf = dnnl_abcdegf,
1304 
1306  abcdefgh = dnnl_abcdefgh,
1308  abcdefhg = dnnl_abcdefhg,
1309 
1311  abcdefghi = dnnl_abcdefghi,
1313  abcdefgih = dnnl_abcdefgih,
1314 
1316  abcdefghij = dnnl_abcdefghij,
1318  abcdefghji = dnnl_abcdefghji,
1319 
1321  abcdefghijk = dnnl_abcdefghijk,
1323  abcdefghikj = dnnl_abcdefghikj,
1324 
1326  abcdefghijkl = dnnl_abcdefghijkl,
1328  abcdefghijlk = dnnl_abcdefghijlk,
1329 
1331  x = a,
1333  nc = ab,
1335  cn = ba,
1337  tn = ab,
1339  nt = ba,
1341  ncw = abc,
1343  nwc = acb,
1345  nchw = abcd,
1347  nhwc = acdb,
1349  chwn = bcda,
1351  ncdhw = abcde,
1353  ndhwc = acdeb,
1354 
1356  oi = ab,
1358  io = ba,
1360  oiw = abc,
1362  owi = acb,
1364  wio = cba,
1366  iwo = bca,
1368  oihw = abcd,
1370  hwio = cdba,
1372  ohwi = acdb,
1374  ihwo = bcda,
1376  iohw = bacd,
1378  oidhw = abcde,
1380  dhwio = cdeba,
1382  odhwi = acdeb,
1384  iodhw = bacde,
1386  idhwo = bcdea,
1387 
1389  goiw = abcd,
1391  gowi = abdc,
1393  wigo = dcab,
1395  gohwi = abdec,
1397  goihw = abcde,
1399  hwigo = decab,
1401  giohw = acbde,
1403  goidhw = abcdef,
1405  giodhw = acbdef,
1407  godhwi = abdefc,
1409  dhwigo = defcab,
1410 
1413  tnc = abc,
1416  ntc = bac,
1419  ldnc = abcd,
1427  ldigo = abcde,
1435  ldgoi = abdec,
1439  ldio = abcd,
1443  ldoi = abdc,
1451  ldgo = abcd,
1452 
1453  // Opaque blocked formats
1454 
1455  AB16b16a = dnnl_AB16b16a,
1456  AB16b32a = dnnl_AB16b32a,
1457  AB16b64a = dnnl_AB16b64a,
1458  AB8b16a2b = dnnl_AB8b16a2b,
1459  AB8b32a2b = dnnl_AB8b32a2b,
1460  AB8b64a2b = dnnl_AB8b64a2b,
1461  AB4b16a4b = dnnl_AB4b16a4b,
1462  AB4b32a4b = dnnl_AB4b32a4b,
1463  AB4b64a4b = dnnl_AB4b64a4b,
1464  AB16b16a4b = dnnl_AB16b16a4b,
1465  AB16b32a4b = dnnl_AB16b32a4b,
1466  AB16b48a4b = dnnl_AB16b48a4b,
1467  AB16b64a4b = dnnl_AB16b64a4b,
1468  AB16b16a2b = dnnl_AB16b16a2b,
1469  AB16b32a2b = dnnl_AB16b32a2b,
1470  AB16b48a2b = dnnl_AB16b48a2b,
1471  AB16b64a2b = dnnl_AB16b64a2b,
1472  Abc16a = dnnl_Abc16a,
1473  ABc16a16b = dnnl_ABc16a16b,
1474  ABc4a4b = dnnl_ABc4a4b,
1475  aBc16b = dnnl_aBc16b,
1476  aBc32b = dnnl_aBc32b,
1477  ABc16b16a = dnnl_ABc16b16a,
1478  ABc16b32a = dnnl_ABc16b32a,
1479  ABc16b64a = dnnl_ABc16b64a,
1480  Abc4a = dnnl_Abc4a,
1481  aBc4b = dnnl_aBc4b,
1482  ABc4b16a4b = dnnl_ABc4b16a4b,
1483  ABc4b32a4b = dnnl_ABc4b32a4b,
1484  ABc4b64a4b = dnnl_ABc4b64a4b,
1485  ABc2b8a4b = dnnl_ABc2b8a4b,
1486  ABc16a16b2a = dnnl_ABc16a16b2a,
1487  ABc16b16a4b = dnnl_ABc16b16a4b,
1488  ABc16b32a4b = dnnl_ABc16b32a4b,
1489  ABc16b48a4b = dnnl_ABc16b48a4b,
1490  ABc16b64a4b = dnnl_ABc16b64a4b,
1491  ABc16b16a2b = dnnl_ABc16b16a2b,
1492  ABc16b32a2b = dnnl_ABc16b32a2b,
1493  ABc16b48a2b = dnnl_ABc16b48a2b,
1494  ABc16b64a2b = dnnl_ABc16b64a2b,
1495  ABc4b4a = dnnl_ABc4b4a,
1496  ABc8a16b2a = dnnl_ABc8a16b2a,
1497  ABc8a8b = dnnl_ABc8a8b,
1498  ABc8a4b = dnnl_ABc8a4b,
1499  aBc8b = dnnl_aBc8b,
1500  ABc8b16a2b = dnnl_ABc8b16a2b,
1501  ABc8b32a2b = dnnl_ABc8b32a2b,
1502  ABc8b64a2b = dnnl_ABc8b64a2b,
1503  ABc8b8a = dnnl_ABc8b8a,
1504  Abcd8a = dnnl_Abcd8a,
1505  Abcd16a = dnnl_Abcd16a,
1506  Abcd32a = dnnl_Abcd32a,
1507  ABcd16a16b = dnnl_ABcd16a16b,
1508  aBcd16b = dnnl_aBcd16b,
1509  aBcd32b = dnnl_aBcd32b,
1510  ABcd16b16a = dnnl_ABcd16b16a,
1511  ABcd16b32a = dnnl_ABcd16b32a,
1512  ABcd16b64a = dnnl_ABcd16b64a,
1513  aBCd16b16c = dnnl_aBCd16b16c,
1514  aBCd16c16b = dnnl_aBCd16c16b,
1515  Abcd4a = dnnl_Abcd4a,
1516  aBcd4b = dnnl_aBcd4b,
1517  ABcd4b16a4b = dnnl_ABcd4b16a4b,
1518  ABcd4b32a4b = dnnl_ABcd4b32a4b,
1519  ABcd4b64a4b = dnnl_ABcd4b64a4b,
1520  ABcd2b8a4b = dnnl_ABcd2b8a4b,
1521  ABcd4b4a = dnnl_ABcd4b4a,
1522  ABcd4a4b = dnnl_ABcd4a4b,
1523  aBCd4c16b4c = dnnl_aBCd4c16b4c,
1524  aBCd2c8b4c = dnnl_aBCd2c8b4c,
1525  ABcd16a16b2a = dnnl_ABcd16a16b2a,
1526  ABcd16b16a4b = dnnl_ABcd16b16a4b,
1527  ABcd16b32a4b = dnnl_ABcd16b32a4b,
1528  ABcd16b48a4b = dnnl_ABcd16b48a4b,
1529  ABcd16b64a4b = dnnl_ABcd16b64a4b,
1530  ABcd16b16a2b = dnnl_ABcd16b16a2b,
1531  ABcd16b32a2b = dnnl_ABcd16b32a2b,
1532  ABcd16b48a2b = dnnl_ABcd16b48a2b,
1533  ABcd16b64a2b = dnnl_ABcd16b64a2b,
1534  aBCd16b16c2b = dnnl_aBCd16b16c2b,
1535  aBCd16c16b4c = dnnl_aBCd16c16b4c,
1536  aBCd16c16b2c = dnnl_aBCd16c16b2c,
1537  aBCd4c4b = dnnl_aBCd4c4b,
1538  aBCd4b4c = dnnl_aBCd4b4c,
1539  ABcd8a16b2a = dnnl_ABcd8a16b2a,
1540  ABcd8a8b = dnnl_ABcd8a8b,
1541  ABcd8a4b = dnnl_ABcd8a4b,
1543  aBcd8b = dnnl_aBcd8b,
1544  ABcd8b16a2b = dnnl_ABcd8b16a2b,
1545  ABcd8b32a2b = dnnl_ABcd8b32a2b,
1546  ABcd8b64a2b = dnnl_ABcd8b64a2b,
1547  aBCd8b16c2b = dnnl_aBCd8b16c2b,
1549  ABcd8b8a = dnnl_ABcd8b8a,
1550  aBCd8b8c = dnnl_aBCd8b8c,
1551  aBCd8b4c = dnnl_aBCd8b4c,
1552  aBCd8c16b2c = dnnl_aBCd8c16b2c,
1553  aBCd8c8b = dnnl_aBCd8c8b,
1554  Abcde16a = dnnl_Abcde16a,
1555  Abcde32a = dnnl_Abcde32a,
1556  ABcde16a16b = dnnl_ABcde16a16b,
1557  aBcde16b = dnnl_aBcde16b,
1558  aBcde32b = dnnl_aBcde32b,
1559  ABcde16b16a = dnnl_ABcde16b16a,
1560  ABcde16b32a = dnnl_ABcde16b32a,
1561  ABcde16b64a = dnnl_ABcde16b64a,
1562  aBCde16b16c = dnnl_aBCde16b16c,
1563  aBCde16c16b = dnnl_aBCde16c16b,
1564  aBCde2c8b4c = dnnl_aBCde2c8b4c,
1565  Abcde4a = dnnl_Abcde4a,
1566  aBcde4b = dnnl_aBcde4b,
1567  ABcde4b4a = dnnl_ABcde4b4a,
1568  ABcde4a4b = dnnl_ABcde4a4b,
1569  aBCde4b4c = dnnl_aBCde4b4c,
1570  aBCde4c16b4c = dnnl_aBCde4c16b4c,
1571  aBCde16b16c2b = dnnl_aBCde16b16c2b,
1572  aBCde16c16b4c = dnnl_aBCde16c16b4c,
1573  aBCde16c16b2c = dnnl_aBCde16c16b2c,
1574  aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
1575  aBCde4c4b = dnnl_aBCde4c4b,
1576  Abcde8a = dnnl_Abcde8a,
1577  ABcde8a8b = dnnl_ABcde8a8b,
1578  ABcde8a4b = dnnl_ABcde8a4b,
1579  aBcde8b = dnnl_aBcde8b,
1580  ABcde8b16a2b = dnnl_ABcde8b16a2b,
1581  ABcde8b32a2b = dnnl_ABcde8b32a2b,
1582  ABcde8b64a2b = dnnl_ABcde8b64a2b,
1583  ABcde4b16a4b = dnnl_ABcde4b16a4b,
1584  ABcde4b32a4b = dnnl_ABcde4b32a4b,
1585  ABcde4b64a4b = dnnl_ABcde4b64a4b,
1586  ABcde16b16a4b = dnnl_ABcde16b16a4b,
1587  ABcde16b32a4b = dnnl_ABcde16b32a4b,
1588  ABcde16b48a4b = dnnl_ABcde16b48a4b,
1589  ABcde16b64a4b = dnnl_ABcde16b64a4b,
1590  ABcde16b16a2b = dnnl_ABcde16b16a2b,
1591  ABcde16b32a2b = dnnl_ABcde16b32a2b,
1592  ABcde16b48a2b = dnnl_ABcde16b48a2b,
1593  ABcde16b64a2b = dnnl_ABcde16b64a2b,
1594  ABcde2b8a4b = dnnl_ABcde2b8a4b,
1595  aBCde8b16c2b = dnnl_aBCde8b16c2b,
1596  ABcde8b8a = dnnl_ABcde8b8a,
1597  aBCde8b8c = dnnl_aBCde8b8c,
1598  aBCde8b4c = dnnl_aBCde8b4c,
1599  ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1600  ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1601  aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1602  aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1603  aBCde8c16b2c = dnnl_aBCde8c16b2c,
1604  aBCde8c8b = dnnl_aBCde8c8b,
1605  aBcdef16b = dnnl_aBcdef16b,
1606  aBCdef16b16c = dnnl_aBCdef16b16c,
1607  aBCdef16c16b = dnnl_aBCdef16c16b,
1608  aBcdef4b = dnnl_aBcdef4b,
1609  aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
1610  aBCdef4c4b = dnnl_aBCdef4c4b,
1611  aBCdef4b4c = dnnl_aBCdef4b4c,
1612  aBCdef8b8c = dnnl_aBCdef8b8c,
1613  aBCdef8b4c = dnnl_aBCdef8b4c,
1614  aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1615  aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
1616  aBCdef8c8b = dnnl_aBCdef8c8b,
1617  aBdc16b = dnnl_aBdc16b,
1618  aBdc4b = dnnl_aBdc4b,
1619  aBdc8b = dnnl_aBdc8b,
1620  aBdec16b = dnnl_aBdec16b,
1621  aBdec4b = dnnl_aBdec4b,
1622  aBdec8b = dnnl_aBdec8b,
1623  aBdefc16b = dnnl_aBdefc16b,
1624  aCBdef16c16b = dnnl_aCBdef16c16b,
1625  aCBdef16b16c = dnnl_aCBdef16b16c,
1626  aBdefc4b = dnnl_aBdefc4b,
1627  aBdefc8b = dnnl_aBdefc8b,
1628  Acb16a = dnnl_Acb16a,
1629  Acb4a = dnnl_Acb4a,
1630  Acb8a = dnnl_Acb8a,
1631  aCBd16b16c = dnnl_aCBd16b16c,
1632  aCBd16c16b = dnnl_aCBd16c16b,
1633  aCBde16b16c = dnnl_aCBde16b16c,
1634  aCBde16c16b = dnnl_aCBde16c16b,
1635  Acdb16a = dnnl_Acdb16a,
1636  Acdb4a = dnnl_Acdb4a,
1637  Acdb8a = dnnl_Acdb8a,
1638  Acdeb16a = dnnl_Acdeb16a,
1639  Acdeb4a = dnnl_Acdeb4a,
1640  Acdeb8a = dnnl_Acdeb8a,
1641  BAc16a16b = dnnl_BAc16a16b,
1642  BAc16b16a = dnnl_BAc16b16a,
1643  BAcd16a16b = dnnl_BAcd16a16b,
1644  BAcd16b16a = dnnl_BAcd16b16a,
1645  ABcd32a32b = dnnl_ABcd32a32b,
1646  BAcde16b16a = dnnl_BAcde16b16a,
1647  BAcde16a16b = dnnl_BAcde16a16b,
1648  aBdec32b = dnnl_aBdec32b,
1649  Abcdef16a = dnnl_Abcdef16a,
1650  Abcdef32a = dnnl_Abcdef32a,
1651  Acdb32a = dnnl_Acdb32a,
1652  aBCd2b4c2b = dnnl_aBCd2b4c2b,
1653  aBCde2b4c2b = dnnl_aBCde2b4c2b,
1654  aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
1655  aBCd2c4b2c = dnnl_aBCd2c4b2c,
1656  aBCde2c4b2c = dnnl_aBCde2c4b2c,
1657  aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
1658  aBCd4b8c2b = dnnl_aBCd4b8c2b,
1659  aBCde4b8c2b = dnnl_aBCde4b8c2b,
1660  aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
1661  aBCd4c8b2c = dnnl_aBCd4c8b2c,
1662  aBCde4c8b2c = dnnl_aBCde4c8b2c,
1663  aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
1664  AB32a32b8a4b = dnnl_AB32a32b8a4b,
1665  AB32a32b8a2b = dnnl_AB32a32b8a2b,
1666  AB8a4b = dnnl_AB8a4b,
1667  AB8a2b = dnnl_AB8a2b,
1668  abDc32d = dnnl_abDc32d,
1669  abDC32d4c = dnnl_abDC32d4c,
1670  abdEc32e = dnnl_abdEc32e,
1671  abdEC32e2c = dnnl_abdEC32e2c,
1672  abdEC32e4c = dnnl_abdEC32e4c,
1673  aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
1674  aBdC16b4c = dnnl_aBdC16b4c,
1675  aBdeC16b4c = dnnl_aBdeC16b4c,
1676  AcB16a4b = dnnl_AcB16a4b,
1677  AcdB16a2b = dnnl_AcdB16a2b,
1678  aBdefC16b4c = dnnl_aBdefC16b4c,
1679  AcdeB16a4b = dnnl_AcdeB16a4b,
1680 
1681  Acb32a = dnnl_Acb32a,
1682  AcB32a2b = dnnl_AcB32a2b,
1683  AcB32a4b = dnnl_AcB32a4b,
1684  Acb48a = dnnl_Acb48a,
1685  AcB48a2b = dnnl_AcB48a2b,
1686  AcB48a4b = dnnl_AcB48a4b,
1687  Acb64a = dnnl_Acb64a,
1688  AcB64a2b = dnnl_AcB64a2b,
1689  AcB64a4b = dnnl_AcB64a4b,
1690  cBa2b = dnnl_cBa2b,
1691  cBa4b = dnnl_cBa4b,
1692  aBdc32b = dnnl_aBdc32b,
1693  aBdC32b2c = dnnl_aBdC32b2c,
1694  aBdC32b4c = dnnl_aBdC32b4c,
1695  aBdc48b = dnnl_aBdc48b,
1696  aBdC48b2c = dnnl_aBdC48b2c,
1697  aBdC48b4c = dnnl_aBdC48b4c,
1698  aBdc64b = dnnl_aBdc64b,
1699  aBdC64b2c = dnnl_aBdC64b2c,
1700  aBdC64b4c = dnnl_aBdC64b4c,
1701  adcb = dnnl_adcb,
1702  adCb2c = dnnl_adCb2c,
1703  adCb4c = dnnl_adCb4c,
1704  AcdB32a2b = dnnl_AcdB32a2b,
1705  AcdB32a4b = dnnl_AcdB32a4b,
1706  Acdb48a = dnnl_Acdb48a,
1707  AcdB48a2b = dnnl_AcdB48a2b,
1708  AcdB48a4b = dnnl_AcdB48a4b,
1709  Acdb64a = dnnl_Acdb64a,
1710  AcdB64a2b = dnnl_AcdB64a2b,
1711  AcdB64a4b = dnnl_AcdB64a4b,
1712  cdBa2b = dnnl_cdBa2b,
1713  cdBa4b = dnnl_cdBa4b,
1714  aBdeC32b2c = dnnl_aBdeC32b2c,
1715  aBdeC32b4c = dnnl_aBdeC32b4c,
1716  aBdec48b = dnnl_aBdec48b,
1717  aBdeC48b2c = dnnl_aBdeC48b2c,
1718  aBdeC48b4c = dnnl_aBdeC48b4c,
1719  aBdec64b = dnnl_aBdec64b,
1720  aBdeC64b2c = dnnl_aBdeC64b2c,
1721  aBdeC64b4c = dnnl_aBdeC64b4c,
1722  adecb = dnnl_adecb,
1723  adeCb2c = dnnl_adeCb2c,
1724  adeCb4c = dnnl_adeCb4c,
1725  Acdeb32a = dnnl_Acdeb32a,
1726  AcdeB32a2b = dnnl_AcdeB32a2b,
1727  AcdeB32a4b = dnnl_AcdeB32a4b,
1728  Acdeb48a = dnnl_Acdeb48a,
1729  AcdeB48a2b = dnnl_AcdeB48a2b,
1730  AcdeB48a4b = dnnl_AcdeB48a4b,
1731  Acdeb64a = dnnl_Acdeb64a,
1732  AcdeB64a2b = dnnl_AcdeB64a2b,
1733  AcdeB64a4b = dnnl_AcdeB64a4b,
1734  cdeBa2b = dnnl_cdeBa2b,
1735  cdeBa4b = dnnl_cdeBa4b,
1736  aBdefc32b = dnnl_aBdefc32b,
1737  aBdefC32b2c = dnnl_aBdefC32b2c,
1738  aBdefC32b4c = dnnl_aBdefC32b4c,
1739  aBdefc48b = dnnl_aBdefc48b,
1740  aBdefC48b2c = dnnl_aBdefC48b2c,
1741  aBdefC48b4c = dnnl_aBdefC48b4c,
1742  aBdefc64b = dnnl_aBdefc64b,
1743  aBdefC64b2c = dnnl_aBdefC64b2c,
1744  aBdefC64b4c = dnnl_aBdefC64b4c,
1745  adefcb = dnnl_adefcb,
1746  adefCb2c = dnnl_adefCb2c,
1747  adefCb4c = dnnl_adefCb4c,
1748 
1749  format_tag_last = dnnl_format_tag_last,
1750 
1751  nCdhw16c = dnnl_nCdhw16c,
1752  nCdhw4c = dnnl_nCdhw4c,
1753  nCdhw8c = dnnl_nCdhw8c,
1754  nChw16c = dnnl_nChw16c,
1755  nChw4c = dnnl_nChw4c,
1756  nChw8c = dnnl_nChw8c,
1757  nCw16c = dnnl_nCw16c,
1758  nCw4c = dnnl_nCw4c,
1759  nCw8c = dnnl_nCw8c,
1760  NCw16n16c = dnnl_NCw16n16c,
1761  NChw16n16c = dnnl_NChw16n16c,
1762  NCdhw16n16c = dnnl_NCdhw16n16c,
1763  NCdhw32n32c = dnnl_NCdhw32n32c,
1764  NChw32n32c = dnnl_NChw32n32c,
1765  IOhw16i16o = dnnl_IOhw16i16o,
1766  OI16i16o = dnnl_OI16i16o,
1767  OI16i32o = dnnl_OI16i32o,
1768  OI16i64o = dnnl_OI16i64o,
1769  OI8i16o2i = dnnl_OI8i16o2i,
1770  OI8i32o2i = dnnl_OI8i32o2i,
1771  OI8i64o2i = dnnl_OI8i64o2i,
1772  OI4i16o4i = dnnl_OI4i16o4i,
1773  OI4i32o4i = dnnl_OI4i32o4i,
1774  OI4i64o4i = dnnl_OI4i64o4i,
1775  Ohwi32o = dnnl_Ohwi32o,
1776  IOdhw16i16o = dnnl_IOdhw16i16o,
1777  gIOhw16i16o = dnnl_gIOhw16i16o,
1778  gOhwi32o = dnnl_gOhwi32o,
1779  Goidhw16g = dnnl_Goidhw16g,
1780  IOw16o16i = dnnl_IOw16o16i,
1781  OIw16i16o = dnnl_OIw16i16o,
1782  OIw16i32o = dnnl_OIw16i32o,
1783  OIw16i64o = dnnl_OIw16i64o,
1784  IOw16i16o = dnnl_IOw16i16o,
1785  gIOw16i16o = dnnl_gIOw16i16o,
1786  OIw16o16i = dnnl_OIw16o16i,
1787  Oiw16o = dnnl_Oiw16o,
1788  OIw4i16o4i = dnnl_OIw4i16o4i,
1789  OIw4i32o4i = dnnl_OIw4i32o4i,
1790  OIw4i64o4i = dnnl_OIw4i64o4i,
1791  OIw2i8o4i = dnnl_OIw2i8o4i,
1792  OIw4i4o = dnnl_OIw4i4o,
1793  OIw4o4i = dnnl_OIw4o4i,
1794  Oiw4o = dnnl_Oiw4o,
1795  OIw8i16o2i = dnnl_OIw8i16o2i,
1796  OIw8i32o2i = dnnl_OIw8i32o2i,
1797  OIw8i64o2i = dnnl_OIw8i64o2i,
1798  OIw8i8o = dnnl_OIw8i8o,
1799  OIw8o16i2o = dnnl_OIw8o16i2o,
1800  OIw8o8i = dnnl_OIw8o8i,
1801  OIw8o4i = dnnl_OIw8o4i,
1802  OIw16i16o4i = dnnl_OIw16i16o4i,
1803  OIw16i32o4i = dnnl_OIw16i32o4i,
1804  OIw16i48o4i = dnnl_OIw16i48o4i,
1805  OIw16i64o4i = dnnl_OIw16i64o4i,
1806  OIw16i16o2i = dnnl_OIw16i16o2i,
1807  OIw16i32o2i = dnnl_OIw16i32o2i,
1808  OIw16i48o2i = dnnl_OIw16i48o2i,
1809  OIw16i64o2i = dnnl_OIw16i64o2i,
1810  OIw16o16i2o = dnnl_OIw16o16i2o,
1811  Owi16o = dnnl_Owi16o,
1812  OwI16o2i = dnnl_OwI16o2i,
1813  Owi4o = dnnl_Owi4o,
1814  Owi8o = dnnl_Owi8o,
1815  IOhw16o16i = dnnl_IOhw16o16i,
1816  Ohwi16o = dnnl_Ohwi16o,
1817  OhwI16o2i = dnnl_OhwI16o2i,
1818  Ohwi4o = dnnl_Ohwi4o,
1819  Ohwi8o = dnnl_Ohwi8o,
1820  OIhw16i16o = dnnl_OIhw16i16o,
1821  OIhw16i32o = dnnl_OIhw16i32o,
1822  OIhw16i64o = dnnl_OIhw16i64o,
1823  OIhw16o16i = dnnl_OIhw16o16i,
1824  Oihw16o = dnnl_Oihw16o,
1825  OIhw4i16o4i = dnnl_OIhw4i16o4i,
1826  OIhw4i32o4i = dnnl_OIhw4i32o4i,
1827  OIhw4i64o4i = dnnl_OIhw4i64o4i,
1828  OIhw4i4o = dnnl_OIhw4i4o,
1829  OIhw4o4i = dnnl_OIhw4o4i,
1830  Oihw4o = dnnl_Oihw4o,
1831  OIhw8i16o2i = dnnl_OIhw8i16o2i,
1832  OIhw8i32o2i = dnnl_OIhw8i32o2i,
1833  OIhw8i64o2i = dnnl_OIhw8i64o2i,
1834  OIhw8i8o = dnnl_OIhw8i8o,
1835  OIhw8o16i2o = dnnl_OIhw8o16i2o,
1836  OIhw8o8i = dnnl_OIhw8o8i,
1837  OIhw8o4i = dnnl_OIhw8o4i,
1838  OIhw2i8o4i = dnnl_OIhw2i8o4i,
1839  IOdhw16o16i = dnnl_IOdhw16o16i,
1840  Odhwi16o = dnnl_Odhwi16o,
1841  OdhwI16o2i = dnnl_OdhwI16o2i,
1842  Odhwi4o = dnnl_Odhwi4o,
1843  Odhwi8o = dnnl_Odhwi8o,
1844  OIdhw16i16o = dnnl_OIdhw16i16o,
1845  OIdhw16i32o = dnnl_OIdhw16i32o,
1846  OIdhw16i64o = dnnl_OIdhw16i64o,
1847  OIdhw16o16i = dnnl_OIdhw16o16i,
1848  Oidhw16o = dnnl_Oidhw16o,
1849  OIdhw4i4o = dnnl_OIdhw4i4o,
1850  OIdhw4o4i = dnnl_OIdhw4o4i,
1851  Oidhw4o = dnnl_Oidhw4o,
1852  OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1853  OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
1854  OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
1855  OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
1856  OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
1857  OIdhw16i32o4i = dnnl_OIdhw16i32o4i,
1858  OIdhw16i48o4i = dnnl_OIdhw16i48o4i,
1859  OIdhw16i64o4i = dnnl_OIdhw16i64o4i,
1860  OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
1861  OIdhw16i32o2i = dnnl_OIdhw16i32o2i,
1862  OIdhw16i48o2i = dnnl_OIdhw16i48o2i,
1863  OIdhw16i64o2i = dnnl_OIdhw16i64o2i,
1864  OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
1865  OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
1866  OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
1867  OIdhw8i8o = dnnl_OIdhw8i8o,
1868  OIdhw8o8i = dnnl_OIdhw8o8i,
1869  OIdhw8o4i = dnnl_OIdhw8o4i,
1870  gIOw16o16i = dnnl_gIOw16o16i,
1871  gOIw16i16o = dnnl_gOIw16i16o,
1872  gOIw16o16i = dnnl_gOIw16o16i,
1873  gOiw16o = dnnl_gOiw16o,
1874  gOIw4i16o4i = dnnl_gOIw4i16o4i,
1875  gOIw2i8o4i = dnnl_gOIw2i8o4i,
1876  gOIw4i4o = dnnl_gOIw4i4o,
1877  gOIw4o4i = dnnl_gOIw4o4i,
1878  gOiw4o = dnnl_gOiw4o,
1879  gOIw8i16o2i = dnnl_gOIw8i16o2i,
1880  gOIw8i8o = dnnl_gOIw8i8o,
1881  gOIw8o16i2o = dnnl_gOIw8o16i2o,
1882  gOIw8o8i = dnnl_gOIw8o8i,
1883  gOIw8o4i = dnnl_gOIw8o4i,
1884  gOIw16i16o4i = dnnl_gOIw16i16o4i,
1885  gOIw16i16o2i = dnnl_gOIw16i16o2i,
1886  gOIw16o16i2o = dnnl_gOIw16o16i2o,
1887  gOwi16o = dnnl_gOwi16o,
1888  gOwI16o2i = dnnl_gOwI16o2i,
1889  gOwi4o = dnnl_gOwi4o,
1890  gOwi8o = dnnl_gOwi8o,
1891  Goiw8g = dnnl_Goiw8g,
1892  Goiw16g = dnnl_Goiw16g,
1893  gIOhw16o16i = dnnl_gIOhw16o16i,
1894  gOhwi16o = dnnl_gOhwi16o,
1895  gOhwI16o2i = dnnl_gOhwI16o2i,
1896  gOhwi4o = dnnl_gOhwi4o,
1897  gOhwi8o = dnnl_gOhwi8o,
1898  Goihw16g = dnnl_Goihw16g,
1899  gOIhw16i16o = dnnl_gOIhw16i16o,
1900  gOIhw16o16i = dnnl_gOIhw16o16i,
1901  gOihw16o = dnnl_gOihw16o,
1902  gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1903  gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1904  gOIhw4i4o = dnnl_gOIhw4i4o,
1905  gOIhw4o4i = dnnl_gOIhw4o4i,
1906  gOihw4o = dnnl_gOihw4o,
1907  Goihw8g = dnnl_Goihw8g,
1908  gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1909  gOIhw8i8o = dnnl_gOIhw8i8o,
1910  gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1911  OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
1912  OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
1913  OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1914  OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1915  gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
1916  gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
1917  gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1918  gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1919  OIhw16i16o4i = dnnl_OIhw16i16o4i,
1920  OIhw16i32o4i = dnnl_OIhw16i32o4i,
1921  OIhw16i48o4i = dnnl_OIhw16i48o4i,
1922  OIhw16i64o4i = dnnl_OIhw16i64o4i,
1923  OIhw16i16o2i = dnnl_OIhw16i16o2i,
1924  OIhw16i32o2i = dnnl_OIhw16i32o2i,
1925  OIhw16i48o2i = dnnl_OIhw16i48o2i,
1926  OIhw16i64o2i = dnnl_OIhw16i64o2i,
1927  OIhw16o16i2o = dnnl_OIhw16o16i2o,
1928  gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
1929  gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
1930  gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
1931  gOIhw8o8i = dnnl_gOIhw8o8i,
1932  gOIhw8o4i = dnnl_gOIhw8o4i,
1933  gIOdhw16i16o = dnnl_gIOdhw16i16o,
1934  gIOdhw16o16i = dnnl_gIOdhw16o16i,
1935  gOdhwi16o = dnnl_gOdhwi16o,
1936  gOdhwI16o2i = dnnl_gOdhwI16o2i,
1937  gOdhwi4o = dnnl_gOdhwi4o,
1938  gOdhwi8o = dnnl_gOdhwi8o,
1939  gOIdhw16i16o = dnnl_gOIdhw16i16o,
1940  gOIdhw16o16i = dnnl_gOIdhw16o16i,
1941  gOidhw16o = dnnl_gOidhw16o,
1942  gOIdhw4i4o = dnnl_gOIdhw4i4o,
1943  gOIdhw4o4i = dnnl_gOIdhw4o4i,
1944  gOidhw4o = dnnl_gOidhw4o,
1945  gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1946  gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
1947  gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
1948  gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
1949  gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
1950  gOIdhw8i8o = dnnl_gOIdhw8i8o,
1951  gOIdhw8o8i = dnnl_gOIdhw8o8i,
1952  gOIdhw8o4i = dnnl_gOIdhw8o4i,
1953  gOIw2i4o2i = dnnl_gOIw2i4o2i,
1954  gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
1955  gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
1956  gOIw2o4i2o = dnnl_gOIw2o4i2o,
1957  gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
1958  gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
1959  gOIw4i8o2i = dnnl_gOIw4i8o2i,
1960  gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
1961  gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
1962  gOIw4o8i2o = dnnl_gOIw4o8i2o,
1963  gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
1964  gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
1965  ldOi32o = abDc32d,
1966  ldOI32o4i = abDC32d4c,
1967  ldgOi32o = abdEc32e,
1968  ldgOI32o2i = abdEC32e2c,
1969  ldgOI32o4i = abdEC32e4c,
1970  OwI16o4i = dnnl_OwI16o4i,
1971  OhwI16o4i = dnnl_OhwI16o4i,
1972  gOwI16o4i = dnnl_gOwI16o4i,
1973  gOhwI16o4i = dnnl_gOhwI16o4i,
1974  OdhwI16o4i = dnnl_OdhwI16o4i,
1975  gOdhwI16o4i = dnnl_gOdhwI16o4i,
1976 
1977  Owi32o = dnnl_Owi32o,
1978  OwI32o2i = dnnl_OwI32o2i,
1979  OwI32o4i = dnnl_OwI32o4i,
1980  Owi48o = dnnl_Owi48o,
1981  OwI48o2i = dnnl_OwI48o2i,
1982  OwI48o4i = dnnl_OwI48o4i,
1983  Owi64o = dnnl_Owi64o,
1984  OwI64o2i = dnnl_OwI64o2i,
1985  OwI64o4i = dnnl_OwI64o4i,
1986  wIo2i = dnnl_wIo2i,
1987  wIo4i = dnnl_wIo4i,
1988  gOwi32o = dnnl_gOwi32o,
1989  gOwI32o2i = dnnl_gOwI32o2i,
1990  gOwI32o4i = dnnl_gOwI32o4i,
1991  gOwi48o = dnnl_gOwi48o,
1992  gOwI48o2i = dnnl_gOwI48o2i,
1993  gOwI48o4i = dnnl_gOwI48o4i,
1994  gOwi64o = dnnl_gOwi64o,
1995  gOwI64o2i = dnnl_gOwI64o2i,
1996  gOwI64o4i = dnnl_gOwI64o4i,
1997  gwio = dnnl_gwio,
1998  gwIo2i = dnnl_gwIo2i,
1999  gwIo4i = dnnl_gwIo4i,
2000  OhwI32o = dnnl_OhwI32o,
2001  OhwI32o2i = dnnl_OhwI32o2i,
2002  OhwI32o4i = dnnl_OhwI32o4i,
2003  Ohwi48o = dnnl_Ohwi48o,
2004  OhwI48o2i = dnnl_OhwI48o2i,
2005  OhwI48o4i = dnnl_OhwI48o4i,
2006  Ohwi64o = dnnl_Ohwi64o,
2007  OhwI64o2i = dnnl_OhwI64o2i,
2008  OhwI64o4i = dnnl_OhwI64o4i,
2009  hwIo2i = dnnl_hwIo2i,
2010  hwIo4i = dnnl_hwIo4i,
2011  gOhwI32o = dnnl_gOhwI32o,
2012  gOhwI32o2i = dnnl_gOhwI32o2i,
2013  gOhwI32o4i = dnnl_gOhwI32o4i,
2014  gOhwi48o = dnnl_gOhwi48o,
2015  gOhwI48o2i = dnnl_gOhwI48o2i,
2016  gOhwI48o4i = dnnl_gOhwI48o4i,
2017  gOhwi64o = dnnl_gOhwi64o,
2018  gOhwI64o2i = dnnl_gOhwI64o2i,
2019  gOhwI64o4i = dnnl_gOhwI64o4i,
2020  ghwio = dnnl_ghwio,
2021  ghwIo2i = dnnl_ghwIo2i,
2022  ghwIo4i = dnnl_ghwIo4i,
2023  Odhwi32o = dnnl_Odhwi32o,
2024  OdhwI32o2i = dnnl_OdhwI32o2i,
2025  OdhwI32o4i = dnnl_OdhwI32o4i,
2026  Odhwi48o = dnnl_Odhwi48o,
2027  OdhwI48o2i = dnnl_OdhwI48o2i,
2028  OdhwI48o4i = dnnl_OdhwI48o4i,
2029  Odhwi64o = dnnl_Odhwi64o,
2030  OdhwI64o2i = dnnl_OdhwI64o2i,
2031  OdhwI64o4i = dnnl_OdhwI64o4i,
2032  dhwIo2i = dnnl_dhwIo2i,
2033  dhwIo4i = dnnl_dhwIo4i,
2034  gOdhwi32o = dnnl_gOdhwi32o,
2035  gOdhwI32o2i = dnnl_gOdhwI32o2i,
2036  gOdhwI32o4i = dnnl_gOdhwI32o4i,
2037  gOdhwi48o = dnnl_gOdhwi48o,
2038  gOdhwI48o2i = dnnl_gOdhwI48o2i,
2039  gOdhwI48o4i = dnnl_gOdhwI48o4i,
2040  gOdhwi64o = dnnl_gOdhwi64o,
2041  gOdhwI64o2i = dnnl_gOdhwI64o2i,
2042  gOdhwI64o4i = dnnl_gOdhwI64o4i,
2043  gdhwio = dnnl_gdhwio,
2044  gdhwIo2i = dnnl_gdhwIo2i,
2045  gdhwIo4i = dnnl_gdhwIo4i,
2046  };
2047 
2049  struct desc {
2050  friend struct memory;
2053 
2056  desc() : data() {}
2057 
2073  desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
2074  bool allow_empty = false)
2075  : data() {
2076  validate_dims(adims);
2078  (int)adims.size(), adims.data(), convert_to_c(adata_type),
2079  convert_to_c(aformat_tag));
2080  if (!allow_empty)
2082  "could not construct a memory descriptor using a "
2083  "format tag");
2084  }
2085 
2101  desc(const dims &adims, data_type adata_type, const dims &strides,
2102  bool allow_empty = false)
2103  : data() {
2104  validate_dims(adims);
2105  if (!strides.empty()) validate_dims(strides, (int)adims.size());
2107  (int)adims.size(), adims.data(), convert_to_c(adata_type),
2108  strides.empty() ? nullptr : &strides[0]);
2109  if (!allow_empty)
2111  "could not construct a memory descriptor using "
2112  "strides");
2113  }
2114 
2118  desc(const dnnl_memory_desc_t &data) : data(data) {}
2119 
2122  //
2131  desc submemory_desc(const dims &adims, const dims &offsets,
2132  bool allow_empty = false) const {
2133  validate_dims(adims, data.ndims);
2134  validate_dims(offsets, data.ndims);
2137  &sub_md, &data, adims.data(), offsets.data());
2138  if (!allow_empty)
2139  error::wrap_c_api(status, "could not construct a sub-memory");
2140  return desc(sub_md);
2141  }
2142 
2187  desc reshape(const dims &adims, bool allow_empty = false) const {
2188  if (data.ndims) validate_dims(adims, 1);
2191  &out_md, &data, (int)adims.size(), adims.data());
2192  if (!allow_empty)
2194  status, "could not reshape a memory descriptor");
2195  return desc(out_md);
2196  }
2197 
2235  desc permute_axes(const std::vector<int> &permutation,
2236  bool allow_empty = false) const {
2237  validate_dims(permutation, data.ndims);
2240  &out_md, &data, permutation.data());
2241  if (!allow_empty)
2243  "could not permute axes of a memory descriptor");
2244  return desc(out_md);
2245  }
2246 
2250  return static_cast<memory::data_type>(data.data_type);
2251  }
2252 
2257  memory::dims dims() const {
2258  return memory::dims(data.dims, data.dims + data.ndims);
2259  }
2260 
2265  size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
2266 
2270  bool is_zero() const { return data.ndims == 0; }
2271 
2276  bool operator==(const desc &other) const {
2277  return dnnl_memory_desc_equal(&data, &other.data) != 0;
2278  }
2279 
2284  bool operator!=(const desc &other) const { return !operator==(other); }
2285 
2289  explicit operator bool() const { return data.ndims != 0; }
2290  };
2291 
2296  memory() = default;
2297 
2317  memory(const desc &md, const engine &aengine, void *handle) {
2318  dnnl_memory_t result;
2320  dnnl_memory_create(&result, &md.data, aengine.get(), handle),
2321  "could not create a memory object");
2322  reset(result);
2323  }
2324 
2331  memory(const desc &md, const engine &aengine)
2332  : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
2333 
2335  desc get_desc() const {
2336  const dnnl_memory_desc_t *cdesc;
2338  "could not get a memory descriptor from a memory object");
2339  return desc(*cdesc);
2340  }
2341 
2343  engine get_engine() const {
2344  dnnl_engine_t c_engine;
2345  error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
2346  "could not get an engine from a memory object");
2347  return engine(c_engine, true);
2348  }
2349 
2354  void *get_data_handle() const {
2355  void *handle;
2357  "could not get a native handle from a memory object");
2358  return handle;
2359  }
2360 
2389  void set_data_handle(void *handle, const stream &astream) const {
2391  get(), handle, astream.get(true)),
2392  "could not set native handle of a memory object");
2393  }
2394 
2405  void set_data_handle(void *handle) const {
2407  dnnl_memory_set_data_handle_v2(get(), handle, nullptr),
2408  "could not set native handle of a memory object");
2409  }
2410 
2432  template <typename T = void>
2433  T *map_data() const {
2434  void *mapped_ptr;
2435  error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
2436  "could not map memory object data");
2437  return static_cast<T *>(mapped_ptr);
2438  }
2439 
2450  void unmap_data(void *mapped_ptr) const {
2451  error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
2452  "could not unmap memory object data");
2453  }
2454 
2455  static dnnl_data_type_t convert_to_c(data_type adata_type) {
2456  return static_cast<dnnl_data_type_t>(adata_type);
2457  }
2458  static dnnl_format_tag_t convert_to_c(format_tag format) {
2459  return static_cast<dnnl_format_tag_t>(format);
2460  }
2461 };
2462 
2463 inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
2464  return a == memory::convert_to_c(b);
2465 }
2466 inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
2467  return !(a == b);
2468 }
2469 inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
2470  return b == a;
2471 }
2472 inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
2473  return !(a == b);
2474 }
2475 
2476 inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
2477  return a == memory::convert_to_c(b);
2478 }
2479 inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
2480  return !(a == b);
2481 }
2482 inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
2483  return b == a;
2484 }
2485 inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
2486  return !(a == b);
2487 }
2488 
2490 
2498 
2500 template <>
2501 struct handle_traits<dnnl_post_ops_t> {
2502  static dnnl_status_t destructor(dnnl_post_ops_t p) {
2503  return dnnl_post_ops_destroy(p);
2504  }
2505 };
2507 
2515 struct post_ops : public handle<dnnl_post_ops_t> {
2517 
2520  dnnl_post_ops_t result;
2522  dnnl_post_ops_create(&result), "could not create post-ops");
2523  reset(result);
2524  }
2525 
2527  int len() const { return dnnl_post_ops_len(get()); }
2528 
2532  primitive::kind kind(int index) const {
2534  "post-ops index is out of range");
2535  return static_cast<primitive::kind>(
2536  dnnl_post_ops_get_kind(get(), index));
2537  }
2538 
2567  void append_sum(float scale = 1.f,
2569  if (data_type == memory::data_type::undef)
2571  "could not append a sum post-op");
2572  else
2574  memory::convert_to_c(data_type)),
2575  "could not append a sum post-op");
2576  }
2577 
2582  void get_params_sum(int index, float &scale) const {
2584  "could not get parameters of a sum post-op");
2585  }
2586 
2593  int index, float &scale, memory::data_type &data_type) const {
2594  dnnl_data_type_t c_data_type;
2596  get(), index, &scale, &c_data_type),
2597  "could not get parameters of a sum post-op");
2598  data_type = static_cast<memory::data_type>(c_data_type);
2599  }
2600 
2615  float scale, algorithm aalgorithm, float alpha, float beta) {
2617  convert_to_c(aalgorithm), alpha, beta),
2618  "could not append an elementwise post-op");
2619  }
2620 
2628  void get_params_eltwise(int index, float &scale, algorithm &aalgorithm,
2629  float &alpha, float &beta) const {
2630  dnnl_alg_kind_t c_alg;
2632  get(), index, &scale, &c_alg, &alpha, &beta),
2633  "could not get parameters of an elementwise post-op");
2634  aalgorithm = static_cast<dnnl::algorithm>(c_alg);
2635  }
2636 
2665  void append_dw_k3s1p1(memory::data_type weights_data_type,
2666  memory::data_type bias_data_type, memory::data_type dst_data_type,
2667  int mask, const std::vector<float> &scales) {
2668 
2670  memory::convert_to_c(weights_data_type),
2671  memory::convert_to_c(bias_data_type),
2672  memory::convert_to_c(dst_data_type),
2673  scales.size(), mask, &scales[0]),
2674  "could not append depthwise post-op");
2675  }
2676 
2691  void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type,
2692  memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2693  int &mask, std::vector<float> &scales) const {
2694 
2695  dnnl_data_type_t c_weights_data_type;
2696  dnnl_data_type_t c_bias_data_type;
2697  dnnl_data_type_t c_dst_data_type;
2698  dnnl_dim_t count;
2699  int c_mask;
2700  const float *c_scales;
2702  &c_weights_data_type, &c_bias_data_type,
2703  &c_dst_data_type, &count, &c_mask, &c_scales),
2704  "could not get parameters of depthwise post-op");
2705 
2706  weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2707  bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2708  dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2709  scales.resize(count);
2710 
2711  mask = c_mask;
2712  for (dnnl_dim_t c = 0; c < count; ++c)
2713  scales[c] = c_scales[c];
2714  return;
2715  }
2716 
2750  void append_dw_k3s2p1(memory::data_type weights_data_type,
2751  memory::data_type bias_data_type, memory::data_type dst_data_type,
2752  int mask, const std::vector<float> &scales) {
2753 
2755  memory::convert_to_c(weights_data_type),
2756  memory::convert_to_c(bias_data_type),
2757  memory::convert_to_c(dst_data_type),
2758  scales.size(), mask, &scales[0]),
2759  "could not append depthwise post-op");
2760  }
2761 
2776  void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type,
2777  memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2778  int &mask, std::vector<float> &scales) const {
2779 
2780  dnnl_data_type_t c_weights_data_type;
2781  dnnl_data_type_t c_bias_data_type;
2782  dnnl_data_type_t c_dst_data_type;
2783  dnnl_dim_t count;
2784  int c_mask;
2785  const float *c_scales;
2787  &c_weights_data_type, &c_bias_data_type,
2788  &c_dst_data_type, &count, &c_mask, &c_scales),
2789  "could not get parameters of depthwise post-op");
2790 
2791  weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2792  bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2793  dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2794  scales.resize(count);
2795 
2796  mask = c_mask;
2797  for (dnnl_dim_t c = 0; c < count; ++c)
2798  scales[c] = c_scales[c];
2799  return;
2800  }
2801 
2816  void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
2818  convert_to_c(aalgorithm), &src1_desc.data),
2819  "could not append a binary post-op");
2820  }
2821 
2828  int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
2829  dnnl_alg_kind_t c_alg;
2830  const dnnl_memory_desc_t *data;
2832  dnnl_post_ops_get_params_binary(get(), index, &c_alg, &data),
2833  "could not get parameters of a binary post-op");
2834  aalgorithm = static_cast<dnnl::algorithm>(c_alg);
2835  src1_desc.data = *data;
2836  }
2837 };
2838 
2840 template <>
2841 struct handle_traits<dnnl_primitive_attr_t> {
2842  static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
2843  return dnnl_primitive_attr_destroy(p);
2844  }
2845 };
2847 
2851 struct primitive_attr : public handle<dnnl_primitive_attr_t> {
2853 
2856  dnnl_primitive_attr_t result;
2858  "could not create primitive attribute");
2859  reset(result);
2860  }
2861 
2868  : handle<dnnl_primitive_attr_t>(attr) {}
2869 
2872  dnnl_scratchpad_mode_t result;
2875  "could not get scratchpad mode primitive attribute");
2876  return scratchpad_mode(result);
2877  }
2878 
2884  get(), dnnl::convert_to_c(mode)),
2885  "could not set scratchpad mode primitive attribute");
2886  }
2887 
2897  void get_output_scales(int &mask, std::vector<float> &scales) const {
2898  dnnl_dim_t count;
2899  int c_mask;
2900  const float *c_scales;
2902  get(), &count, &c_mask, &c_scales),
2903  "could not get output scales primitive attribute");
2904  scales.resize(count);
2905 
2906  mask = c_mask;
2907  for (dnnl_dim_t c = 0; c < count; ++c)
2908  scales[c] = c_scales[c];
2909  }
2910 
2953  void set_output_scales(int mask, const std::vector<float> &scales) {
2956  get(), (dnnl_dim_t)scales.size(), mask, scales.data()),
2957  "could not set output scales primitive attribute");
2958  }
2959 
2971  void get_scales(int arg, int &mask, std::vector<float> &scales) const {
2972  dnnl_dim_t count;
2973  int c_mask;
2974  const float *c_scales;
2976  get(), arg, &count, &c_mask, &c_scales),
2977  "could not get scales primitive attributes");
2978  scales.resize(count);
2979 
2980  mask = c_mask;
2981  for (dnnl_dim_t c = 0; c < count; ++c)
2982  scales[c] = c_scales[c];
2983  }
2984 
3001  void set_scales(int arg, int mask, const std::vector<float> &scales) {
3004  (dnnl_dim_t)scales.size(), mask, scales.data()),
3005  "could not set scales primitive attribute");
3006  }
3007 
3019  int arg, int &mask, std::vector<int32_t> &zero_points) const {
3020  dnnl_dim_t count;
3021  int c_mask;
3022  const int32_t *c_zero_points;
3024  get(), arg, &count, &c_mask, &c_zero_points),
3025  "could not get zero points primitive attribute");
3026  zero_points.resize(count);
3027 
3028  mask = c_mask;
3029  for (dnnl_dim_t c = 0; c < count; ++c)
3030  zero_points[c] = c_zero_points[c];
3031  }
3032 
3054  int arg, int mask, const std::vector<int32_t> &zero_points) {
3056  (dnnl_dim_t)zero_points.size(), mask,
3057  zero_points.data()),
3058  "could not set zero points primitive attribute");
3059  }
3060 
3064  const post_ops get_post_ops() const {
3065  post_ops result;
3066  const_dnnl_post_ops_t c_result;
3068  "could not get post-ops primitive attribute");
3069  result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
3070  return result;
3071  }
3072 
3081  void set_post_ops(const post_ops ops) {
3083  "could not set post-ops primitive attribute");
3084  }
3085 
3119  void set_rnn_data_qparams(float scale, float shift) {
3122  "could not set RNN data quantization parameters primitive "
3123  "attribute");
3124  }
3125 
3135  void get_rnn_data_qparams(float &scale, float &shift) {
3136  float c_scale, c_shift;
3138  get(), &c_scale, &c_shift),
3139  "could not set RNN data quantization parameters primitive "
3140  "attribute");
3141  scale = c_scale;
3142  shift = c_shift;
3143  }
3144 
3171  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
3173  (int)scales.size(), mask, scales.data()),
3174  "could not set RNN weights quantization parameters primitive "
3175  "attribute");
3176  }
3177 
3197  void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
3198  dnnl_dim_t count;
3199  int c_mask;
3200  const float *c_scales;
3202  get(), &count, &c_mask, &c_scales),
3203  "could not get primitive RNN weights quantization "
3204  "parameters attributes");
3205  scales.resize(count);
3206 
3207  mask = c_mask;
3208  for (dnnl_dim_t c = 0; c < count; c++)
3209  scales[c] = c_scales[c];
3210  }
3211 
3213  // The low-precision configuration of the RNN primitives expect input
3214  // weights to use the signed 8-bit integer data type. The scaling factors
3215  // are used to quantize floating-point data to signed integer and must be
3239  int mask, const std::vector<float> &scales) {
3242  get(), (int)scales.size(), mask, scales.data()),
3243  "could not set primitive RNN weights projection quantization "
3244  "parameters attributes");
3245  }
3246 
3267  int &mask, std::vector<float> &scales) {
3268  dnnl_dim_t count;
3269  int c_mask;
3270  const float *c_scales;
3273  get(), &count, &c_mask, &c_scales),
3274  "could not get primitive RNN weights projection quantization "
3275  "parameters attributes");
3276  scales.resize(count);
3277 
3278  mask = c_mask;
3279  for (dnnl_dim_t c = 0; c < count; c++)
3280  scales[c] = c_scales[c];
3281  }
3282 };
3283 
3285 
3288 
3290 struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
3292 
3294  primitive_desc_base() = default;
3295 
3298  engine get_engine() const { return engine::query(*this); }
3299 
3302  const char *impl_info_str() const {
3303  const char *res;
3305  get(), dnnl_query_impl_info_str, 0, &res),
3306  "could not retrieve implementation info string from a "
3307  "primitive descriptor");
3308  return res;
3309  }
3310 
3315  memory::dim res;
3317  get(), dnnl::convert_to_c(what), 0, &res);
3318  return status == dnnl_success ? res : 0;
3319  }
3320 
3335  memory::desc query_md(query what, int idx = 0) const {
3336  std::vector<query> valid_q {query::src_md, query::diff_src_md,
3340  if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
3341  [=](query q) { return what == q; }))
3342  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3343  "memory descriptor query is invalid");
3344 
3346  get(), dnnl::convert_to_c(what), idx);
3347  return cdesc ? memory::desc(*cdesc) : memory::desc();
3348  }
3349 
3355  memory::desc src_desc(int idx) const {
3356  return query_md(query::src_md, idx);
3357  }
3358 
3364  memory::desc dst_desc(int idx) const {
3365  return query_md(query::dst_md, idx);
3366  }
3367 
3373  memory::desc weights_desc(int idx) const {
3374  return query_md(query::weights_md, idx);
3375  }
3376 
3382  memory::desc diff_src_desc(int idx) const {
3383  return query_md(query::diff_src_md, idx);
3384  }
3385 
3391  memory::desc diff_dst_desc(int idx) const {
3392  return query_md(query::diff_dst_md, idx);
3393  }
3394 
3401  return query_md(query::diff_weights_md, idx);
3402  }
3403 
3404  // Separate versions without the index argument for documentation
3405  // purposes.
3406 
3411  memory::desc src_desc() const { return src_desc(0); }
3412 
3417  memory::desc dst_desc() const { return dst_desc(0); }
3418 
3423  memory::desc weights_desc() const { return weights_desc(0); }
3424 
3430 
3436 
3442 
3448  return query_md(query::workspace_md, 0);
3449  }
3450 
3457  return query_md(query::scratchpad_md, 0);
3458  }
3459 
3463  dnnl_engine_t c_engine;
3466  0, &c_engine),
3467  "could not retrieve scratchpad engine from a primitive "
3468  "descriptor");
3469  return engine(c_engine, true);
3470  }
3471 
3475  const_dnnl_primitive_attr_t const_c_attr;
3477  "could not get attributes from a primitive descriptor");
3478  dnnl_primitive_attr_t c_attr;
3479  error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
3480  "could not clone primitive attributes");
3481  return primitive_attr(c_attr);
3482  }
3483 
3487  dnnl_primitive_kind_t kind;
3489  dnnl_query_primitive_kind, 0, (void *)&kind),
3490  "could not get primitive kind from a primitive descriptor");
3491  return static_cast<dnnl::primitive::kind>(kind);
3492  }
3493 
3494 protected:
3499  dnnl_primitive_desc_t new_pd;
3501  "could not clone a primitive descriptor");
3502  reset(new_pd);
3503  }
3504 
3520  : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
3521 
3534  dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
3535  : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
3536 
3551  dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
3552  dnnl::prop_kind prop_kind2) {
3553  // It is OK to pass an empty primitive descriptor
3554  if (pd == nullptr) return;
3555 
3556  dnnl_status_t rc;
3557 
3558  dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
3559  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
3560  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
3561 
3562  // Check that primitive kind matches
3563  dnnl_primitive_kind_t pd_kind;
3565  pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
3567  rc, "could not get primitive kind from a primitive descriptor");
3568  if (pd_kind != c_prim_kind)
3569  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3570  "primitive descriptor operation kind mismatch");
3571 
3572  // Check that propagation kind matches
3573  dnnl_prop_kind_t pd_prop_kind;
3575  pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
3576 
3577  // Something went wrong
3578  if (rc != dnnl_success && rc != dnnl_unimplemented)
3579  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3580  "could not get propagation kind from the primitive "
3581  "descriptor");
3582 
3583  // Everything is fine
3584  if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
3585  || (rc == dnnl_success
3586  && (pd_prop_kind == c_prop_kind1
3587  || pd_prop_kind == c_prop_kind2))) {
3588  reset_with_clone(pd);
3589  return;
3590  }
3591 
3592  // We could get the propagation kind but there is a mismatch
3593  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3594  "primitive descriptor propagation kind mismatch");
3595  }
3596 
3597  using base = primitive_desc_base;
3598 };
3599 
3601 
3610 
3612 struct reorder : public primitive {
3616 
3618  primitive_desc() = default;
3619 
3637  primitive_desc(const engine &src_engine, const memory::desc &src_md,
3638  const engine &dst_engine, const memory::desc &dst_md,
3639  const primitive_attr &attr = primitive_attr(),
3640  bool allow_empty = false) {
3641  dnnl_primitive_desc_t result;
3643  &src_md.data, src_engine.get(), &dst_md.data,
3644  dst_engine.get(), attr.get());
3645  if (!allow_empty)
3647  "could not create a primitive descriptor for a reorder "
3648  "primitive");
3650  }
3651 
3663  primitive_desc(const memory &src, const memory &dst,
3664  const primitive_attr &attr = primitive_attr(),
3665  bool allow_empty = false) {
3666  dnnl_primitive_desc_t result;
3667  auto src_md = src.get_desc();
3668  auto dst_md = dst.get_desc();
3670  &src_md.data, src.get_engine().get(), &dst_md.data,
3671  dst.get_engine().get(), attr.get());
3672  if (!allow_empty)
3674  "could not create a primitive descriptor for a reorder "
3675  "primitive");
3677  }
3678 
3685 
3690  }
3691 
3696  }
3697 
3699  memory::desc src_desc() const { return base::src_desc(0); }
3700 
3702  memory::desc dst_desc() const { return base::dst_desc(0); }
3703  };
3704 
3706  reorder() = default;
3707 
3710  reorder(const primitive_desc &pd) : primitive(pd.get()) {}
3711 
3719  reorder(const memory &src, const memory &dst,
3720  const primitive_attr &attr = primitive_attr())
3721  : primitive(primitive_desc(src, dst, attr).get()) {}
3722 
3723  using primitive::execute;
3724 
3731  void execute(const stream &astream, memory &src, memory &dst) const {
3732  primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
3733  }
3734 };
3735 
3737 
3745 
3747 inline std::vector<dnnl_memory_desc_t> convert_to_c(
3748  const std::vector<memory::desc> &mems) {
3749  std::vector<dnnl_memory_desc_t> c_mems;
3750  c_mems.reserve(mems.size());
3751  for (const auto &s : mems)
3752  c_mems.push_back(s.data);
3753  return c_mems;
3754 }
3756 
3758 struct concat : public primitive {
3762 
3764  primitive_desc() = default;
3765 
3776  primitive_desc(const memory::desc &dst, int concat_dimension,
3777  const std::vector<memory::desc> &srcs, const engine &aengine,
3778  const primitive_attr &attr = primitive_attr()) {
3779  auto c_srcs = convert_to_c(srcs);
3780 
3781  dnnl_primitive_desc_t result;
3784  (int)c_srcs.size(), concat_dimension, c_srcs.data(),
3785  attr.get(), aengine.get()),
3786  "could not create a primitive descriptor for a concat "
3787  "primitive");
3788  reset(result);
3789  }
3790 
3803  primitive_desc(int concat_dimension,
3804  const std::vector<memory::desc> &srcs, const engine &aengine,
3805  const primitive_attr &attr = primitive_attr()) {
3806  auto c_api_srcs = convert_to_c(srcs);
3807 
3808  dnnl_primitive_desc_t result;
3810  dnnl_concat_primitive_desc_create(&result, nullptr,
3811  (int)c_api_srcs.size(), concat_dimension,
3812  c_api_srcs.data(), attr.get(), aengine.get()),
3813  "could not create a primitive descriptor for a concat "
3814  "primitive");
3815  reset(result);
3816  }
3817 
3824 
3826  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3827 
3829  memory::desc dst_desc() const { return base::dst_desc(0); }
3830  };
3831 
3833  concat() = default;
3834 
3837  concat(const primitive_desc &pd) : primitive(pd.get()) {}
3838 };
3839 
3841 
3849 
3851 struct sum : public primitive {
3855 
3857  primitive_desc() = default;
3858 
3868  const std::vector<float> &scales,
3869  const std::vector<memory::desc> &srcs, const engine &aengine,
3870  const primitive_attr &attr = primitive_attr()) {
3871  validate_container_size(scales,
3872  "counts of scales and sources are not equal",
3873  (int)srcs.size(), (int)srcs.size());
3874 
3875  auto c_api_srcs = convert_to_c(srcs);
3876 
3877  dnnl_primitive_desc_t result;
3879  dnnl_sum_primitive_desc_create(&result, &dst.data,
3880  (int)c_api_srcs.size(), scales.data(),
3881  c_api_srcs.data(), attr.get(), aengine.get()),
3882  "could not create a primitive descriptor for a sum "
3883  "primitive");
3884  reset(result);
3885  }
3886 
3897  primitive_desc(const std::vector<float> &scales,
3898  const std::vector<memory::desc> &srcs, const engine &aengine,
3899  const primitive_attr &attr = primitive_attr()) {
3900  validate_container_size(scales,
3901  "counts of scales and sources are not equal",
3902  (int)srcs.size(), (int)srcs.size());
3903 
3904  auto c_api_srcs = convert_to_c(srcs);
3905  dnnl_primitive_desc_t result;
3907  dnnl_sum_primitive_desc_create(&result, nullptr,
3908  (int)c_api_srcs.size(), scales.data(),
3909  c_api_srcs.data(), attr.get(), aengine.get()),
3910  "could not create a primitive descriptor for a sum "
3911  "primitive");
3912  reset(result);
3913  }
3914 
3921 
3923  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3924 
3926  memory::desc dst_desc() const { return base::dst_desc(0); }
3927  };
3928 
3930  sum() = default;
3931 
3934  sum(const primitive_desc &pd) : primitive(pd.get()) {}
3935 };
3936 
3938 
3941 
3946 
3947  primitive_desc() = default;
3948 
3972  const engine &aengine, const_dnnl_primitive_desc_t hint_fwd_pd,
3973  bool allow_empty = false)
3974  : allow_empty_(allow_empty) {
3975  dnnl_primitive_desc_iterator_t iterator = nullptr;
3977  desc, attr ? attr->get() : nullptr, aengine.get(), hint_fwd_pd);
3978  if (!allow_empty)
3980  status, "could not create a primitive descriptor iterator");
3981  pd_iterator.reset(iterator);
3982  fetch_impl();
3983  }
3984 
3989  bool next_impl() {
3991  = dnnl_primitive_desc_iterator_next(pd_iterator.get());
3992  if (status == dnnl_iterator_ends) return false;
3994  status, "could not advance a primitive descriptor iterator");
3995  fetch_impl();
3996  return true;
3997  }
3998 
3999 private:
4000  bool allow_empty_ = false;
4002  void fetch_impl() {
4004  pd_iterator.get(allow_empty_));
4005  error::wrap_c_api(pd != nullptr || allow_empty_ ? dnnl_success
4007  "could not fetch a primitive descriptor from a primitive "
4008  "descriptor iterator");
4009  reset(pd);
4010  }
4011 };
4012 
4014 
4024 
4028  struct desc {
4030 
4061  desc(prop_kind aprop_kind, algorithm aalgorithm,
4062  const memory::desc &src_desc, const memory::desc &weights_desc,
4063  const memory::desc &bias_desc, const memory::desc &dst_desc,
4064  const memory::dims &strides, const memory::dims &padding_l,
4065  const memory::dims &padding_r) {
4066  memory::validate_dims(strides, src_desc.data.ndims - 2);
4067  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4068  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4071  dnnl::convert_to_c(aprop_kind),
4072  convert_to_c(aalgorithm), &src_desc.data,
4073  &weights_desc.data, &bias_desc.data, &dst_desc.data,
4074  &strides[0], &padding_l[0], &padding_r[0]),
4075  "could not create a descriptor for a convolution forward "
4076  "propagation primitive");
4077  }
4078 
4107  desc(prop_kind aprop_kind, algorithm aalgorithm,
4108  const memory::desc &src_desc, const memory::desc &weights_desc,
4109  const memory::desc &dst_desc, const memory::dims &strides,
4110  const memory::dims &padding_l, const memory::dims &padding_r) {
4111  memory::validate_dims(strides, src_desc.data.ndims - 2);
4112  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4113  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4116  dnnl::convert_to_c(aprop_kind),
4117  convert_to_c(aalgorithm), &src_desc.data,
4118  &weights_desc.data, nullptr, &dst_desc.data,
4119  &strides[0], &padding_l[0], &padding_r[0]),
4120  "could not create a descriptor for a convolution forward "
4121  "propagation primitive");
4122  }
4123 
4156  desc(prop_kind aprop_kind, algorithm aalgorithm,
4157  const memory::desc &src_desc, const memory::desc &weights_desc,
4158  const memory::desc &bias_desc, const memory::desc &dst_desc,
4159  const memory::dims &strides, const memory::dims &dilates,
4160  const memory::dims &padding_l, const memory::dims &padding_r) {
4161  memory::validate_dims(strides, src_desc.data.ndims - 2);
4162  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4163  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4164  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4166  dnnl::convert_to_c(aprop_kind),
4167  convert_to_c(aalgorithm), &src_desc.data,
4168  &weights_desc.data, &bias_desc.data,
4169  &dst_desc.data, &strides[0], &dilates[0],
4170  &padding_l[0], &padding_r[0]),
4171  "could not create a descriptor for a dilated convolution "
4172  "forward propagation primitive");
4173  }
4174 
4205  desc(prop_kind aprop_kind, algorithm aalgorithm,
4206  const memory::desc &src_desc, const memory::desc &weights_desc,
4207  const memory::desc &dst_desc, const memory::dims &strides,
4208  const memory::dims &dilates, const memory::dims &padding_l,
4209  const memory::dims &padding_r) {
4210  memory::validate_dims(strides, src_desc.data.ndims - 2);
4211  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4212  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4213  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4215  dnnl::convert_to_c(aprop_kind),
4216  convert_to_c(aalgorithm), &src_desc.data,
4217  &weights_desc.data, nullptr,
4218  &dst_desc.data, &strides[0], &dilates[0],
4219  &padding_l[0], &padding_r[0]),
4220  "could not create a descriptor for a dilated convolution "
4221  "forward propagation primitive");
4222  }
4223  };
4224 
4228  primitive_desc() = default;
4229 
4240  primitive_desc(const desc &adesc, const engine &aengine,
4241  bool allow_empty = false)
4242  : dnnl::primitive_desc(
4243  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4244 
4256  primitive_desc(const desc &adesc, const primitive_attr &attr,
4257  const engine &aengine, bool allow_empty = false)
4258  : dnnl::primitive_desc(
4259  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4260 
4268  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4271 
4273  memory::desc src_desc() const { return base::src_desc(0); }
4274 
4277 
4279  memory::desc dst_desc() const { return base::dst_desc(0); }
4280 
4286  };
4287 
4289  convolution_forward() = default;
4290 
4295 };
4296 
4299 
4301  struct desc {
4303 
4329  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
4330  const memory::desc &weights_desc,
4331  const memory::desc &diff_dst_desc, const memory::dims &strides,
4332  const memory::dims &padding_l, const memory::dims &padding_r) {
4333  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4334  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4335  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4338  convert_to_c(aalgorithm), &diff_src_desc.data,
4339  &weights_desc.data, &diff_dst_desc.data,
4340  &strides[0], &padding_l[0], &padding_r[0]),
4341  "could not create a descriptor for a convolution backward "
4342  "propagation primitive");
4343  }
4344 
4372  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
4373  const memory::desc &weights_desc,
4374  const memory::desc &diff_dst_desc, const memory::dims &strides,
4375  const memory::dims &dilates, const memory::dims &padding_l,
4376  const memory::dims &padding_r) {
4377  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4378  memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
4379  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4380  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4383  convert_to_c(aalgorithm), &diff_src_desc.data,
4384  &weights_desc.data, &diff_dst_desc.data,
4385  &strides[0], &dilates[0], &padding_l[0],
4386  &padding_r[0]),
4387  "could not create a descriptor for a dilated convolution "
4388  "backward propagation primitive");
4389  }
4390  };
4391 
4395  primitive_desc() = default;
4396 
4410  primitive_desc(const desc &adesc, const engine &aengine,
4411  const convolution_forward::primitive_desc &hint_fwd_pd,
4412  bool allow_empty = false)
4413  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
4414  hint_fwd_pd.get(), allow_empty) {}
4415 
4430  primitive_desc(const desc &adesc, const primitive_attr &attr,
4431  const engine &aengine,
4432  const convolution_forward::primitive_desc &hint_fwd_pd,
4433  bool allow_empty = false)
4434  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
4435  hint_fwd_pd.get(), allow_empty) {}
4436 
4444  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4446 
4449 
4452 
4455  };
4456 
4459 
4464 };
4465 
4469  struct desc {
4471 
4499  desc(algorithm aalgorithm, const memory::desc &src_desc,
4500  const memory::desc &diff_weights_desc,
4501  const memory::desc &diff_bias_desc,
4502  const memory::desc &diff_dst_desc, const memory::dims &strides,
4503  const memory::dims &padding_l, const memory::dims &padding_r) {
4504  memory::validate_dims(strides, src_desc.data.ndims - 2);
4505  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4506  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4509  convert_to_c(aalgorithm), &src_desc.data,
4510  &diff_weights_desc.data, &diff_bias_desc.data,
4511  &diff_dst_desc.data, &strides[0], &padding_l[0],
4512  &padding_r[0]),
4513  "could not create a descriptor for a convolution weights "
4514  "update primitive");
4515  }
4516 
4542  desc(algorithm aalgorithm, const memory::desc &src_desc,
4543  const memory::desc &diff_weights_desc,
4544  const memory::desc &diff_dst_desc, const memory::dims &strides,
4545  const memory::dims &padding_l, const memory::dims &padding_r) {
4546  memory::validate_dims(strides, src_desc.data.ndims - 2);
4547  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4548  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4550  convert_to_c(aalgorithm), &src_desc.data,
4551  &diff_weights_desc.data, nullptr,
4552  &diff_dst_desc.data, &strides[0],
4553  &padding_l[0], &padding_r[0]),
4554  "could not create a descriptor for a convolution weights "
4555  "update primitive");
4556  }
4557 
4587  desc(algorithm aalgorithm, const memory::desc &src_desc,
4588  const memory::desc &diff_weights_desc,
4589  const memory::desc &diff_bias_desc,
4590  const memory::desc &diff_dst_desc, const memory::dims &strides,
4591  const memory::dims &dilates, const memory::dims &padding_l,
4592  const memory::dims &padding_r) {
4593  memory::validate_dims(strides, src_desc.data.ndims - 2);
4594  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4595  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4596  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4599  convert_to_c(aalgorithm), &src_desc.data,
4600  &diff_weights_desc.data, &diff_bias_desc.data,
4601  &diff_dst_desc.data, &strides[0], &dilates[0],
4602  &padding_l[0], &padding_r[0]),
4603  "could not create a descriptor for a dilated convolution "
4604  "weights gradient primitive");
4605  }
4606 
4634  desc(algorithm aalgorithm, const memory::desc &src_desc,
4635  const memory::desc &diff_weights_desc,
4636  const memory::desc &diff_dst_desc, const memory::dims &strides,
4637  const memory::dims &dilates, const memory::dims &padding_l,
4638  const memory::dims &padding_r) {
4639  memory::validate_dims(strides, src_desc.data.ndims - 2);
4640  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4641  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4642  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4645  convert_to_c(aalgorithm), &src_desc.data,
4646  &diff_weights_desc.data, nullptr,
4647  &diff_dst_desc.data, &strides[0], &dilates[0],
4648  &padding_l[0], &padding_r[0]),
4649  "could not create a descriptor for a dilated convolution "
4650  "weights gradient primitive");
4651  }
4652  };
4653 
4657  primitive_desc() = default;
4658 
4671  primitive_desc(const desc &adesc, const engine &aengine,
4672  const convolution_forward::primitive_desc &hint_fwd_pd,
4673  bool allow_empty = false)
4674  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
4675  hint_fwd_pd.get(), allow_empty) {}
4676 
4690  primitive_desc(const desc &adesc, const primitive_attr &attr,
4691  const engine &aengine,
4692  const convolution_forward::primitive_desc &hint_fwd_pd,
4693  bool allow_empty = false)
4694  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
4695  hint_fwd_pd.get(), allow_empty) {}
4696 
4704  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4706 
4708  memory::desc src_desc() const { return base::src_desc(0); }
4709 
4712  return base::diff_weights_desc(0);
4713  }
4714 
4717 
4723  return base::diff_weights_desc(1);
4724  }
4725  };
4726 
4729 
4734 };
4735 
4737 //
4745 
4749  struct desc {
4751 
4781  desc(prop_kind aprop_kind, algorithm aalgorithm,
4782  const memory::desc &src_desc, const memory::desc &weights_desc,
4783  const memory::desc &bias_desc, const memory::desc &dst_desc,
4784  const memory::dims &strides, const memory::dims &padding_l,
4785  const memory::dims &padding_r) {
4786  memory::validate_dims(strides, src_desc.data.ndims - 2);
4787  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4788  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4791  dnnl::convert_to_c(aprop_kind),
4792  convert_to_c(aalgorithm), &src_desc.data,
4793  &weights_desc.data, &bias_desc.data, &dst_desc.data,
4794  &strides[0], &padding_l[0], &padding_r[0]),
4795  "could not create a descriptor for a deconvolution forward "
4796  "propagation primitive");
4797  }
4798 
4826  desc(prop_kind aprop_kind, algorithm aalgorithm,
4827  const memory::desc &src_desc, const memory::desc &weights_desc,
4828  const memory::desc &dst_desc, const memory::dims &strides,
4829  const memory::dims &padding_l, const memory::dims &padding_r) {
4830  memory::validate_dims(strides, src_desc.data.ndims - 2);
4831  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4832  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4835  dnnl::convert_to_c(aprop_kind),
4836  convert_to_c(aalgorithm), &src_desc.data,
4837  &weights_desc.data, nullptr, &dst_desc.data,
4838  &strides[0], &padding_l[0], &padding_r[0]),
4839  "could not create a descriptor for a deconvolution forward "
4840  "propagation primitive");
4841  }
4842 
4874  desc(prop_kind aprop_kind, algorithm aalgorithm,
4875  const memory::desc &src_desc, const memory::desc &weights_desc,
4876  const memory::desc &bias_desc, const memory::desc &dst_desc,
4877  const memory::dims &strides, const memory::dims &dilates,
4878  const memory::dims &padding_l, const memory::dims &padding_r) {
4879  memory::validate_dims(strides, src_desc.data.ndims - 2);
4880  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4881  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4882  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4884  &data, dnnl::convert_to_c(aprop_kind),
4885  convert_to_c(aalgorithm), &src_desc.data,
4886  &weights_desc.data, &bias_desc.data,
4887  &dst_desc.data, &strides[0], &dilates[0],
4888  &padding_l[0], &padding_r[0]),
4889  "could not create a descriptor for a dilated deconvolution "
4890  "forward propagation primitive");
4891  }
4892 
4922  desc(prop_kind aprop_kind, algorithm aalgorithm,
4923  const memory::desc &src_desc, const memory::desc &weights_desc,
4924  const memory::desc &dst_desc, const memory::dims &strides,
4925  const memory::dims &dilates, const memory::dims &padding_l,
4926  const memory::dims &padding_r) {
4927  memory::validate_dims(strides, src_desc.data.ndims - 2);
4928  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4929  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4930  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4932  &data, dnnl::convert_to_c(aprop_kind),
4933  convert_to_c(aalgorithm), &src_desc.data,
4934  &weights_desc.data, nullptr,
4935  &dst_desc.data, &strides[0], &dilates[0],
4936  &padding_l[0], &padding_r[0]),
4937  "could not create a descriptor for a dilated deconvolution "
4938  "forward propagation primitive");
4939  }
4940  };
4941 
4945  primitive_desc() = default;
4946 
4957  primitive_desc(const desc &adesc, const engine &aengine,
4958  bool allow_empty = false)
4959  : dnnl::primitive_desc(
4960  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4961 
4973  primitive_desc(const desc &adesc, const primitive_attr &attr,
4974  const engine &aengine, bool allow_empty = false)
4975  : dnnl::primitive_desc(
4976  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4977 
4985  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
4988 
4990  memory::desc src_desc() const { return base::src_desc(0); }
4991 
4994 
4996  memory::desc dst_desc() const { return base::dst_desc(0); }
4997 
5000  };
5001 
5004 
5009 };
5010 
5014  struct desc {
5016 
5041  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5042  const memory::desc &weights_desc,
5043  const memory::desc &diff_dst_desc, const memory::dims &strides,
5044  const memory::dims &padding_l, const memory::dims &padding_r) {
5045  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5046  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5047  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5050  convert_to_c(aalgorithm), &diff_src_desc.data,
5051  &weights_desc.data, &diff_dst_desc.data,
5052  &strides[0], &padding_l[0], &padding_r[0]),
5053  "could not create a descriptor for a deconvolution "
5054  "backward propagation primitive");
5055  }
5056 
5083  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5084  const memory::desc &weights_desc,
5085  const memory::desc &diff_dst_desc, const memory::dims &strides,
5086  const memory::dims &dilates, const memory::dims &padding_l,
5087  const memory::dims &padding_r) {
5088  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5089  memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
5090  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5091  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5094  convert_to_c(aalgorithm), &diff_src_desc.data,
5095  &weights_desc.data, &diff_dst_desc.data,
5096  &strides[0], &dilates[0], &padding_l[0],
5097  &padding_r[0]),
5098  "could not create a descriptor for a dilated deconvolution "
5099  "backward propagation primitive");
5100  }
5101  };
5102 
5106  primitive_desc() = default;
5107 
5121  primitive_desc(const desc &adesc, const engine &aengine,
5122  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5123  bool allow_empty = false)
5124  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5125  hint_fwd_pd.get(), allow_empty) {}
5126 
5141  primitive_desc(const desc &adesc, const primitive_attr &attr,
5142  const engine &aengine,
5143  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5144  bool allow_empty = false)
5145  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5146  hint_fwd_pd.get(), allow_empty) {}
5147 
5155  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5157 
5160 
5163 
5166  };
5167 
5170 
5175 };
5176 
5180  struct desc {
5182 
5209  desc(algorithm aalgorithm, const memory::desc &src_desc,
5210  const memory::desc &diff_weights_desc,
5211  const memory::desc &diff_bias_desc,
5212  const memory::desc &diff_dst_desc, const memory::dims &strides,
5213  const memory::dims &padding_l, const memory::dims &padding_r) {
5214  memory::validate_dims(strides, src_desc.data.ndims - 2);
5215  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5216  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5219  convert_to_c(aalgorithm), &src_desc.data,
5220  &diff_weights_desc.data, &diff_bias_desc.data,
5221  &diff_dst_desc.data, &strides[0], &padding_l[0],
5222  &padding_r[0]),
5223  "could not create a descriptor for a deconvolution weights "
5224  "update primitive");
5225  }
5226 
5251  desc(algorithm aalgorithm, const memory::desc &src_desc,
5252  const memory::desc &diff_weights_desc,
5253  const memory::desc &diff_dst_desc, const memory::dims &strides,
5254  const memory::dims &padding_l, const memory::dims &padding_r) {
5255  memory::validate_dims(strides, src_desc.data.ndims - 2);
5256  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5257  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5259  &data, convert_to_c(aalgorithm),
5260  &src_desc.data, &diff_weights_desc.data,
5261  nullptr, &diff_dst_desc.data, &strides[0],
5262  &padding_l[0], &padding_r[0]),
5263  "could not create a descriptor for a deconvolution weights "
5264  "update primitive");
5265  }
5266 
5295  desc(algorithm aalgorithm, const memory::desc &src_desc,
5296  const memory::desc &diff_weights_desc,
5297  const memory::desc &diff_bias_desc,
5298  const memory::desc &diff_dst_desc, const memory::dims &strides,
5299  const memory::dims &dilates, const memory::dims &padding_l,
5300  const memory::dims &padding_r) {
5301  memory::validate_dims(strides, src_desc.data.ndims - 2);
5302  memory::validate_dims(dilates, src_desc.data.ndims - 2);
5303  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5304  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5307  convert_to_c(aalgorithm), &src_desc.data,
5308  &diff_weights_desc.data, &diff_bias_desc.data,
5309  &diff_dst_desc.data, &strides[0], &dilates[0],
5310  &padding_l[0], &padding_r[0]),
5311  "could not create a descriptor for a dilated deconvolution "
5312  "weights gradient primitive");
5313  }
5314 
5341  desc(algorithm aalgorithm, const memory::desc &src_desc,
5342  const memory::desc &diff_weights_desc,
5343  const memory::desc &diff_dst_desc, const memory::dims &strides,
5344  const memory::dims &dilates, const memory::dims &padding_l,
5345  const memory::dims &padding_r) {
5346  memory::validate_dims(strides, src_desc.data.ndims - 2);
5347  memory::validate_dims(dilates, src_desc.data.ndims - 2);
5348  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5349  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5352  convert_to_c(aalgorithm), &src_desc.data,
5353  &diff_weights_desc.data, nullptr,
5354  &diff_dst_desc.data, &strides[0], &dilates[0],
5355  &padding_l[0], &padding_r[0]),
5356  "could not create a descriptor for a dilated deconvolution "
5357  "weights gradient primitive");
5358  }
5359  };
5360 
5364  primitive_desc() = default;
5365 
5379  primitive_desc(const desc &adesc, const engine &aengine,
5380  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5381  bool allow_empty = false)
5382  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5383  hint_fwd_pd.get(), allow_empty) {}
5384 
5399  primitive_desc(const desc &adesc, const primitive_attr &attr,
5400  const engine &aengine,
5401  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5402  bool allow_empty = false)
5403  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5404  hint_fwd_pd.get(), allow_empty) {}
5405 
5413  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5415 
5417  memory::desc src_desc() const { return base::src_desc(0); }
5418 
5421  return base::diff_weights_desc(0);
5422  }
5423 
5426 
5429  return base::diff_weights_desc(1);
5430  }
5431  };
5432 
5435 
5440 };
5441 
5443 
5452 
5454 struct lrn_forward : public primitive {
5456  struct desc {
5457  dnnl_lrn_desc_t data;
5458 
5472  desc(prop_kind aprop_kind, algorithm aalgorithm,
5473  const memory::desc &data_desc, memory::dim local_size,
5474  float alpha, float beta, float k = 1.f) {
5476  dnnl::convert_to_c(aprop_kind),
5477  convert_to_c(aalgorithm), &data_desc.data,
5478  local_size, alpha, beta, k),
5479  "could not create a descriptor for a lrn forward "
5480  "propagation primitive");
5481  }
5482  };
5483 
5487  primitive_desc() = default;
5488 
5498  primitive_desc(const desc &adesc, const engine &aengine,
5499  bool allow_empty = false)
5500  : dnnl::primitive_desc(
5501  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5502 
5513  primitive_desc(const desc &adesc, const primitive_attr &attr,
5514  const engine &aengine, bool allow_empty = false)
5515  : dnnl::primitive_desc(
5516  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5517 
5525  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
5528 
5530  memory::desc src_desc() const { return base::src_desc(0); }
5531 
5533  memory::desc dst_desc() const { return base::dst_desc(0); }
5534 
5537  };
5538 
5540  lrn_forward() = default;
5541 
5546 };
5547 
5549 struct lrn_backward : public primitive {
5551  struct desc {
5552  dnnl_lrn_desc_t data;
5553 
5566  desc(algorithm aalgorithm, const memory::desc &data_desc,
5567  const memory::desc &diff_data_desc, memory::dim local_size,
5568  float alpha, float beta, float k = 1.f) {
5570  dnnl_lrn_backward_desc_init(&data, convert_to_c(aalgorithm),
5571  &diff_data_desc.data, &data_desc.data, local_size,
5572  alpha, beta, k),
5573  "could not create a descriptor for a lrn backward "
5574  "propagation primitive");
5575  }
5576  };
5577 
5581  primitive_desc() = default;
5582 
5595  primitive_desc(const desc &adesc, const engine &aengine,
5596  const lrn_forward::primitive_desc &hint_fwd_pd,
5597  bool allow_empty = false)
5598  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5599  hint_fwd_pd.get(), allow_empty) {}
5600 
5614  primitive_desc(const desc &adesc, const primitive_attr &attr,
5615  const engine &aengine,
5616  const lrn_forward::primitive_desc &hint_fwd_pd,
5617  bool allow_empty = false)
5618  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5619  hint_fwd_pd.get(), allow_empty) {}
5620 
5628  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
5630 
5633 
5636 
5639  };
5640 
5642  lrn_backward() = default;
5643 
5648 };
5649 
5651 
5659 
5661 struct pooling_forward : public primitive {
5663  struct desc {
5664  dnnl_pooling_desc_t data;
5665 
5690  desc(prop_kind aprop_kind, algorithm aalgorithm,
5691  const memory::desc &src_desc, const memory::desc &dst_desc,
5692  const memory::dims &strides, const memory::dims &kernel,
5693  const memory::dims &padding_l, const memory::dims &padding_r) {
5694  memory::validate_dims(strides, src_desc.data.ndims - 2);
5695  memory::validate_dims(kernel, src_desc.data.ndims - 2);
5696  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5697  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5699  dnnl::convert_to_c(aprop_kind),
5700  convert_to_c(aalgorithm), &src_desc.data,
5701  &dst_desc.data, &strides[0], &kernel[0],
5702  &padding_l[0], &padding_r[0]),
5703  "could not create a descriptor for a pooling forward "
5704  "propagation primitive");
5705  }
5706  };
5707 
5711  primitive_desc() = default;
5712 
5722  primitive_desc(const desc &adesc, const engine &aengine,
5723  bool allow_empty = false)
5724  : dnnl::primitive_desc(
5725  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5726 
5737  primitive_desc(const desc &adesc, const primitive_attr &attr,
5738  const engine &aengine, bool allow_empty = false)
5739  : dnnl::primitive_desc(
5740  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5741 
5749  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
5752 
5754  memory::desc src_desc() const { return base::src_desc(0); }
5755 
5757  memory::desc dst_desc() const { return base::dst_desc(0); }
5758 
5761  };
5762 
5764  pooling_forward() = default;
5765 
5770 };
5771 
5773 struct pooling_backward : public primitive {
5775  struct desc {
5776  dnnl_pooling_desc_t data;
5777 
5799  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5800  const memory::desc &diff_dst_desc, const memory::dims &strides,
5801  const memory::dims &kernel, const memory::dims &padding_l,
5802  const memory::dims &padding_r) {
5803  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5804  memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
5805  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5806  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5809  convert_to_c(aalgorithm), &diff_src_desc.data,
5810  &diff_dst_desc.data, &strides[0], &kernel[0],
5811  &padding_l[0], &padding_r[0]),
5812  "could not create a descriptor for a pooling backward "
5813  "propagation primitive");
5814  }
5815  };
5816 
5820  primitive_desc() = default;
5821 
5834  primitive_desc(const desc &adesc, const engine &aengine,
5835  const pooling_forward::primitive_desc &hint_fwd_pd,
5836  bool allow_empty = false)
5837  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5838  hint_fwd_pd.get(), allow_empty) {}
5839 
5853  primitive_desc(const desc &adesc, const primitive_attr &attr,
5854  const engine &aengine,
5855  const pooling_forward::primitive_desc &hint_fwd_pd,
5856  bool allow_empty = false)
5857  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5858  hint_fwd_pd.get(), allow_empty) {}
5859 
5867  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
5869 
5872 
5875 
5878  };
5879 
5881  pooling_backward() = default;
5882 
5887 };
5888 
5890 
5911 
5913 struct eltwise_forward : public primitive {
5915  struct desc {
5916  dnnl_eltwise_desc_t data;
5917 
5930  desc(prop_kind aprop_kind, algorithm aalgorithm,
5931  const memory::desc &data_desc, float alpha = 0,
5932  float beta = 0) {
5934  dnnl::convert_to_c(aprop_kind),
5935  dnnl::convert_to_c(aalgorithm),
5936  &data_desc.data, alpha, beta),
5937  "could not create a descriptor for an eltwise forward "
5938  "propagation primitive");
5939  }
5940  };
5941 
5945  primitive_desc() = default;
5946 
5957  primitive_desc(const desc &adesc, const engine &aengine,
5958  bool allow_empty = false)
5959  : dnnl::primitive_desc(
5960  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5961 
5973  primitive_desc(const desc &adesc, const primitive_attr &attr,
5974  const engine &aengine, bool allow_empty = false)
5975  : dnnl::primitive_desc(
5976  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5977 
5985  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
5988 
5990  memory::desc src_desc() const { return base::src_desc(0); }
5991 
5993  memory::desc dst_desc() const { return base::dst_desc(0); }
5994  };
5995 
5997  eltwise_forward() = default;
5998 
6003 };
6004 
6006 struct eltwise_backward : public primitive {
6008  struct desc {
6009  dnnl_eltwise_desc_t data;
6010 
6022  desc(algorithm aalgorithm, const memory::desc &diff_data_desc,
6023  const memory::desc &data_desc, float alpha = 0,
6024  float beta = 0) {
6027  dnnl::convert_to_c(aalgorithm),
6028  &diff_data_desc.data, &data_desc.data, alpha, beta),
6029  "could not create a descriptor for an eltwise backward "
6030  "propagation primitive");
6031  }
6032  };
6033 
6037  primitive_desc() = default;
6038 
6052  primitive_desc(const desc &adesc, const engine &aengine,
6053  const eltwise_forward::primitive_desc &hint_fwd_pd,
6054  bool allow_empty = false)
6055  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6056  hint_fwd_pd.get(), allow_empty) {}
6057 
6072  primitive_desc(const desc &adesc, const primitive_attr &attr,
6073  const engine &aengine,
6074  const eltwise_forward::primitive_desc &hint_fwd_pd,
6075  bool allow_empty = false)
6076  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6077  hint_fwd_pd.get(), allow_empty) {}
6078 
6086  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
6088 
6090  memory::desc src_desc() const { return base::src_desc(0); }
6091 
6094 
6097  };
6098 
6100  eltwise_backward() = default;
6101 
6106 };
6107 
6109 
6117 
6119 struct softmax_forward : public primitive {
6121  struct desc {
6122  dnnl_softmax_desc_t data;
6123 
6125  desc() = default;
6126 
6135  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6136  int softmax_axis) {
6138  dnnl::convert_to_c(aprop_kind),
6139  &data_desc.data, softmax_axis),
6140  "could not create a descriptor for a softmax forward "
6141  "propagation primitive");
6142  }
6143  };
6144 
6148  primitive_desc() = default;
6149 
6160  primitive_desc(const desc &adesc, const engine &aengine,
6161  bool allow_empty = false)
6162  : dnnl::primitive_desc(
6163  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6164 
6176  primitive_desc(const desc &adesc, const primitive_attr &attr,
6177  const engine &aengine, bool allow_empty = false)
6178  : dnnl::primitive_desc(
6179  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6180 
6188  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6191 
6193  memory::desc src_desc() const { return base::src_desc(0); }
6194 
6196  memory::desc dst_desc() const { return base::dst_desc(0); }
6197  };
6198 
6200  softmax_forward() = default;
6201 
6206 };
6207 
6209 struct softmax_backward : public primitive {
6211  struct desc {
6212  dnnl_softmax_desc_t data;
6213 
6215  desc() = default;
6216 
6224  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
6225  int softmax_axis) {
6227  dnnl_softmax_backward_desc_init(&data, &diff_data_desc.data,
6228  &data_desc.data, softmax_axis),
6229  "could not create a descriptor for a softmax backward "
6230  "propagation primitive");
6231  }
6232  };
6233 
6237  primitive_desc() = default;
6238 
6252  primitive_desc(const desc &adesc, const engine &aengine,
6253  const softmax_forward::primitive_desc &hint_fwd_pd,
6254  bool allow_empty = false)
6255  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6256  hint_fwd_pd.get(), allow_empty) {}
6257 
6272  primitive_desc(const desc &adesc, const primitive_attr &attr,
6273  const engine &aengine,
6274  const softmax_forward::primitive_desc &hint_fwd_pd,
6275  bool allow_empty = false)
6276  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6277  hint_fwd_pd.get(), allow_empty) {}
6278 
6286  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6288 
6290  memory::desc dst_desc() const { return base::dst_desc(0); }
6291 
6294 
6297  };
6298 
6300  softmax_backward() = default;
6301 
6306 };
6307 
6309 
6317 
6321  struct desc {
6323 
6325  desc() = default;
6326 
6335  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6336  int logsoftmax_axis) {
6338  dnnl::convert_to_c(aprop_kind),
6339  &data_desc.data, logsoftmax_axis),
6340  "could not create a descriptor for a logsoftmax forward "
6341  "propagation primitive");
6342  }
6343  };
6344 
6348  primitive_desc() = default;
6349 
6360  primitive_desc(const desc &adesc, const engine &aengine,
6361  bool allow_empty = false)
6362  : dnnl::primitive_desc(
6363  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6364 
6376  primitive_desc(const desc &adesc, const primitive_attr &attr,
6377  const engine &aengine, bool allow_empty = false)
6378  : dnnl::primitive_desc(
6379  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6380 
6388  : dnnl::primitive_desc(pd,
6389  // Logsoftmax and softmax share the implementation and
6390  // currently report the same primitive kind. Hence this
6391  // must be softmax and not logsoftmax.
6392  dnnl::primitive::kind::softmax,
6395 
6397  memory::desc src_desc() const { return base::src_desc(0); }
6398 
6400  memory::desc dst_desc() const { return base::dst_desc(0); }
6401  };
6402 
6404  logsoftmax_forward() = default;
6405 
6410 };
6411 
6415  struct desc {
6417 
6419  desc() = default;
6420 
6428  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
6429  int logsoftmax_axis) {
6431  &diff_data_desc.data, &data_desc.data,
6432  logsoftmax_axis),
6433  "could not create a descriptor for a logsoftmax backward "
6434  "propagation primitive");
6435  }
6436  };
6437 
6441  primitive_desc() = default;
6442 
6456  primitive_desc(const desc &adesc, const engine &aengine,
6457  const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6458  bool allow_empty = false)
6459  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6460  hint_fwd_pd.get(), allow_empty) {}
6461 
6476  primitive_desc(const desc &adesc, const primitive_attr &attr,
6477  const engine &aengine,
6478  const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6479  bool allow_empty = false)
6480  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6481  hint_fwd_pd.get(), allow_empty) {}
6482 
6490  : dnnl::primitive_desc(pd,
6491  // Logsoftmax and softmax share the implementation and
6492  // currently report the same primitive kind. Hence this
6493  // must be softmax and not logsoftmax.
6494  dnnl::primitive::kind::softmax,
6496 
6498  memory::desc dst_desc() const { return base::dst_desc(0); }
6499 
6502 
6505  };
6506 
6508  logsoftmax_backward() = default;
6509 
6514 };
6515 
6517 
6537 
6541  struct desc {
6543 
6558  desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon,
6559  normalization_flags flags) {
6562  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6563  epsilon, convert_to_c(flags)),
6564  "could not create a descriptor for a batch normalization "
6565  "forward propagation primitive");
6566  }
6567  };
6568 
6573  primitive_desc() = default;
6574 
6585  primitive_desc(const desc &adesc, const engine &aengine,
6586  bool allow_empty = false)
6587  : dnnl::primitive_desc(
6588  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6589 
6601  primitive_desc(const desc &adesc, const primitive_attr &attr,
6602  const engine &aengine, bool allow_empty = false)
6603  : dnnl::primitive_desc(
6604  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6605 
6613  : dnnl::primitive_desc(pd,
6614  dnnl::primitive::kind::batch_normalization,
6617 
6619  memory::desc src_desc() const { return base::src_desc(0); }
6620 
6622  memory::desc dst_desc() const { return base::dst_desc(0); }
6623 
6626 
6629 
6632  memory::desc mean_desc() const { return stat_desc(mean); }
6633 
6636  memory::desc variance_desc() const { return stat_desc(var); }
6637 
6638  private:
6639  enum {
6640  mean = 1,
6641  var = 2,
6642  };
6643  memory::desc stat_desc(int kind) const {
6648  &p),
6649  "could not retrieve a descriptor from a primitive "
6650  "descriptor for batch normalization forward propagation "
6651  "primitive");
6652  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6653  : query::dst_md,
6654  kind);
6655  }
6656  };
6657 
6660 
6665 };
6666 
6670  struct desc {
6672 
6685  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
6686  const memory::desc &data_desc, float epsilon,
6687  normalization_flags flags) {
6689  dnnl::convert_to_c(aprop_kind),
6690  &diff_data_desc.data, &data_desc.data,
6691  epsilon, convert_to_c(flags)),
6692  "could not create a descriptor for a batch normalization "
6693  "backward propagation primitive");
6694  }
6695  };
6696 
6701  primitive_desc() = default;
6702 
6716  primitive_desc(const desc &adesc, const engine &aengine,
6718  bool allow_empty = false)
6719  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6720  hint_fwd_pd.get(), allow_empty) {}
6721 
6736  primitive_desc(const desc &adesc, const primitive_attr &attr,
6737  const engine &aengine,
6739  bool allow_empty = false)
6740  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6741  hint_fwd_pd.get(), allow_empty) {}
6742 
6750  : dnnl::primitive_desc(pd,
6751  dnnl::primitive::kind::batch_normalization,
6753  }
6754 
6756  memory::desc src_desc() const { return base::src_desc(0); }
6757 
6760 
6762  memory::desc dst_desc() const { return base::dst_desc(0); }
6763 
6766 
6769 
6772  return base::diff_weights_desc(0);
6773  }
6774 
6777 
6780  return query_md(query::src_md, 2);
6781  }
6782 
6785  };
6786 
6789 
6794 };
6795 
6797 
6819 
6823  struct desc {
6825 
6837  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6838  const memory::desc &stat_desc, float epsilon,
6839  normalization_flags flags) {
6842  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6843  &stat_desc.data, epsilon, convert_to_c(flags)),
6844  "could not create a descriptor for a layer normalization "
6845  "forward propagation primitive");
6846  }
6847 
6858  desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon,
6859  normalization_flags flags) {
6862  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6863  nullptr, epsilon, convert_to_c(flags)),
6864  "could not create a descriptor for a layer normalization "
6865  "forward propagation primitive");
6866  }
6867  };
6868 
6873  primitive_desc() = default;
6874 
6885  primitive_desc(const desc &adesc, const engine &aengine,
6886  bool allow_empty = false)
6887  : dnnl::primitive_desc(
6888  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6889 
6901  primitive_desc(const desc &adesc, const primitive_attr &attr,
6902  const engine &aengine, bool allow_empty = false)
6903  : dnnl::primitive_desc(
6904  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6905 
6913  : dnnl::primitive_desc(pd,
6914  dnnl::primitive::kind::layer_normalization,
6917 
6919  memory::desc src_desc() const { return base::src_desc(0); }
6920 
6922  memory::desc dst_desc() const { return base::dst_desc(0); }
6923 
6926 
6929 
6931  memory::desc mean_desc() const { return stat_desc(mean); }
6932 
6934  memory::desc variance_desc() const { return stat_desc(var); }
6935 
6936  private:
6937  enum {
6938  mean = 1,
6939  var = 2,
6940  };
6941  memory::desc stat_desc(int kind) const {
6946  &p),
6947  "could not retrieve a descriptor from a primitive "
6948  "descriptor for layer normalization forward propagation "
6949  "primitive");
6950  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6951  : query::dst_md,
6952  kind);
6953  }
6954  };
6955 
6958 
6963 };
6964 
6968  struct desc {
6970 
6984  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
6985  const memory::desc &data_desc, const memory::desc &stat_desc,
6986  float epsilon, normalization_flags flags) {
6989  dnnl::convert_to_c(aprop_kind),
6990  &diff_data_desc.data, &data_desc.data,
6991  &stat_desc.data, epsilon, convert_to_c(flags)),
6992  "could not create a descriptor for a batch normalization "
6993  "backward propagation primitive");
6994  }
6995 
7008  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
7009  const memory::desc &data_desc, float epsilon,
7010  normalization_flags flags) {
7012  dnnl::convert_to_c(aprop_kind),
7013  &diff_data_desc.data, &data_desc.data,
7014  nullptr, epsilon, convert_to_c(flags)),
7015  "could not create a descriptor for a batch normalization "
7016  "backward propagation primitive");
7017  }
7018  };
7019 
7024  primitive_desc() = default;
7025 
7039  primitive_desc(const desc &adesc, const engine &aengine,
7041  bool allow_empty = false)
7042  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7043  hint_fwd_pd.get(), allow_empty) {}
7044 
7059  primitive_desc(const desc &adesc, const primitive_attr &attr,
7060  const engine &aengine,
7062  bool allow_empty = false)
7063  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7064  hint_fwd_pd.get(), allow_empty) {}
7065 
7073  : dnnl::primitive_desc(pd,
7074  dnnl::primitive::kind::layer_normalization,
7076  }
7077 
7079  memory::desc src_desc() const { return base::src_desc(0); }
7080 
7083 
7085  memory::desc dst_desc() const { return base::dst_desc(0); }
7086 
7089 
7092 
7095  return base::diff_weights_desc(0);
7096  }
7097 
7100 
7103  return query_md(query::src_md, 2);
7104  }
7105 
7108  };
7109 
7112 
7117 };
7118 
7120 
7128 
7132  struct desc {
7134 
7149  desc(prop_kind aprop_kind, const memory::desc &src_desc,
7150  const memory::desc &weights_desc, const memory::desc &bias_desc,
7151  const memory::desc &dst_desc) {
7153  dnnl::convert_to_c(aprop_kind),
7154  &src_desc.data, &weights_desc.data,
7155  &bias_desc.data, &dst_desc.data),
7156  "could not create a descriptor for an inner product "
7157  "forward propagation primitive");
7158  }
7159 
7173  desc(prop_kind aprop_kind, const memory::desc &src_desc,
7174  const memory::desc &weights_desc,
7175  const memory::desc &dst_desc) {
7178  dnnl::convert_to_c(aprop_kind), &src_desc.data,
7179  &weights_desc.data, nullptr, &dst_desc.data),
7180  "could not create a descriptor for an inner product "
7181  "forward propagation primitive");
7182  }
7183  };
7184 
7188  primitive_desc() = default;
7189 
7200  primitive_desc(const desc &adesc, const engine &aengine,
7201  bool allow_empty = false)
7202  : dnnl::primitive_desc(
7203  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
7204 
7216  primitive_desc(const desc &adesc, const primitive_attr &attr,
7217  const engine &aengine, bool allow_empty = false)
7218  : dnnl::primitive_desc(
7219  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7220 
7228  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7231 
7233  memory::desc src_desc() const { return base::src_desc(0); }
7234 
7237 
7239  memory::desc dst_desc() const { return base::dst_desc(0); }
7240 
7243  };
7244 
7247 
7252 };
7253 
7257  struct desc {
7259 
7270  desc(const memory::desc &diff_src_desc,
7271  const memory::desc &weights_desc,
7272  const memory::desc &diff_dst_desc) {
7274  &diff_src_desc.data, &weights_desc.data,
7275  &diff_dst_desc.data),
7276  "could not create a descriptor for an inner product "
7277  "backward propagation primitive");
7278  }
7279  };
7280 
7285  primitive_desc() = default;
7286 
7300  primitive_desc(const desc &adesc, const engine &aengine,
7301  const inner_product_forward::primitive_desc &hint_fwd_pd,
7302  bool allow_empty = false)
7303  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7304  hint_fwd_pd.get(), allow_empty) {}
7305 
7320  primitive_desc(const desc &adesc, const primitive_attr &attr,
7321  const engine &aengine,
7322  const inner_product_forward::primitive_desc &hint_fwd_pd,
7323  bool allow_empty = false)
7324  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7325  hint_fwd_pd.get(), allow_empty) {}
7326 
7334  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7336 
7339 
7342 
7345  };
7346 
7349 
7354 };
7355 
7359  struct desc {
7361 
7373  desc(const memory::desc &src_desc,
7374  const memory::desc &diff_weights_desc,
7375  const memory::desc &diff_bias_desc,
7376  const memory::desc &diff_dst_desc) {
7379  &src_desc.data, &diff_weights_desc.data,
7380  &diff_bias_desc.data, &diff_dst_desc.data),
7381  "could not create a descriptor for an inner product "
7382  "weights gradient primitive");
7383  }
7384 
7395  desc(const memory::desc &src_desc,
7396  const memory::desc &diff_weights_desc,
7397  const memory::desc &diff_dst_desc) {
7400  &src_desc.data, &diff_weights_desc.data, nullptr,
7401  &diff_dst_desc.data),
7402  "could not create a descriptor for an inner product "
7403  "weights gradient primitive");
7404  }
7405  };
7406 
7410  primitive_desc() = default;
7411 
7425  primitive_desc(const desc &adesc, const engine &aengine,
7426  const inner_product_forward::primitive_desc &hint_fwd_pd,
7427  bool allow_empty = false)
7428  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7429  hint_fwd_pd.get(), allow_empty) {}
7430 
7445  primitive_desc(const desc &adesc, const primitive_attr &attr,
7446  const engine &aengine,
7447  const inner_product_forward::primitive_desc &hint_fwd_pd,
7448  bool allow_empty = false)
7449  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7450  hint_fwd_pd.get(), allow_empty) {}
7451 
7459  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7461 
7463  memory::desc src_desc() const { return base::src_desc(0); }
7464 
7467  return base::diff_weights_desc(0);
7468  }
7469 
7472 
7475  return base::diff_weights_desc(1);
7476  }
7477  };
7478 
7481 
7486 };
7487 
7489 
7497 
7500  using primitive_desc::primitive_desc;
7501 
7504 
7513  dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
7514  : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
7515 
7520  }
7521 
7528  }
7529 
7534  }
7535 
7540  }
7541 
7546  }
7547 
7552  }
7553 
7558  }
7559 
7566  }
7567 
7572  }
7573 
7580  }
7581 
7586  }
7587 
7592  }
7593 
7600  }
7601 
7606  }
7607 
7612  }
7613 
7618  }
7619 
7623  return base::query_md(
7625  }
7626 
7630  return base::query_md(
7632  }
7633 
7640  }
7641 
7646  }
7647 
7654  }
7655 
7660  }
7661 
7662 protected:
7663  using rnn_base = rnn_primitive_desc_base;
7664 
7665  // (Deliberately not using doxygen comments)
7666  //
7667  // Constructs an RNN primitive descriptor base from a C API primitive
7668  // descriptor while checking that it actually describes the expected
7669  // primitive by comparing propagation and primitive kinds. Caller can
7670  // pass two options propagation kinds. This is typically used to check
7671  // that propagation kind is inference or training forward propagation.
7672  //
7673  // @param pd C API primitive descriptor.
7674  // @param prop_kind1 Expected propagation kind.
7675  // @param prop_kind2 Expected propagation kind.
7676  // @param cell_kind Expected cell kind.
7678  dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
7679  dnnl::algorithm cell_kind) {
7681  dnnl_status_t rc;
7682  rc = dnnl_primitive_desc_query(pd, dnnl_query_rnn_d, 0, &rnn_d);
7683  error::wrap_c_api(rc,
7684  "could not retrieve a descriptor from a primitive descriptor "
7685  "for an RNN primitive");
7686 
7687  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
7688  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
7689  dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
7690 
7691  bool ok = rnn_d->primitive_kind == dnnl_rnn
7692  && (rnn_d->prop_kind == c_prop_kind1
7693  || rnn_d->prop_kind == c_prop_kind2)
7694  && rnn_d->cell_kind == c_cell_kind;
7695 
7696  if (!ok)
7697  DNNL_THROW_ERROR(dnnl_invalid_arguments,
7698  "mismatch between expected and provided descriptors for an "
7699  "RNN primitive");
7700 
7701  reset_with_clone(pd);
7702  }
7703 };
7704 
7708  struct desc {
7709  dnnl_rnn_desc_t data;
7710 
7751  desc(prop_kind aprop_kind, algorithm activation,
7752  rnn_direction direction, const memory::desc &src_layer_desc,
7753  const memory::desc &src_iter_desc,
7754  const memory::desc &weights_layer_desc,
7755  const memory::desc &weights_iter_desc,
7756  const memory::desc &bias_desc,
7757  const memory::desc &dst_layer_desc,
7758  const memory::desc &dst_iter_desc,
7759  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7760  float beta = 0.0f) {
7763  dnnl::convert_to_c(aprop_kind),
7764  dnnl::convert_to_c(activation),
7765  dnnl::convert_to_c(direction), &src_layer_desc.data,
7766  &src_iter_desc.data, &weights_layer_desc.data,
7767  &weights_iter_desc.data, &bias_desc.data,
7768  &dst_layer_desc.data, &dst_iter_desc.data,
7769  dnnl::convert_to_c(flags), alpha, beta),
7770  "could not create a descriptor for a vanilla RNN forward "
7771  "propagation primitive");
7772  }
7773  };
7774 
7778  primitive_desc() = default;
7779 
7790  primitive_desc(const desc &adesc, const engine &aengine,
7791  bool allow_empty = false)
7793  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
7794 
7806  primitive_desc(const desc &adesc, const primitive_attr &attr,
7807  const engine &aengine, bool allow_empty = false)
7809  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7810 
7820  dnnl::algorithm::vanilla_rnn) {}
7821 
7824  return rnn_base::src_layer_desc();
7825  }
7826 
7829 
7833  }
7834 
7837  return rnn_base::weights_iter_desc();
7838  }
7839 
7842 
7845  return rnn_base::dst_layer_desc();
7846  }
7847 
7850 
7853  return rnn_base::workspace_desc();
7854  }
7855  };
7856 
7858  vanilla_rnn_forward() = default;
7859 
7864 };
7865 
7869  struct desc {
7870  dnnl_rnn_desc_t data;
7871 
7924  desc(prop_kind aprop_kind, algorithm activation,
7925  rnn_direction direction, const memory::desc &src_layer_desc,
7926  const memory::desc &src_iter_desc,
7927  const memory::desc &weights_layer_desc,
7928  const memory::desc &weights_iter_desc,
7929  const memory::desc &bias_desc,
7930  const memory::desc &dst_layer_desc,
7931  const memory::desc &dst_iter_desc,
7932  const memory::desc &diff_src_layer_desc,
7933  const memory::desc &diff_src_iter_desc,
7934  const memory::desc &diff_weights_layer_desc,
7935  const memory::desc &diff_weights_iter_desc,
7936  const memory::desc &diff_bias_desc,
7937  const memory::desc &diff_dst_layer_desc,
7938  const memory::desc &diff_dst_iter_desc,
7939  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7940  float beta = 0.0f) {
7943  dnnl::convert_to_c(aprop_kind),
7944  dnnl::convert_to_c(activation),
7945  dnnl::convert_to_c(direction), &src_layer_desc.data,
7946  &src_iter_desc.data, &weights_layer_desc.data,
7947  &weights_iter_desc.data, &bias_desc.data,
7948  &dst_layer_desc.data, &dst_iter_desc.data,
7949  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
7950  &diff_weights_layer_desc.data,
7951  &diff_weights_iter_desc.data, &diff_bias_desc.data,
7952  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
7953  dnnl::convert_to_c(flags), alpha, beta),
7954  "could not create a descriptor for a vanilla RNN backward "
7955  "propagation primitive");
7956  }
7957  };
7958 
7962  primitive_desc() = default;
7963 
7977  primitive_desc(const desc &adesc, const engine &aengine,
7978  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
7979  bool allow_empty = false)
7980  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
7981  hint_fwd_pd.get(), allow_empty) {}
7982 
7997  primitive_desc(const desc &adesc, const primitive_attr &attr,
7998  const engine &aengine,
7999  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
8000  bool allow_empty = false)
8001  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
8002  hint_fwd_pd.get(), allow_empty) {}
8003 
8012  dnnl::algorithm::vanilla_rnn) {}
8013 
8016  return rnn_base::src_layer_desc();
8017  }
8018 
8021 
8025  }
8026 
8029  return rnn_base::weights_iter_desc();
8030  }
8031 
8034 
8037  return rnn_base::dst_layer_desc();
8038  }
8039 
8042 
8045  return rnn_base::workspace_desc();
8046  }
8047 
8051  }
8052 
8056  }
8057 
8061  }
8062 
8066  }
8067 
8070  return rnn_base::diff_bias_desc();
8071  }
8072 
8076  }
8077 
8081  }
8082  };
8083 
8086 
8091 };
8092 
8094 struct lstm_forward : public primitive {
8096  struct desc {
8097  dnnl_rnn_desc_t data;
8098 
8147  desc(prop_kind aprop_kind, rnn_direction direction,
8148  const memory::desc &src_layer_desc,
8149  const memory::desc &src_iter_desc,
8150  const memory::desc &src_iter_c_desc,
8151  const memory::desc &weights_layer_desc,
8152  const memory::desc &weights_iter_desc,
8153  const memory::desc &weights_peephole_desc,
8154  const memory::desc &weights_projection_desc,
8155  const memory::desc &bias_desc,
8156  const memory::desc &dst_layer_desc,
8157  const memory::desc &dst_iter_desc,
8158  const memory::desc &dst_iter_c_desc,
8159  rnn_flags flags = rnn_flags::undef) {
8162  dnnl::convert_to_c(aprop_kind),
8163  dnnl::convert_to_c(direction), &src_layer_desc.data,
8164  &src_iter_desc.data, &src_iter_c_desc.data,
8165  &weights_layer_desc.data, &weights_iter_desc.data,
8166  &weights_peephole_desc.data,
8167  &weights_projection_desc.data, &bias_desc.data,
8168  &dst_layer_desc.data, &dst_iter_desc.data,
8169  &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8170  "could not create a descriptor for an LSTM forward "
8171  "propagation primitive");
8172  }
8173 
8215  desc(prop_kind aprop_kind, rnn_direction direction,
8216  const memory::desc &src_layer_desc,
8217  const memory::desc &src_iter_desc,
8218  const memory::desc &src_iter_c_desc,
8219  const memory::desc &weights_layer_desc,
8220  const memory::desc &weights_iter_desc,
8221  const memory::desc &weights_peephole_desc,
8222  const memory::desc &bias_desc,
8223  const memory::desc &dst_layer_desc,
8224  const memory::desc &dst_iter_desc,
8225  const memory::desc &dst_iter_c_desc,
8226  rnn_flags flags = rnn_flags::undef) {
8229  dnnl::convert_to_c(aprop_kind),
8230  dnnl::convert_to_c(direction), &src_layer_desc.data,
8231  &src_iter_desc.data, &src_iter_c_desc.data,
8232  &weights_layer_desc.data, &weights_iter_desc.data,
8233  &weights_peephole_desc.data, &bias_desc.data,
8234  &dst_layer_desc.data, &dst_iter_desc.data,
8235  &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8236  "could not create a descriptor for an LSTM forward "
8237  "propagation primitive");
8238  }
8239 
8276  desc(prop_kind aprop_kind, rnn_direction direction,
8277  const memory::desc &src_layer_desc,
8278  const memory::desc &src_iter_desc,
8279  const memory::desc &src_iter_c_desc,
8280  const memory::desc &weights_layer_desc,
8281  const memory::desc &weights_iter_desc,
8282  const memory::desc &bias_desc,
8283  const memory::desc &dst_layer_desc,
8284  const memory::desc &dst_iter_desc,
8285  const memory::desc &dst_iter_c_desc,
8286  rnn_flags flags = rnn_flags::undef) {
8289  dnnl::convert_to_c(aprop_kind),
8290  dnnl::convert_to_c(direction), &src_layer_desc.data,
8291  &src_iter_desc.data, &src_iter_c_desc.data,
8292  &weights_layer_desc.data, &weights_iter_desc.data,
8293  &bias_desc.data, &dst_layer_desc.data,
8294  &dst_iter_desc.data, &dst_iter_c_desc.data,
8295  dnnl::convert_to_c(flags)),
8296  "could not create a descriptor for an LSTM forward "
8297  "propagation primitive");
8298  }
8299  };
8300 
8304  primitive_desc() = default;
8305 
8315  primitive_desc(const desc &adesc, const engine &aengine,
8316  bool allow_empty = false)
8318  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
8319 
8330  primitive_desc(const desc &adesc, const primitive_attr &attr,
8331  const engine &aengine, bool allow_empty = false)
8333  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
8334 
8345 
8348  return rnn_base::src_layer_desc();
8349  }
8350 
8353 
8356  return rnn_base::src_iter_c_desc();
8357  }
8358 
8362  }
8363 
8366  return rnn_base::weights_iter_desc();
8367  }
8368 
8372  }
8373 
8377  }
8378 
8381 
8384  return rnn_base::dst_layer_desc();
8385  }
8386 
8389 
8392  return rnn_base::dst_iter_c_desc();
8393  }
8394 
8397  return rnn_base::workspace_desc();
8398  }
8399  };
8400 
8402  lstm_forward() = default;
8403 
8408 };
8409 
8411 struct lstm_backward : public primitive {
8413  struct desc {
8414  dnnl_rnn_desc_t data;
8415 
8491  desc(prop_kind aprop_kind, rnn_direction direction,
8492  const memory::desc &src_layer_desc,
8493  const memory::desc &src_iter_desc,
8494  const memory::desc &src_iter_c_desc,
8495  const memory::desc &weights_layer_desc,
8496  const memory::desc &weights_iter_desc,
8497  const memory::desc &weights_peephole_desc,
8498  const memory::desc &weights_projection_desc,
8499  const memory::desc &bias_desc,
8500  const memory::desc &dst_layer_desc,
8501  const memory::desc &dst_iter_desc,
8502  const memory::desc &dst_iter_c_desc,
8503  const memory::desc &diff_src_layer_desc,
8504  const memory::desc &diff_src_iter_desc,
8505  const memory::desc &diff_src_iter_c_desc,
8506  const memory::desc &diff_weights_layer_desc,
8507  const memory::desc &diff_weights_iter_desc,
8508  const memory::desc &diff_weights_peephole_desc,
8509  const memory::desc &diff_weights_projection_desc,
8510  const memory::desc &diff_bias_desc,
8511  const memory::desc &diff_dst_layer_desc,
8512  const memory::desc &diff_dst_iter_desc,
8513  const memory::desc &diff_dst_iter_c_desc,
8514  rnn_flags flags = rnn_flags::undef) {
8517  dnnl::convert_to_c(aprop_kind),
8518  dnnl::convert_to_c(direction), &src_layer_desc.data,
8519  &src_iter_desc.data, &src_iter_c_desc.data,
8520  &weights_layer_desc.data, &weights_iter_desc.data,
8521  &weights_peephole_desc.data,
8522  &weights_projection_desc.data, &bias_desc.data,
8523  &dst_layer_desc.data, &dst_iter_desc.data,
8524  &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8525  &diff_src_iter_desc.data,
8526  &diff_src_iter_c_desc.data,
8527  &diff_weights_layer_desc.data,
8528  &diff_weights_iter_desc.data,
8529  &diff_weights_peephole_desc.data,
8530  &diff_weights_projection_desc.data,
8531  &diff_bias_desc.data, &diff_dst_layer_desc.data,
8532  &diff_dst_iter_desc.data,
8533  &diff_dst_iter_c_desc.data,
8534  dnnl::convert_to_c(flags)),
8535  "could not create a descriptor for an LSTM backward "
8536  "propagation primitive");
8537  }
8538 
8603  desc(prop_kind aprop_kind, rnn_direction direction,
8604  const memory::desc &src_layer_desc,
8605  const memory::desc &src_iter_desc,
8606  const memory::desc &src_iter_c_desc,
8607  const memory::desc &weights_layer_desc,
8608  const memory::desc &weights_iter_desc,
8609  const memory::desc &weights_peephole_desc,
8610  const memory::desc &bias_desc,
8611  const memory::desc &dst_layer_desc,
8612  const memory::desc &dst_iter_desc,
8613  const memory::desc &dst_iter_c_desc,
8614  const memory::desc &diff_src_layer_desc,
8615  const memory::desc &diff_src_iter_desc,
8616  const memory::desc &diff_src_iter_c_desc,
8617  const memory::desc &diff_weights_layer_desc,
8618  const memory::desc &diff_weights_iter_desc,
8619  const memory::desc &diff_weights_peephole_desc,
8620  const memory::desc &diff_bias_desc,
8621  const memory::desc &diff_dst_layer_desc,
8622  const memory::desc &diff_dst_iter_desc,
8623  const memory::desc &diff_dst_iter_c_desc,
8624  rnn_flags flags = rnn_flags::undef) {
8627  dnnl::convert_to_c(aprop_kind),
8628  dnnl::convert_to_c(direction), &src_layer_desc.data,
8629  &src_iter_desc.data, &src_iter_c_desc.data,
8630  &weights_layer_desc.data, &weights_iter_desc.data,
8631  &weights_peephole_desc.data, &bias_desc.data,
8632  &dst_layer_desc.data, &dst_iter_desc.data,
8633  &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8634  &diff_src_iter_desc.data,
8635  &diff_src_iter_c_desc.data,
8636  &diff_weights_layer_desc.data,
8637  &diff_weights_iter_desc.data,
8638  &diff_weights_peephole_desc.data,
8639  &diff_bias_desc.data, &diff_dst_layer_desc.data,
8640  &diff_dst_iter_desc.data,
8641  &diff_dst_iter_c_desc.data,
8642  dnnl::convert_to_c(flags)),
8643  "could not create a descriptor for an LSTM backward "
8644  "propagation primitive");
8645  }
8646 
8702  desc(prop_kind aprop_kind, rnn_direction direction,
8703  const memory::desc &src_layer_desc,
8704  const memory::desc &src_iter_desc,
8705  const memory::desc &src_iter_c_desc,
8706  const memory::desc &weights_layer_desc,
8707  const memory::desc &weights_iter_desc,
8708  const memory::desc &bias_desc,
8709  const memory::desc &dst_layer_desc,
8710  const memory::desc &dst_iter_desc,
8711  const memory::desc &dst_iter_c_desc,
8712  const memory::desc &diff_src_layer_desc,
8713  const memory::desc &diff_src_iter_desc,
8714  const memory::desc &diff_src_iter_c_desc,
8715  const memory::desc &diff_weights_layer_desc,
8716  const memory::desc &diff_weights_iter_desc,
8717  const memory::desc &diff_bias_desc,
8718  const memory::desc &diff_dst_layer_desc,
8719  const memory::desc &diff_dst_iter_desc,
8720  const memory::desc &diff_dst_iter_c_desc,
8721  rnn_flags flags = rnn_flags::undef) {
8724  dnnl::convert_to_c(aprop_kind),
8725  dnnl::convert_to_c(direction), &src_layer_desc.data,
8726  &src_iter_desc.data, &src_iter_c_desc.data,
8727  &weights_layer_desc.data, &weights_iter_desc.data,
8728  &bias_desc.data, &dst_layer_desc.data,
8729  &dst_iter_desc.data, &dst_iter_c_desc.data,
8730  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
8731  &diff_src_iter_c_desc.data,
8732  &diff_weights_layer_desc.data,
8733  &diff_weights_iter_desc.data, &diff_bias_desc.data,
8734  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
8735  &diff_dst_iter_c_desc.data,
8736  dnnl::convert_to_c(flags)),
8737  "could not create a descriptor for an LSTM backward "
8738  "propagation primitive");
8739  }
8740  };
8741 
8745  primitive_desc() = default;
8746 
8759  primitive_desc(const desc &adesc, const engine &aengine,
8760  const lstm_forward::primitive_desc &hint_fwd_pd,
8761  bool allow_empty = false)
8762  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
8763  hint_fwd_pd.get(), allow_empty) {}
8764 
8778  primitive_desc(const desc &adesc, const primitive_attr &attr,
8779  const engine &aengine,
8780  const lstm_forward::primitive_desc &hint_fwd_pd,
8781  bool allow_empty = false)
8782  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
8783  hint_fwd_pd.get(), allow_empty) {}
8784 
8794 
8797  return rnn_base::src_layer_desc();
8798  }
8799 
8802 
8805  return rnn_base::src_iter_c_desc();
8806  }
8807 
8811  }
8812 
8815  return rnn_base::weights_iter_desc();
8816  }
8817 
8821  }
8822 
8826  }
8827 
8830 
8833  return rnn_base::dst_layer_desc();
8834  }
8835 
8838 
8841  return rnn_base::dst_iter_c_desc();
8842  }
8843 
8846  return rnn_base::workspace_desc();
8847  }
8848 
8852  }
8853 
8857  }
8858 
8862  }
8863 
8867  }
8868 
8872  }
8873 
8877  }
8878 
8882  }
8883 
8886  return rnn_base::diff_bias_desc();
8887  }
8888 
8892  }
8893 
8897  }
8898 
8902  }
8903  };
8904 
8906  lstm_backward() = default;
8907 
8912 };
8913 
8915 struct gru_forward : public primitive {
8917  struct desc {
8918  dnnl_rnn_desc_t data;
8919 
8952  desc(prop_kind aprop_kind, rnn_direction direction,
8953  const memory::desc &src_layer_desc,
8954  const memory::desc &src_iter_desc,
8955  const memory::desc &weights_layer_desc,
8956  const memory::desc &weights_iter_desc,
8957  const memory::desc &bias_desc,
8958  const memory::desc &dst_layer_desc,
8959  const memory::desc &dst_iter_desc,
8960  rnn_flags flags = rnn_flags::undef) {
8963  dnnl::convert_to_c(aprop_kind),
8964  dnnl::convert_to_c(direction), &src_layer_desc.data,
8965  &src_iter_desc.data, &weights_layer_desc.data,
8966  &weights_iter_desc.data, &bias_desc.data,
8967  &dst_layer_desc.data, &dst_iter_desc.data,
8968  dnnl::convert_to_c(flags)),
8969  "could not create a descriptor for a GRU forward "
8970  "propagation primitive");
8971  }
8972  };
8973 
8977  primitive_desc() = default;
8978 
8988  primitive_desc(const desc &adesc, const engine &aengine,
8989  bool allow_empty = false)
8991  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
8992 
9003  primitive_desc(const desc &adesc, const primitive_attr &attr,
9004  const engine &aengine, bool allow_empty = false)
9006  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9007 
9017  dnnl::algorithm::vanilla_gru) {}
9018 
9021  return rnn_base::src_layer_desc();
9022  }
9023 
9026 
9030  }
9031 
9034  return rnn_base::weights_iter_desc();
9035  }
9036 
9039 
9042  return rnn_base::dst_layer_desc();
9043  }
9044 
9047 
9050  return rnn_base::workspace_desc();
9051  }
9052  };
9053 
9055  gru_forward() = default;
9056 
9061 };
9062 
9064 struct gru_backward : public primitive {
9066  struct desc {
9067  dnnl_rnn_desc_t data;
9068 
9113  desc(prop_kind aprop_kind, rnn_direction direction,
9114  const memory::desc &src_layer_desc,
9115  const memory::desc &src_iter_desc,
9116  const memory::desc &weights_layer_desc,
9117  const memory::desc &weights_iter_desc,
9118  const memory::desc &bias_desc,
9119  const memory::desc &dst_layer_desc,
9120  const memory::desc &dst_iter_desc,
9121  const memory::desc &diff_src_layer_desc,
9122  const memory::desc &diff_src_iter_desc,
9123  const memory::desc &diff_weights_layer_desc,
9124  const memory::desc &diff_weights_iter_desc,
9125  const memory::desc &diff_bias_desc,
9126  const memory::desc &diff_dst_layer_desc,
9127  const memory::desc &diff_dst_iter_desc,
9128  rnn_flags flags = rnn_flags::undef) {
9131  dnnl::convert_to_c(aprop_kind),
9132  dnnl::convert_to_c(direction), &src_layer_desc.data,
9133  &src_iter_desc.data, &weights_layer_desc.data,
9134  &weights_iter_desc.data, &bias_desc.data,
9135  &dst_layer_desc.data, &dst_iter_desc.data,
9136  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9137  &diff_weights_layer_desc.data,
9138  &diff_weights_iter_desc.data, &diff_bias_desc.data,
9139  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9140  dnnl::convert_to_c(flags)),
9141  "could not create a descriptor for a GRU backward "
9142  "propagation primitive");
9143  }
9144  };
9145 
9149  primitive_desc() = default;
9150 
9163  primitive_desc(const desc &adesc, const engine &aengine,
9164  const gru_forward::primitive_desc &hint_fwd_pd,
9165  bool allow_empty = false)
9166  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
9167  hint_fwd_pd.get(), allow_empty) {}
9168 
9182  primitive_desc(const desc &adesc, const primitive_attr &attr,
9183  const engine &aengine,
9184  const gru_forward::primitive_desc &hint_fwd_pd,
9185  bool allow_empty = false)
9186  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
9187  hint_fwd_pd.get(), allow_empty) {}
9188 
9197  dnnl::algorithm::vanilla_gru) {}
9198 
9201  return rnn_base::src_layer_desc();
9202  }
9203 
9206 
9210  }
9211 
9214  return rnn_base::weights_iter_desc();
9215  }
9216 
9219 
9222  return rnn_base::dst_layer_desc();
9223  }
9224 
9227 
9230  return rnn_base::workspace_desc();
9231  }
9232 
9236  }
9237 
9241  }
9242 
9246  }
9247 
9251  }
9252 
9255  return rnn_base::diff_bias_desc();
9256  }
9257 
9261  }
9262 
9266  }
9267  };
9268 
9270  gru_backward() = default;
9271 
9276 };
9277 
9279 struct lbr_gru_forward : public primitive {
9281  struct desc {
9282  dnnl_rnn_desc_t data;
9283 
9317  desc(prop_kind aprop_kind, rnn_direction direction,
9318  const memory::desc &src_layer_desc,
9319  const memory::desc &src_iter_desc,
9320  const memory::desc &weights_layer_desc,
9321  const memory::desc &weights_iter_desc,
9322  const memory::desc &bias_desc,
9323  const memory::desc &dst_layer_desc,
9324  const memory::desc &dst_iter_desc,
9325  rnn_flags flags = rnn_flags::undef) {
9328  dnnl::convert_to_c(aprop_kind),
9329  dnnl::convert_to_c(direction), &src_layer_desc.data,
9330  &src_iter_desc.data, &weights_layer_desc.data,
9331  &weights_iter_desc.data, &bias_desc.data,
9332  &dst_layer_desc.data, &dst_iter_desc.data,
9333  dnnl::convert_to_c(flags)),
9334  "could not create a descriptor for an LBR GRU forward "
9335  "propagation primitive");
9336  }
9337  };
9338 
9342  primitive_desc() = default;
9343 
9354  primitive_desc(const desc &adesc, const engine &aengine,
9355  bool allow_empty = false)
9357  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9358 
9370  primitive_desc(const desc &adesc, const primitive_attr &attr,
9371  const engine &aengine, bool allow_empty = false)
9373  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9374 
9384  dnnl::algorithm::lbr_gru) {}
9385 
9388  return rnn_base::src_layer_desc();
9389  }
9390 
9393 
9397  }
9398 
9401  return rnn_base::weights_iter_desc();
9402  }
9403 
9406 
9409  return rnn_base::dst_layer_desc();
9410  }
9411 
9414 
9417  return rnn_base::workspace_desc();
9418  }
9419  };
9420 
9422  lbr_gru_forward() = default;
9423 
9428 };
9429 
9431 struct lbr_gru_backward : public primitive {
9433  struct desc {
9434  dnnl_rnn_desc_t data;
9435 
9481  desc(prop_kind aprop_kind, rnn_direction direction,
9482  const memory::desc &src_layer_desc,
9483  const memory::desc &src_iter_desc,
9484  const memory::desc &weights_layer_desc,
9485  const memory::desc &weights_iter_desc,
9486  const memory::desc &bias_desc,
9487  const memory::desc &dst_layer_desc,
9488  const memory::desc &dst_iter_desc,
9489  const memory::desc &diff_src_layer_desc,
9490  const memory::desc &diff_src_iter_desc,
9491  const memory::desc &diff_weights_layer_desc,
9492  const memory::desc &diff_weights_iter_desc,
9493  const memory::desc &diff_bias_desc,
9494  const memory::desc &diff_dst_layer_desc,
9495  const memory::desc &diff_dst_iter_desc,
9496  rnn_flags flags = rnn_flags::undef) {
9499  dnnl::convert_to_c(aprop_kind),
9500  dnnl::convert_to_c(direction), &src_layer_desc.data,
9501  &src_iter_desc.data, &weights_layer_desc.data,
9502  &weights_iter_desc.data, &bias_desc.data,
9503  &dst_layer_desc.data, &dst_iter_desc.data,
9504  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9505  &diff_weights_layer_desc.data,
9506  &diff_weights_iter_desc.data, &diff_bias_desc.data,
9507  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9508  dnnl::convert_to_c(flags)),
9509  "could not create a descriptor for an LBR GRU backward "
9510  "propagation primitive");
9511  }
9512  };
9513 
9517  primitive_desc() = default;
9518 
9532  primitive_desc(const desc &adesc, const engine &aengine,
9533  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9534  bool allow_empty = false)
9535  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
9536  hint_fwd_pd.get(), allow_empty) {}
9537 
9552  primitive_desc(const desc &adesc, const primitive_attr &attr,
9553  const engine &aengine,
9554  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9555  bool allow_empty = false)
9556  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
9557  hint_fwd_pd.get(), allow_empty) {}
9558 
9568 
9571  return rnn_base::src_layer_desc();
9572  }
9573 
9576 
9580  }
9581 
9584  return rnn_base::weights_iter_desc();
9585  }
9586 
9589 
9592  return rnn_base::dst_layer_desc();
9593  }
9594 
9597 
9600  return rnn_base::workspace_desc();
9601  }
9602 
9606  }
9607 
9611  }
9612 
9616  }
9617 
9621  }
9622 
9625  return rnn_base::diff_bias_desc();
9626  }
9627 
9631  }
9632 
9636  }
9637  };
9638 
9640  lbr_gru_backward() = default;
9641 
9646 };
9647 
9649 
9657 
9659 struct shuffle_forward : public primitive {
9661  struct desc {
9662  dnnl_shuffle_desc_t data;
9663 
9673  desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis,
9674  int group_size) {
9676  dnnl::convert_to_c(aprop_kind),
9677  &data_desc.data, axis, group_size),
9678  "could not create a descriptor for a shuffle forward "
9679  "propagation primitive");
9680  }
9681  };
9682 
9686  primitive_desc() = default;
9687 
9699  primitive_desc(const desc &adesc, const engine &aengine,
9700  const primitive_attr &attr = primitive_attr(),
9701  bool allow_empty = false)
9702  : dnnl::primitive_desc(
9703  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9704 
9712  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
9715 
9717  memory::desc src_desc() const { return base::src_desc(0); }
9718 
9720  memory::desc dst_desc() const { return base::dst_desc(0); }
9721  };
9722 
9724  shuffle_forward() = default;
9725 
9730 };
9731 
9733 struct shuffle_backward : public primitive {
9736  struct desc {
9737  dnnl_shuffle_desc_t data;
9738 
9746  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
9748  &diff_data_desc.data, axis, group_size),
9749  "could not create a descriptor for a shuffle backward "
9750  "propagation primitive");
9751  }
9752  };
9753 
9757  primitive_desc() = default;
9758 
9773  primitive_desc(const desc &adesc, const engine &aengine,
9774  const shuffle_forward::primitive_desc &hint_fwd_pd,
9775  const primitive_attr &attr = primitive_attr(),
9776  bool allow_empty = false)
9777  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
9778  hint_fwd_pd.get(), allow_empty) {}
9779 
9787  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
9789 
9792 
9795  };
9796 
9798  shuffle_backward() = default;
9799 
9804 };
9805 
9807 
9815 
9817 struct binary : public primitive {
9819  struct desc {
9822 
9824  desc() = default;
9825 
9833  desc(algorithm aalgorithm, const memory::desc &src0,
9834  const memory::desc &src1, const memory::desc &dst) {
9837  &src0.data, &src1.data, &dst.data),
9838  "could not create a descriptor for a binary operation "
9839  "primitive");
9840  }
9841  };
9842 
9846  primitive_desc() = default;
9847 
9857  primitive_desc(const desc &adesc, const engine &aengine,
9858  bool allow_empty = false)
9859  : dnnl::primitive_desc(
9860  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9861 
9872  primitive_desc(const desc &adesc, const primitive_attr &attr,
9873  const engine &aengine, bool allow_empty = false)
9874  : dnnl::primitive_desc(
9875  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9876 
9883 
9885  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
9886 
9888  memory::desc src0_desc() const { return base::src_desc(0); }
9889 
9891  memory::desc src1_desc() const { return base::src_desc(1); }
9892 
9894  memory::desc dst_desc() const { return base::dst_desc(0); }
9895  };
9896 
9898  binary() = default;
9899 
9903  binary(const primitive_desc &pd) : primitive(pd) {}
9904 };
9905 
9907 
9917 
9919 struct matmul : public primitive {
9921  struct desc {
9922  dnnl_matmul_desc_t data;
9923 
9929  desc(const memory::desc &src_desc, const memory::desc &weights_desc,
9930  const memory::desc &dst_desc) {
9932  dnnl_matmul_desc_init(&data, &src_desc.data,
9933  &weights_desc.data, nullptr, &dst_desc.data),
9934  "could not create a descriptor for a matmul primitive");
9935  }
9936 
9943  desc(const memory::desc &src_desc, const memory::desc &weights_desc,
9944  const memory::desc &bias_desc, const memory::desc &dst_desc) {
9945  error::wrap_c_api(dnnl_matmul_desc_init(&data, &src_desc.data,
9946  &weights_desc.data, &bias_desc.data,
9947  &dst_desc.data),
9948  "could not create a descriptor for a matmul primitive");
9949  }
9950  };
9951 
9955  primitive_desc() = default;
9956 
9965  primitive_desc(const desc &adesc, const engine &aengine,
9966  bool allow_empty = false)
9967  : dnnl::primitive_desc(
9968  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9969 
9979  primitive_desc(const desc &adesc, const primitive_attr &attr,
9980  const engine &aengine, bool allow_empty = false)
9981  : dnnl::primitive_desc(
9982  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9983 
9990 
9993 
9996  return query_md(query::weights_md, 0);
9997  }
9998 
10001  return query_md(query::weights_md, 1);
10002  }
10003 
10006  };
10007 
10009  matmul() = default;
10010 
10013  matmul(const primitive_desc &pd) : primitive(pd) {}
10014 };
10015 
10017 
10027 
10031  struct desc {
10033 
10049  desc(prop_kind aprop_kind, algorithm aalgorithm,
10050  const memory::desc &src_desc, const memory::desc &dst_desc) {
10052  dnnl::convert_to_c(aprop_kind),
10053  convert_to_c(aalgorithm), nullptr,
10054  &src_desc.data, &dst_desc.data),
10055  "could not create a resampling forward descriptor");
10056  }
10057 
10069  desc(prop_kind aprop_kind, algorithm aalgorithm,
10070  const std::vector<float> &factors,
10071  const memory::desc &src_desc) {
10072  memory::validate_dims(factors, src_desc.data.ndims - 2);
10074  dnnl::convert_to_c(aprop_kind),
10075  convert_to_c(aalgorithm), &factors[0],
10076  &src_desc.data, nullptr),
10077  "could not create a resampling forward descriptor");
10078  }
10079 
10096  desc(prop_kind aprop_kind, algorithm aalgorithm,
10097  const std::vector<float> &factors, const memory::desc &src_desc,
10098  const memory::desc &dst_desc) {
10099  if (!factors.empty())
10100  memory::validate_dims(factors, src_desc.data.ndims - 2);
10102  dnnl::convert_to_c(aprop_kind),
10103  convert_to_c(aalgorithm), factors.data(),
10104  &src_desc.data, &dst_desc.data),
10105  "could not create a resampling forward descriptor");
10106  }
10107  };
10108 
10112  primitive_desc() = default;
10113 
10124  primitive_desc(const desc &adesc, const engine &aengine,
10125  bool allow_empty = false)
10126  : dnnl::primitive_desc(
10127  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10128 
10140  primitive_desc(const desc &adesc, const primitive_attr &attr,
10141  const engine &aengine, bool allow_empty = false)
10142  : dnnl::primitive_desc(
10143  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10144 
10152  : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
10155 
10157  memory::desc src_desc() const { return base::src_desc(0); }
10158 
10160  memory::desc dst_desc() const { return base::dst_desc(0); }
10161  };
10162 
10164  resampling_forward() = default;
10165 
10170 };
10171 
10175  struct desc {
10177 
10186  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
10187  const memory::desc &diff_dst_desc) {
10189  convert_to_c(aalgorithm), nullptr,
10190  &diff_src_desc.data, &diff_dst_desc.data),
10191  "could not create a resampling backward data descriptor");
10192  }
10193 
10203  desc(algorithm aalgorithm, const std::vector<float> &factors,
10204  const memory::desc &diff_src_desc,
10205  const memory::desc &diff_dst_desc) {
10206  if (!factors.empty())
10207  memory::validate_dims(factors, diff_src_desc.data.ndims - 2);
10209  convert_to_c(aalgorithm), factors.data(),
10210  &diff_src_desc.data, &diff_dst_desc.data),
10211  "could not create a resampling backward data descriptor");
10212  }
10213  };
10214 
10218  primitive_desc() = default;
10219 
10233  primitive_desc(const desc &adesc, const engine &aengine,
10234  const resampling_forward::primitive_desc &hint_fwd_pd,
10235  bool allow_empty = false)
10236  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10237  hint_fwd_pd.get(), allow_empty) {}
10238 
10253  primitive_desc(const desc &adesc, const primitive_attr &attr,
10254  const engine &aengine,
10255  const resampling_forward::primitive_desc &hint_fwd_pd,
10256  bool allow_empty = false)
10257  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10258  hint_fwd_pd.get(), allow_empty) {}
10259 
10267  : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
10269 
10272 
10275  };
10276 
10278  resampling_backward() = default;
10279 
10284 };
10285 
10287 
10295 
10299  struct desc {
10301 
10328  desc(prop_kind aprop_kind, algorithm aalgorithm,
10329  const memory::desc &src_desc, const memory::desc &dst_desc,
10330  const memory::dims &strides, const memory::dims &kernel,
10331  const memory::dims &dilation, const memory::dims &padding_l,
10332  const memory::dims &padding_r) {
10333  memory::validate_dims(strides, src_desc.data.ndims - 2);
10334  memory::validate_dims(kernel, src_desc.data.ndims - 2);
10335  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
10336  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
10337  memory::validate_dims(dilation, src_desc.data.ndims - 2);
10340  dnnl::convert_to_c(aprop_kind),
10341  convert_to_c(aalgorithm), &src_desc.data,
10342  &dst_desc.data, &strides[0], &kernel[0],
10343  &dilation[0], &padding_l[0], &padding_r[0]),
10344  "could not create a descriptor for a pooling forward "
10345  "propagation primitive");
10346  }
10347  };
10348 
10352  primitive_desc() = default;
10353 
10364  primitive_desc(const desc &adesc, const engine &aengine,
10365  bool allow_empty = false)
10366  : dnnl::primitive_desc(
10367  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10368 
10380  primitive_desc(const desc &adesc, const primitive_attr &attr,
10381  const engine &aengine, bool allow_empty = false)
10382  : dnnl::primitive_desc(
10383  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10384 
10393  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling_v2,
10396 
10398  memory::desc src_desc() const { return base::src_desc(0); }
10399 
10401  memory::desc dst_desc() const { return base::dst_desc(0); }
10402 
10405  };
10406 
10408  pooling_v2_forward() = default;
10409 
10415 };
10416 
10420  struct desc {
10422 
10446  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
10447  const memory::desc &diff_dst_desc, const memory::dims &strides,
10448  const memory::dims &kernel, const memory::dims &dilation,
10449  const memory::dims &padding_l, const memory::dims &padding_r) {
10450  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
10451  memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
10452  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
10453  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
10454  memory::validate_dims(dilation, diff_src_desc.data.ndims - 2);
10457  convert_to_c(aalgorithm), &diff_src_desc.data,
10458  &diff_dst_desc.data, &strides[0], &kernel[0],
10459  &dilation[0], &padding_l[0], &padding_r[0]),
10460  "could not create a descriptor for a pooling backward "
10461  "propagation primitive");
10462  }
10463  };
10464 
10469  primitive_desc() = default;
10470 
10484  primitive_desc(const desc &adesc, const engine &aengine,
10485  const pooling_v2_forward::primitive_desc &hint_fwd_pd,
10486  bool allow_empty = false)
10487  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10488  hint_fwd_pd.get(), allow_empty) {}
10489 
10504  primitive_desc(const desc &adesc, const primitive_attr &attr,
10505  const engine &aengine,
10506  const pooling_v2_forward::primitive_desc &hint_fwd_pd,
10507  bool allow_empty = false)
10508  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10509  hint_fwd_pd.get(), allow_empty) {}
10510 
10519  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling_v2,
10521 
10524 
10527 
10530  };
10531 
10533  pooling_v2_backward() = default;
10534 
10540 };
10541 
10543 
10552 
10554 struct prelu_forward : public primitive {
10556  struct desc {
10557  dnnl_prelu_desc_t data;
10558 
10567  desc(prop_kind aprop_kind, const memory::desc &data_desc,
10568  const memory::desc &weight_desc) {
10570  dnnl::convert_to_c(aprop_kind),
10571  &data_desc.data, &weight_desc.data),
10572  "could not create a descriptor for a prelu forward "
10573  "propagation primitive");
10574  }
10575  };
10576 
10580  primitive_desc() = default;
10581 
10592  primitive_desc(const desc &adesc, const engine &aengine,
10593  bool allow_empty = false)
10594  : dnnl::primitive_desc(
10595  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10596 
10608  primitive_desc(const desc &adesc, const primitive_attr &attr,
10609  const engine &aengine, bool allow_empty = false)
10610  : dnnl::primitive_desc(
10611  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10612 
10620  : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
10623 
10625  memory::desc src_desc() const { return base::src_desc(0); }
10626 
10628  memory::desc dst_desc() const { return base::dst_desc(0); }
10629  };
10630 
10632  prelu_forward() = default;
10633 
10638 };
10639 
10641 struct prelu_backward : public primitive {
10643  struct desc {
10644  dnnl_prelu_desc_t data;
10645 
10654  desc(const memory::desc &data_desc, const memory::desc &weight_desc,
10655  const memory::desc &diff_data_desc,
10656  const memory::desc &diff_weights_desc) {
10658  dnnl_prelu_backward_desc_init(&data, &data_desc.data,
10659  &weight_desc.data, &diff_data_desc.data,
10660  &diff_weights_desc.data),
10661  "could not create a descriptor for a prelu backward "
10662  "propagation primitive");
10663  }
10664  };
10665 
10669  primitive_desc() = default;
10670 
10684  primitive_desc(const desc &adesc, const engine &aengine,
10685  const prelu_forward::primitive_desc &hint_fwd_pd,
10686  bool allow_empty = false)
10687  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10688  hint_fwd_pd.get(), allow_empty) {}
10689 
10704  primitive_desc(const desc &adesc, const primitive_attr &attr,
10705  const engine &aengine,
10706  const prelu_forward::primitive_desc &hint_fwd_pd,
10707  bool allow_empty = false)
10708  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10709  hint_fwd_pd.get(), allow_empty) {}
10710 
10718  : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
10720 
10722  memory::desc src_desc() const { return base::src_desc(0); }
10723 
10726 
10729  };
10730 
10732  prelu_backward() = default;
10733 
10738 };
10739 
10741 
10750 
10752 struct reduction : public primitive {
10754  struct desc {
10755  dnnl_reduction_desc_t data;
10756 
10758  desc() = default;
10759 
10777  desc(algorithm aalgorithm, const memory::desc &src_desc,
10778  const memory::desc &dst_desc, float p, float eps) {
10780  dnnl_reduction_desc_init(&data, convert_to_c(aalgorithm),
10781  &src_desc.data, &dst_desc.data, p, eps),
10782  "could not create a reduction descriptor");
10783  }
10784  };
10785 
10789  primitive_desc() = default;
10790 
10799  primitive_desc(const desc &adesc, const engine &aengine,
10800  bool allow_empty = false)
10801  : dnnl::primitive_desc(
10802  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10803 
10813  primitive_desc(const desc &adesc, const primitive_attr &attr,
10814  const engine &aengine, bool allow_empty = false)
10815  : dnnl::primitive_desc(
10816  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10817 
10824 
10826  memory::desc src_desc() const { return base::src_desc(0); }
10827 
10829  memory::desc dst_desc() const { return base::dst_desc(0); }
10830  };
10831 
10833  reduction() = default;
10834 
10837  reduction(const primitive_desc &pd) : primitive(pd) {}
10838 };
10839 
10841 
10843 
10849 
10852 
10854 enum class status {
10869 };
10870 
10872 inline status set_verbose(int level) {
10873  return static_cast<status>(dnnl_set_verbose(level));
10874 }
10875 
10877 inline const version_t *version() {
10878  return dnnl_version();
10879 }
10880 
10882 inline status set_jit_dump(int enable) {
10883  return static_cast<status>(dnnl_set_jit_dump(enable));
10884 }
10885 
10887 inline status set_jit_profiling_flags(unsigned flags) {
10888  return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
10889 }
10890 
10892 inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
10893  return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
10894 }
10895 
10897 enum class cpu_isa {
10920 };
10921 
10924  return static_cast<status>(
10925  dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
10926 }
10927 
10930  return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
10931 }
10932 
10934 enum class cpu_isa_hints {
10939 };
10940 
10943  return static_cast<status>(dnnl_set_cpu_isa_hints(
10944  static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
10945 }
10946 
10949  return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
10950 }
10951 
10953 
10959 
10963  int result = 0;
10965  "could not get primitive cache capacity");
10966  return result;
10967 }
10968 
10970 inline void set_primitive_cache_capacity(int capacity) {
10972  "could not set primitive cache capacity");
10973 }
10974 
10976 
10983 
10985 inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
10986  dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
10987  const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
10988  return static_cast<status>(dnnl_sgemm(
10989  transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
10990 }
10991 
10993 inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
10994  dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
10995  dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
10996  float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
10997  return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
10998  K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
10999 }
11000 
11002 inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
11003  dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
11004  dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
11005  float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
11006  return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
11007  K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
11008 }
11009 
11011 
11012 // implementation section
11013 
11016  dnnl_primitive_t result;
11018  "could not create a primitive");
11019  reset(result);
11020 }
11021 
11022 inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
11023 
11024 inline void primitive::execute(const stream &astream,
11025  const std::unordered_map<int, memory> &args) const {
11026  std::vector<dnnl_exec_arg_t> c_args;
11027  c_args.reserve(args.size());
11028  for (const auto &a : args)
11029  c_args.push_back({a.first, a.second.get(true)});
11030 
11031  error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
11032  (int)c_args.size(), c_args.data()),
11033  "could not execute a primitive");
11034 }
11035 
11037 
11038 #undef DNNL_DEFINE_BITMASK_OPS
11039 
11040 } // namespace dnnl
11041 
11043 
11046 namespace oneapi {
11047 // Note: without this guard, doxygen warns of potentially recursive namespace
11048 #ifndef DOXYGEN_SHOULD_SKIP_THIS
11050 namespace dnnl = ::dnnl;
11051 #endif
11052 } // namespace oneapi
11053 
11055 
11056 #endif /* ONEAPI_DNNL_DNNL_HPP */
algorithm
Kinds of algorithms.
Definition: dnnl.hpp:470
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(dnnl_primitive_attr_t attr, const float scale, const float shift)
Set quantization scale and shift parameters for RNN data tensors.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v2(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_data_type_t *data_type)
Returns the parameters of an accumulation (sum) post-op with a data type parameter.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 2.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode)
Sets primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops)
Returns primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns the quantization scaling factors for RNN weights tensors.
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 1.
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops)
Destroys post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const int32_t *zero_points)
Sets primitive attributes zero points for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops)
Sets primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops, float scale)
Appends an accumulation (sum) to post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets quantization scaling factors for RNN weights tensors.
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr)
Destroys primitive attributes.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v2(dnnl_post_ops_t post_ops, float scale, dnnl_data_type_t data_type)
Appends an accumulation v2 (sum) to post-ops.
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops)
Returns the length of post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc)
Appends a binary post-op.
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 2.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns the quantization scaling factors for RNN projection weights tensors.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_alg_kind_t *alg_kind, float *alpha, float *beta)
Returns the parameters of an elementwise post-op.
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops)
Creates empty post-ops sequence.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const float *scales)
Sets primitive attributes scaling factors for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode)
Returns the primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr)
Clones primitive attributes.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 1.
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(const_dnnl_post_ops_t post_ops, int index)
Returns the kind of a post-op entry.
scratchpad_mode
Scratchpad mode.
Definition: dnnl.hpp:401
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets quantization scaling factors for RNN projection weights tensors.
prop_kind
Propagation kind.
Definition: dnnl.hpp:435
dnnl_scratchpad_mode_t
Scratchpad mode.
Definition: dnnl_types.h:2270
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta)
Appends an elementwise post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points(const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const int32_t **zero_points)
Returns count, correspondence zero point mask, and a pointer to a constant int32_t array of zero_poin...
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(const_dnnl_post_ops_t post_ops, int index, float *scale)
Returns the parameters of an accumulation (sum) post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind, const dnnl_memory_desc_t **src1_desc)
Returns the parameters of a binary post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(const_dnnl_primitive_attr_t attr, float *scale, float *shift)
Returns the quantization scale and shift parameters for RNN data tensors.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets output scaling factors correspondence mask and values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes scaling factors correspondence mask and values for a given memory argume...
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr)
Creates an empty (default) primitive attributes with all the parameters set to their default values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes output scaling factors correspondence mask and values.
@ eltwise_mish
Elementwise: mish.
@ resampling_linear
Linear (Bilinear, Trilinear) resampling method.
@ binary_mul
Binary mul.
@ resampling_nearest
Nearest Neighbor resampling method.
@ eltwise_elu_use_dst_for_bwd
Elementwise: exponential linear unit (ELU) (dst for backward)
@ eltwise_tanh_use_dst_for_bwd
Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
@ reduction_norm_lp_power_p_sum
Reduction using norm_lp_power_p_sum operation.
@ eltwise_linear
Elementwise: linear.
@ eltwise_clip_v2
Eltwise: clip version 2.
@ eltwise_soft_relu
Elementwise: soft_relu.
@ vanilla_gru
GRU cell.
@ eltwise_logistic
Elementwise: logistic.
@ binary_div
Binary div.
@ eltwise_clip
Elementwise: clip.
@ binary_ge
Binary greater than or equal.
@ eltwise_abs
Elementwise: abs.
@ eltwise_pow
Elementwise: pow.
@ eltwise_tanh
Elementwise: hyperbolic tangent non-linearity (tanh)
@ eltwise_logistic_use_dst_for_bwd
Elementwise: logistic (dst for backward)
@ eltwise_bounded_relu
Elementwise: bounded_relu.
@ reduction_norm_lp_power_p_max
Reduction using norm_lp_power_p_max operation.
@ reduction_max
Reduction using max operation.
@ eltwise_clip_v2_use_dst_for_bwd
Elementwise: clip version 2 (dst for backward)
@ eltwise_square
Elementwise: square.
@ binary_max
Binary max.
@ convolution_direct
Direct convolution.
@ eltwise_exp
Elementwise: exponent.
@ binary_gt
Binary greater than.
@ reduction_norm_lp_max
Reduction using norm_lp_max operation.
@ eltwise_elu
Elementwise: exponential linear unit (ELU)
@ convolution_winograd
Winograd convolution.
@ vanilla_lstm
LSTM cell.
@ deconvolution_direct
Direct deconvolution.
@ pooling_avg
Average pooling exclude padding, alias for dnnl::algorithm::pooling_avg_include_padding.
@ lbr_gru
GRU cell with linear before reset.
@ binary_eq
Binary equal.
@ pooling_avg_exclude_padding
Average pooling exclude padding.
@ eltwise_gelu
Elementwise: gelu alias for dnnl::algorithm::eltwise_gelu_tanh.
@ eltwise_sqrt
Elementwise: square root.
@ pooling_max
Max pooling.
@ reduction_min
Reduction using min operation.
@ eltwise_gelu_erf
Elementwise: erf-based gelu.
@ eltwise_swish
Elementwise: swish ( )
@ binary_sub
Binary sub.
@ binary_ne
Binary not equal.
@ lrn_within_channel
LRN within a single channel.
@ binary_le
Binary less than or equal.
@ eltwise_hardswish
Elementwise: hardswish.
@ reduction_mul
Reduction using mul operation.
@ vanilla_rnn
RNN cell.
@ binary_add
Binary add.
@ lrn_across_channels
Local response normalization (LRN) across multiple channels.
@ eltwise_relu
Elementwise: rectified linear unit (ReLU)
@ eltwise_gelu_tanh
Elementwise: tanh-based gelu.
@ eltwise_relu_use_dst_for_bwd
Elementwise: rectified linar unit (ReLU) (dst for backward)
@ eltwise_logsigmoid
Elementwise: logsigmoid.
@ convolution_auto
Convolution algorithm that is chosen to be either direct or Winograd automatically.
@ binary_min
Binary min.
@ eltwise_exp_use_dst_for_bwd
Elementwise: exponent (dst for backward)
@ eltwise_round
Elementwise: round.
@ eltwise_sqrt_use_dst_for_bwd
Elementwise: square root (dst for backward)
@ pooling_avg_include_padding
Average pooling include padding.
@ reduction_norm_lp_sum
Reduction using norm_lp_sum operation.
@ reduction_mean
Reduction using mean operation.
@ deconvolution_winograd
Winograd deconvolution.
@ eltwise_log
Elementwise: natural logarithm.
@ undef
Undefined algorithm.
@ binary_lt
Binary less than.
@ reduction_sum
Reduction using sum operation.
@ library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
@ user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
@ backward
Backward propagation (with respect to all parameters).
@ backward_weights
Backward weights propagation.
@ forward_training
Forward data propagation (training mode).
@ forward_inference
Forward data propagation (inference mode).
@ forward_scoring
Forward data propagation, alias for dnnl::prop_kind::forward_inference.
@ forward
Forward data propagation, alias for dnnl::prop_kind::forward_training.
@ backward_data
Backward data propagation.
@ backward_bias
Backward bias propagation.
@ undef
Undefined propagation kind.
@ dnnl_scratchpad_mode_user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
Definition: dnnl_types.h:2292
@ dnnl_scratchpad_mode_library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
Definition: dnnl_types.h:2287
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization forward propagation primitive.
dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc, const dnnl_memory_desc_t *src1_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a binary primitive.
dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
Definition: dnnl.hpp:10993
status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
Definition: dnnl.hpp:11002
dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
Definition: dnnl.hpp:10985
dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(dnnl_primitive_desc_t *concat_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an out-of-place concatenation primitive.
dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise forward propagation primitive.
dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise backward propagation primitive.
dnnl_engine_kind_t
Kinds of engines.
Definition: dnnl_types.h:2216
dnnl_status_t DNNL_API dnnl_engine_get_kind(dnnl_engine_t engine, dnnl_engine_kind_t *kind)
Returns the kind of an engine.
dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine)
Destroys an engine.
dnnl_status_t DNNL_API dnnl_engine_create(dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index)
Creates an engine.
size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind)
Returns the number of engines of a particular kind.
dnnl_engine_kind_t convert_to_c(engine::kind akind)
Converts engine kind enum value from C++ API to C API type.
Definition: dnnl.hpp:977
@ dnnl_gpu
GPU engine.
Definition: dnnl_types.h:2222
@ dnnl_cpu
CPU engine.
Definition: dnnl_types.h:2220
@ dnnl_any_engine
An unspecified engine.
Definition: dnnl_types.h:2218
dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes descriptor for inner product forward propagation.
dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product weights gradient primitive.
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product backward propagation.
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for a layer normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for layer normalization forward propagation primitive.
dnnl_status_t DNNL_API dnnl_logsoftmax_forward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax forward propagation primitive.
dnnl_status_t DNNL_API dnnl_logsoftmax_backward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_matmul_desc_init(dnnl_matmul_desc_t *matmul_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a matrix multiplication descriptor.
dnnl_data_type_t
Data type specification.
Definition: dnnl_types.h:62
size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type)
Returns the size of data type.
dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(dnnl_memory_desc_t *memory_desc, const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims, const dnnl_dims_t offsets)
Initializes a memory descriptor for a region inside an area described by an existing memory descripto...
dnnl_format_tag_t
Memory format tag specification.
Definition: dnnl_types.h:164
dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, const int *permutation)
Initializes a memory descriptor by permuting axes in an existing one.
dnnl_status_t DNNL_API dnnl_memory_unmap_data(const_dnnl_memory_t memory, void *mapped_ptr)
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine, void *handle)
Creates a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_engine(const_dnnl_memory_t memory, dnnl_engine_t *engine)
Returns the engine of a memory object.
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, int ndims, const dnnl_dims_t dims)
Initializes a memory descriptor by reshaping an existing one.
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc)
Returns the memory descriptor for a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_data_handle(const_dnnl_memory_t memory, void **handle)
Returns memory object's data handle.
dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(dnnl_memory_t memory, void *handle, dnnl_stream_t stream)
Sets the underlying memory buffer.
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, const dnnl_dims_t strides)
Initializes a memory descriptor using dimensions and strides.
int64_t dnnl_dim_t
A type to describe tensor dimension.
Definition: dnnl_types.h:1399
dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory)
Destroys a memory object.
int DNNL_API dnnl_memory_desc_equal(const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs)
Compares two memory descriptors.
#define DNNL_MAX_NDIMS
Maximum number of dimensions a tensor can have.
Definition: dnnl_types.h:1367
dnnl_status_t DNNL_API dnnl_memory_map_data(const_dnnl_memory_t memory, void **mapped_ptr)
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
size_t DNNL_API dnnl_memory_desc_get_size(const dnnl_memory_desc_t *memory_desc)
Returns the size of a memory descriptor.
#define DNNL_MEMORY_ALLOCATE
Special pointer value that indicates that the library needs to allocate an underlying buffer for a me...
Definition: dnnl_types.h:1576
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, dnnl_format_tag_t tag)
Initializes a memory descriptor using dimensions and memory format tag.
@ dnnl_f16
16-bit/half-precision floating point.
Definition: dnnl_types.h:66
@ dnnl_bf16
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
Definition: dnnl_types.h:68
@ dnnl_f32
32-bit/single-precision floating point.
Definition: dnnl_types.h:70
@ dnnl_data_type_undef
Undefined data type, used for empty memory descriptors.
Definition: dnnl_types.h:64
@ dnnl_s8
8-bit signed integer.
Definition: dnnl_types.h:74
@ dnnl_s32
32-bit signed integer.
Definition: dnnl_types.h:72
@ dnnl_u8
8-bit unsigned integer.
Definition: dnnl_types.h:76
@ dnnl_abcdefhg
permuted 8D tensor
Definition: dnnl_types.h:216
@ dnnl_aBCdef2b4c2b
6D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:362
@ dnnl_abcdefghi
plain 9D tensor
Definition: dnnl_types.h:186
@ dnnl_acdeb
permuted 5D tensor
Definition: dnnl_types.h:199
@ dnnl_abcdefgh
plain 8D tensor
Definition: dnnl_types.h:185
@ dnnl_abcdefghikj
permuted 11D tensor
Definition: dnnl_types.h:219
@ dnnl_ab
plain 2D tensor
Definition: dnnl_types.h:178
@ dnnl_ABcd8b8a
4D tensor blocked by 1st and 2nd dimension with block size 8
Definition: dnnl_types.h:288
@ dnnl_cdba
permuted 4D tensor
Definition: dnnl_types.h:208
@ dnnl_abcdefghijkl
plain 12D tensor
Definition: dnnl_types.h:189
@ dnnl_aBcdef4b
6D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:364
@ dnnl_abcdegf
permuted 7D tensor
Definition: dnnl_types.h:215
@ dnnl_abcdfe
permuted 6D tensor
Definition: dnnl_types.h:214
@ dnnl_aBcd4b
4D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:263
@ dnnl_nCdhw16c
5D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcde16b
Definition: dnnl_types.h:707
@ dnnl_abcde
plain 5D tensor
Definition: dnnl_types.h:182
@ dnnl_decab
permuted 5D tensor
Definition: dnnl_types.h:211
@ dnnl_bca
permuted 3D tensor
Definition: dnnl_types.h:204
@ dnnl_aBcde4b
5D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:315
@ dnnl_aBc16b
3D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:229
@ dnnl_aBcdef16b
6D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:354
@ dnnl_aBCde2b4c2b
5D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:352
@ dnnl_aBc4b
3D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:235
@ dnnl_abcdefghijk
plain 11D tensor
Definition: dnnl_types.h:188
@ dnnl_bacde
permuted 5D tensor
Definition: dnnl_types.h:203
@ dnnl_aBcd16b
4D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:255
@ dnnl_cba
permuted 3D tensor
Definition: dnnl_types.h:207
@ dnnl_ba
permuted 2D tensor
Definition: dnnl_types.h:200
@ dnnl_ABcde2b8a4b
5D tensor blocked by 1st dimension with block size 8
Definition: dnnl_types.h:304
@ dnnl_abcd
plain 4D tensor
Definition: dnnl_types.h:180
@ dnnl_format_tag_undef
Undefined memory format tag.
Definition: dnnl_types.h:166
@ dnnl_nCdhw4c
5D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcde4b
Definition: dnnl_types.h:710
@ dnnl_defcab
permuted 6D tensor
Definition: dnnl_types.h:212
@ dnnl_abcdef
plain 6D tensor
Definition: dnnl_types.h:183
@ dnnl_nChw8c
4D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcd8b
Definition: dnnl_types.h:725
@ dnnl_a
plain 1D tensor
Definition: dnnl_types.h:177
@ dnnl_nChw4c
4D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcd4b
Definition: dnnl_types.h:722
@ dnnl_acbdef
permuted 6D tensor
Definition: dnnl_types.h:197
@ dnnl_acdb
permuted 4D tensor
Definition: dnnl_types.h:198
@ dnnl_aBcd8b
4D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:282
@ dnnl_aBc8b
3D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:245
@ dnnl_nCw4c
3D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBc4b
Definition: dnnl_types.h:734
@ dnnl_abcdefg
plain 7D tensor
Definition: dnnl_types.h:184
@ dnnl_aBcde8b
5D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:330
@ dnnl_nChw16c
4D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcd16b
Definition: dnnl_types.h:719
@ dnnl_abdfce
permuted 6D tensor
Definition: dnnl_types.h:424
@ dnnl_abdec
permuted 5D tensor
Definition: dnnl_types.h:194
@ dnnl_bacd
permuted 4D tensor
Definition: dnnl_types.h:202
@ dnnl_nCdhw8c
5D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcde8b
Definition: dnnl_types.h:713
@ dnnl_aBcde32b
5D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:313
@ dnnl_abced
permuted 5D tensor
Definition: dnnl_types.h:213
@ dnnl_bcda
permuted 4D tensor
Definition: dnnl_types.h:205
@ dnnl_acbde
permuted 5D tensor
Definition: dnnl_types.h:196
@ dnnl_aBCd2b4c2b
4D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:300
@ dnnl_abcdefgih
permuted 9D tensor
Definition: dnnl_types.h:217
@ dnnl_bcdea
permuted 5D tensor
Definition: dnnl_types.h:206
@ dnnl_abdefc
permuted 6D tensor
Definition: dnnl_types.h:425
@ dnnl_aBcde16b
5D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:306
@ dnnl_nCw8c
3D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBc8b
Definition: dnnl_types.h:737
@ dnnl_abdc
permuted 4D tensor
Definition: dnnl_types.h:193
@ dnnl_ABcde4b16a4b
5D tensor blocked by 1st dimension with block size 16
Definition: dnnl_types.h:302
@ dnnl_aBcd32b
4D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:261
@ dnnl_abcdefghijlk
permuted 12D tensor
Definition: dnnl_types.h:220
@ dnnl_format_tag_last
Just a sentinel, not real memory format tag.
Definition: dnnl_types.h:568
@ dnnl_abc
plain 3D tensor
Definition: dnnl_types.h:179
@ dnnl_bac
permuted 3D tensor
Definition: dnnl_types.h:201
@ dnnl_dcab
permuted 4D tensor
Definition: dnnl_types.h:209
@ dnnl_cdeba
permuted 5D tensor
Definition: dnnl_types.h:210
@ dnnl_acb
permuted 3D tensor
Definition: dnnl_types.h:195
@ dnnl_aBc32b
3D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:233
@ dnnl_abcdefghji
permuted 10D tensor
Definition: dnnl_types.h:218
@ dnnl_nCw16c
3D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBc16b
Definition: dnnl_types.h:731
@ dnnl_aBCdef2c8b4c
6D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:359
@ dnnl_abcdefghij
plain 10D tensor
Definition: dnnl_types.h:187
@ dnnl_format_tag_any
Undefined memory format tag.
Definition: dnnl_types.h:169
@ dnnl_blocked
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: dnnl_types.h:89
@ dnnl_format_kind_wino
Weights format used in 8bit Winograd convolution.
Definition: dnnl_types.h:91
@ dnnl_format_kind_any
Unspecified format kind.
Definition: dnnl_types.h:85
@ dnnl_format_kind_undef
Undefined memory format kind, used for empty memory descriptors.
Definition: dnnl_types.h:82
@ dnnl_format_kind_rnn_packed
Packed weights format used in RNN.
Definition: dnnl_types.h:93
dnnl_status_t DNNL_API dnnl_pooling_v2_backward_desc_init(dnnl_pooling_v2_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t dilation, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling v2 (pooling with dilation support) backward propagation primitiv...
dnnl_status_t DNNL_API dnnl_pooling_v2_forward_desc_init(dnnl_pooling_v2_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t dilation, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling v2 (pooling with dilation support) forward propagation primitive...
dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_prelu_forward_desc_init(dnnl_prelu_desc_t *prelu_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *weights_desc)
Initializes a descriptor for PReLU (leaky ReLU with trainable alpha parameter) forward propagation pr...
dnnl_status_t DNNL_API dnnl_prelu_backward_desc_init(dnnl_prelu_desc_t *prelu_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *diff_weights_desc)
Initializes a descriptor for PReLU (leaky ReLU with trainable alpha parameter) backward propagation p...
void set_primitive_cache_capacity(int capacity)
Sets a number of primitives that can be held in the primitive cache at a time.
Definition: dnnl.hpp:10970
dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity)
Sets a number of primitives that can be held in the primitive cache at a time.
dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity)
Returns the number of primitives that can be held in the primitive cache at the same time.
int get_primitive_cache_capacity()
Returns the number of primitives that can be held in the primitive cache at the same time.
Definition: dnnl.hpp:10962
dnnl_status_t DNNL_API dnnl_primitive_desc_query(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index, void *result)
Queries a primitive descriptor for various pieces of information.
#define DNNL_ARG_DST_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2387
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(dnnl_primitive_desc_iterator_t iterator)
Destroys a primitive descriptor iterator.
#define DNNL_ARG_WEIGHTS_LAYER
A special mnemonic for RNN weights applied to the layer input.
Definition: dnnl_types.h:2405
#define DNNL_ARG_DIFF_BIAS
Gradient (diff) of the bias tensor argument.
Definition: dnnl_types.h:2512
#define DNNL_ARG_DIFF_SRC_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition: dnnl_types.h:2458
#define DNNL_ARG_DIFF_SRC_LAYER
A special mnemonic for gradient (diff) of RNN input vector.
Definition: dnnl_types.h:2446
#define DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE
A special mnemonic for diff of RNN weights applied to the peephole weights.
Definition: dnnl_types.h:2503
#define DNNL_ARG_WEIGHTS_PROJECTION
A special mnemonic for RNN weights applied to the projection weights.
Definition: dnnl_types.h:2423
dnnl_normalization_flags_t
Flags for normalization primitives.
Definition: dnnl_types.h:1307
#define DNNL_ARG_DIFF_WEIGHTS_PROJECTION
A special mnemonic for diff of RNN weights applied to the projection weights.
Definition: dnnl_types.h:2509
const dnnl_memory_desc_t DNNL_API * dnnl_primitive_desc_query_md(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index)
Queries primitive descriptor for a memory descriptor.
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(const_dnnl_primitive_desc_t primitive_desc, const_dnnl_primitive_attr_t *attr)
Returns a constant reference to the attributes of a primitive descriptor.
#define DNNL_ARG_DIFF_WEIGHTS_ITER
A special mnemonic for diff of RNN weights applied to the recurrent input.
Definition: dnnl_types.h:2497
#define DNNL_ARG_DIFF_SRC_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2452
#define DNNL_ARG_DIFF_DST_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition: dnnl_types.h:2479
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args)
Executes a primitive.
#define DNNL_ARG_WEIGHTS_ITER
A special mnemonic for RNN weights applied to the recurrent input.
Definition: dnnl_types.h:2411
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(dnnl_primitive_desc_iterator_t iterator)
Advances the primitive descriptor iterator to point to the next available implementation.
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(dnnl_primitive_desc_t primitive_desc)
Destroys a primitive descriptor.
const void * const_dnnl_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: dnnl_types.h:1588
const_dnnl_primitive_desc_t get_primitive_desc() const
Returns the C API primitive descriptor of the underlying C API primitive.
Definition: dnnl.hpp:368
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(const_dnnl_primitive_t primitive, const_dnnl_primitive_desc_t *primitive_desc)
Retrieves a constant reference to the primitive descriptor of a given primitive.
#define DNNL_ARG_DST_ITER_C
A special mnemonic for LSTM output recurrent cell state vector.
Definition: dnnl_types.h:2393
#define DNNL_ARG_SRC_ITER_C
A special mnemonic for RNN input recurrent cell state vector.
Definition: dnnl_types.h:2370
query
Primitive descriptor query specification.
Definition: dnnl.hpp:761
#define DNNL_ARG_FROM
A special mnemonic for reorder source argument.
Definition: dnnl_types.h:2358
dnnl_alg_kind_t
Kinds of algorithms.
Definition: dnnl_types.h:1157
dnnl_primitive_kind_t
Kinds of primitives.
Definition: dnnl_types.h:1103
dnnl_query_t
Primitive descriptor query specification.
Definition: dnnl_types.h:2583
dnnl_primitive_kind_t convert_to_c(primitive::kind akind)
Converts primitive kind enum value from C++ API to C API type.
Definition: dnnl.hpp:364
struct dnnl_primitive_desc * dnnl_primitive_desc_t
A primitive descriptor handle.
Definition: dnnl_types.h:2259
#define DNNL_ARG_WEIGHTS_PEEPHOLE
A special mnemonic for RNN weights applied to the peephole weights.
Definition: dnnl_types.h:2417
kind get_kind() const
Returns the kind of the primitive.
Definition: dnnl.hpp:375
#define DNNL_ARG_SRC_LAYER
A special mnemonic for RNN input vector.
Definition: dnnl_types.h:2355
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive)
Destroys a primitive.
#define DNNL_ARG_DIFF_WEIGHTS_LAYER
A special mnemonic for diff of RNN weights applied to the layer input.
Definition: dnnl_types.h:2491
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, const_dnnl_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator.
#define DNNL_ARG_DST_LAYER
A special mnemonic for RNN output vector. An alias for DNNL_ARG_DST_0.
Definition: dnnl_types.h:2381
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc)
Creates a primitive.
#define DNNL_ARG_BIAS
Bias tensor argument.
Definition: dnnl_types.h:2426
normalization_flags
Flags for normalization primitives.
Definition: dnnl.hpp:631
#define DNNL_ARG_DIFF_DST_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2473
dnnl_prop_kind_t
Kinds of propagation.
Definition: dnnl_types.h:1076
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(dnnl_primitive_desc_t *primitive_desc, const_dnnl_primitive_desc_t existing_primitive_desc)
Clones a primitive descriptor.
#define DNNL_ARG_SRC_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2364
dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(const_dnnl_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor from a primitive descriptor iterator.
#define DNNL_ARG_TO
A special mnemonic for reorder destination argument.
Definition: dnnl_types.h:2379
#define DNNL_ARG_DIFF_DST_LAYER
A special mnemonic for gradient (diff) of RNN output vector.
Definition: dnnl_types.h:2467
@ dnnl_fuse_norm_relu
Fuse with ReLU.
Definition: dnnl_types.h:1355
@ dnnl_normalization_flags_none
Use no normalization flags.
Definition: dnnl_types.h:1316
@ dnnl_use_scaleshift
Use scale and shift parameters.
Definition: dnnl_types.h:1342
@ dnnl_use_global_stats
Use global statistics.
Definition: dnnl_types.h:1329
@ batch_normalization_d
batch normalization descriptor
@ weights_md
weights memory descriptor desc
@ memory_consumption_s64
memory required for scratchpad (bytes)
@ shuffle_d
shuffle descriptor
@ deconvolution_d
deconvolution descriptor
@ impl_info_str
implementation name
@ diff_weights_md
weights gradient (diff) memory desc
@ workspace_md
workspace memory desc
@ reduction_d
reduction descriptor
@ eltwise_d
eltwise descriptor
@ matmul_d
matmul descriptor
@ rnn_d
rnn descriptor
@ softmax_d
softmax descriptor
@ num_of_outputs_s32
number of outputs expected
@ primitive_kind
primitive kind
@ dst_md
destination memory desc
@ scratchpad_engine
scratchpad engine
@ reorder_src_engine
reorder source engine
@ op_d
operation descriptor
@ layer_normalization_d
layer normalization descriptor
@ logsoftmax_d
logsoftmax descriptor
@ pooling_d
pooling descriptor
@ num_of_inputs_s32
number of inputs expected
@ diff_src_md
source gradient (diff) memory desc
@ src_md
source memory desc
@ scratchpad_md
scratchpad memory desc
@ reorder_dst_engine
reorder destination engine
@ engine
execution engine
@ convolution_d
convolution descriptor
@ time_estimate_f64
runtime estimation (seconds), unimplemented
@ binary_d
binary descriptor
@ diff_dst_md
destination gradient (diff) memory desc
@ exec_arg_md
memory desc of an execute argument
@ inner_product_d
inner product descriptor
@ lrn_d
lrn descriptor
@ undef
no query
@ resampling_d
resampling descriptor
@ dnnl_pooling_avg_exclude_padding
Average pooling exclude padding.
Definition: dnnl_types.h:1237
@ dnnl_eltwise_clip
Eltwise: clip.
Definition: dnnl_types.h:1203
@ dnnl_eltwise_tanh_use_dst_for_bwd
Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
Definition: dnnl_types.h:1221
@ dnnl_eltwise_logsigmoid
Eltwise: logsigmoid.
Definition: dnnl_types.h:1213
@ dnnl_pooling_avg
Average pooling (alias for dnnl_pooling_avg_exclude_padding)
Definition: dnnl_types.h:1239
@ dnnl_eltwise_gelu_tanh
Eltwise: gelu.
Definition: dnnl_types.h:1195
@ dnnl_resampling_linear
Linear Resampling Method.
Definition: dnnl_types.h:1285
@ dnnl_eltwise_sqrt
Eltwise: square root.
Definition: dnnl_types.h:1180
@ dnnl_binary_min
Binary min.
Definition: dnnl_types.h:1265
@ dnnl_reduction_norm_lp_sum
Reduction using lp norm.
Definition: dnnl_types.h:1299
@ dnnl_eltwise_abs
Eltwise: abs.
Definition: dnnl_types.h:1178
@ dnnl_reduction_norm_lp_power_p_max
Reduction using lp norm without final pth-root.
Definition: dnnl_types.h:1301
@ dnnl_reduction_min
Reduction using min.
Definition: dnnl_types.h:1289
@ dnnl_binary_ne
Binary not equal.
Definition: dnnl_types.h:1281
@ dnnl_eltwise_sqrt_use_dst_for_bwd
Eltwise: square root (dst for backward)
Definition: dnnl_types.h:1225
@ dnnl_eltwise_exp
Eltwise: exponent.
Definition: dnnl_types.h:1190
@ dnnl_eltwise_square
Eltwise: square.
Definition: dnnl_types.h:1176
@ dnnl_eltwise_gelu
Eltwise: tanh-based gelu (alias for dnnl_eltwise_gelu_tanh)
Definition: dnnl_types.h:1197
@ dnnl_convolution_winograd
Winograd convolution.
Definition: dnnl_types.h:1162
@ dnnl_eltwise_clip_v2_use_dst_for_bwd
Eltwise: clip version 2 (dst for backward)
Definition: dnnl_types.h:1231
@ dnnl_lrn_across_channels
Local response normalization (LRN) across multiple channels.
Definition: dnnl_types.h:1241
@ dnnl_binary_sub
Binary sub.
Definition: dnnl_types.h:1269
@ dnnl_deconvolution_direct
Direct deconvolution.
Definition: dnnl_types.h:1166
@ dnnl_binary_eq
Binary equal.
Definition: dnnl_types.h:1279
@ dnnl_eltwise_relu
Eltwise: ReLU.
Definition: dnnl_types.h:1170
@ dnnl_convolution_auto
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: dnnl_types.h:1164
@ dnnl_eltwise_swish
Eltwise: swish.
Definition: dnnl_types.h:1199
@ dnnl_vanilla_rnn
RNN cell.
Definition: dnnl_types.h:1245
@ dnnl_eltwise_gelu_erf
Eltwise: erf-based gelu.
Definition: dnnl_types.h:1209
@ dnnl_vanilla_lstm
LSTM cell.
Definition: dnnl_types.h:1247
@ dnnl_eltwise_elu
Eltwise: exponential linear unit (elu)
Definition: dnnl_types.h:1174
@ dnnl_vanilla_gru
GRU cell.
Definition: dnnl_types.h:1249
@ dnnl_lbr_gru
GRU cell with linear before reset.
Definition: dnnl_types.h:1257
@ dnnl_eltwise_tanh
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: dnnl_types.h:1172
@ dnnl_convolution_direct
Direct convolution.
Definition: dnnl_types.h:1160
@ dnnl_eltwise_soft_relu
Eltwise: soft_relu.
Definition: dnnl_types.h:1186
@ dnnl_binary_ge
Binary greater or equal.
Definition: dnnl_types.h:1271
@ dnnl_eltwise_log
Eltwise: natural logarithm.
Definition: dnnl_types.h:1201
@ dnnl_eltwise_clip_v2
Eltwise: clip version 2.
Definition: dnnl_types.h:1205
@ dnnl_lrn_within_channel
LRN within a single channel.
Definition: dnnl_types.h:1243
@ dnnl_eltwise_elu_use_dst_for_bwd
Eltwise: exponential linear unit (elu) (dst for backward)
Definition: dnnl_types.h:1223
@ dnnl_deconvolution_winograd
Winograd deconvolution.
Definition: dnnl_types.h:1168
@ dnnl_eltwise_hardswish
Eltwise: hardswish.
Definition: dnnl_types.h:1217
@ dnnl_reduction_mul
Reduction using mul.
Definition: dnnl_types.h:1293
@ dnnl_eltwise_pow
Eltwise: pow.
Definition: dnnl_types.h:1207
@ dnnl_eltwise_relu_use_dst_for_bwd
Eltwise: ReLU (dst for backward)
Definition: dnnl_types.h:1219
@ dnnl_binary_gt
Binary greater than.
Definition: dnnl_types.h:1273
@ dnnl_reduction_max
Reduction using max.
Definition: dnnl_types.h:1287
@ dnnl_eltwise_logistic
Eltwise: logistic.
Definition: dnnl_types.h:1188
@ dnnl_binary_lt
Binary less than.
Definition: dnnl_types.h:1277
@ dnnl_pooling_avg_include_padding
Average pooling include padding.
Definition: dnnl_types.h:1235
@ dnnl_reduction_mean
Reduction using mean.
Definition: dnnl_types.h:1295
@ dnnl_binary_le
Binary less or equal.
Definition: dnnl_types.h:1275
@ dnnl_pooling_max
Max pooling.
Definition: dnnl_types.h:1233
@ dnnl_eltwise_logistic_use_dst_for_bwd
Eltwise: logistic (dst for backward)
Definition: dnnl_types.h:1227
@ dnnl_binary_add
Binary add.
Definition: dnnl_types.h:1259
@ dnnl_binary_div
Binary div.
Definition: dnnl_types.h:1267
@ dnnl_reduction_norm_lp_max
Reduction using lp norm.
Definition: dnnl_types.h:1297
@ dnnl_reduction_norm_lp_power_p_sum
Reduction using lp norm without final pth-root.
Definition: dnnl_types.h:1303
@ dnnl_eltwise_round
Eltwise: round.
Definition: dnnl_types.h:1211
@ dnnl_binary_mul
Binary mul.
Definition: dnnl_types.h:1261
@ dnnl_eltwise_mish
Eltwise: mish.
Definition: dnnl_types.h:1215
@ dnnl_reduction_sum
Reduction using sum.
Definition: dnnl_types.h:1291
@ dnnl_eltwise_exp_use_dst_for_bwd
Eltwise: exp (dst for backward)
Definition: dnnl_types.h:1229
@ dnnl_eltwise_bounded_relu
Eltwise: bounded_relu.
Definition: dnnl_types.h:1184
@ dnnl_eltwise_linear
Eltwise: linear.
Definition: dnnl_types.h:1182
@ dnnl_resampling_nearest
Nearest Neighbor Resampling Method.
Definition: dnnl_types.h:1283
@ dnnl_binary_max
Binary max.
Definition: dnnl_types.h:1263
@ dnnl_binary
A binary primitive.
Definition: dnnl_types.h:1137
@ dnnl_concat
A (out-of-place) concat primitive.
Definition: dnnl_types.h:1111
@ dnnl_reorder
A reorder primitive.
Definition: dnnl_types.h:1107
@ dnnl_convolution
A convolution primitive.
Definition: dnnl_types.h:1115
@ dnnl_inner_product
An inner product primitive.
Definition: dnnl_types.h:1131
@ dnnl_resampling
A resampling primitive.
Definition: dnnl_types.h:1143
@ dnnl_batch_normalization
A batch normalization primitive.
Definition: dnnl_types.h:1127
@ dnnl_undefined_primitive
Undefined primitive.
Definition: dnnl_types.h:1105
@ dnnl_sum
A sum primitive.
Definition: dnnl_types.h:1113
@ dnnl_pooling_v2
A pooling version 2 primitive (pooling with dilation support).
Definition: dnnl_types.h:1145
@ dnnl_layer_normalization
A layer normalization primitive.
Definition: dnnl_types.h:1129
@ dnnl_prelu
A PReLU primitive.
Definition: dnnl_types.h:1149
@ dnnl_eltwise
An element-wise primitive.
Definition: dnnl_types.h:1119
@ dnnl_matmul
A matrix multiplication primitive.
Definition: dnnl_types.h:1141
@ dnnl_shuffle
A shuffle primitive.
Definition: dnnl_types.h:1109
@ dnnl_logsoftmax
A logsoftmax primitive.
Definition: dnnl_types.h:1139
@ dnnl_pooling
A pooling primitive.
Definition: dnnl_types.h:1123
@ dnnl_deconvolution
A deconvolution primitive.
Definition: dnnl_types.h:1117
@ dnnl_softmax
A softmax primitive.
Definition: dnnl_types.h:1121
@ dnnl_rnn
A rnn primitive.
Definition: dnnl_types.h:1133
@ dnnl_reduction
A reduction primitive.
Definition: dnnl_types.h:1147
@ dnnl_lrn
An LRN primitive.
Definition: dnnl_types.h:1125
@ dnnl_query_resampling_d
resampling descriptor
Definition: dnnl_types.h:2626
@ dnnl_query_num_of_outputs_s32
number of outputs expected
Definition: dnnl_types.h:2590
@ dnnl_query_convolution_d
convolution descriptor
Definition: dnnl_types.h:2611
@ dnnl_query_weights_md
weights memory descriptor desc
Definition: dnnl_types.h:2635
@ dnnl_query_src_md
source memory desc
Definition: dnnl_types.h:2633
@ dnnl_query_softmax_d
softmax descriptor
Definition: dnnl_types.h:2615
@ dnnl_query_binary_d
binary descriptor
Definition: dnnl_types.h:2623
@ dnnl_query_workspace_md
workspace memory desc
Definition: dnnl_types.h:2639
@ dnnl_query_matmul_d
matrix multiplication (matmul) descriptor
Definition: dnnl_types.h:2625
@ dnnl_query_num_of_inputs_s32
number of inputs expected
Definition: dnnl_types.h:2589
@ dnnl_query_op_d
op descriptor
Definition: dnnl_types.h:2610
@ dnnl_query_diff_src_md
source gradient memory desc
Definition: dnnl_types.h:2634
@ dnnl_query_scratchpad_md
scratchpad memory desc
Definition: dnnl_types.h:2640
@ dnnl_query_shuffle_d
shuffle descriptor
Definition: dnnl_types.h:2613
@ dnnl_query_memory_consumption_s64
memory consumption – extra
Definition: dnnl_types.h:2593
@ dnnl_query_inner_product_d
inner product descriptor
Definition: dnnl_types.h:2620
@ dnnl_query_deconvolution_d
deconvolution descriptor
Definition: dnnl_types.h:2612
@ dnnl_query_primitive_kind
primitive kind
Definition: dnnl_types.h:2587
@ dnnl_query_batch_normalization_d
batch normalization descriptor
Definition: dnnl_types.h:2618
@ dnnl_query_impl_info_str
for creating scratchpad memory
Definition: dnnl_types.h:2601
@ dnnl_query_time_estimate_f64
runtime estimation (seconds)
Definition: dnnl_types.h:2592
@ dnnl_query_eltwise_d
eltwise descriptor
Definition: dnnl_types.h:2614
@ dnnl_query_diff_weights_md
weights grad. memory desc
Definition: dnnl_types.h:2636
@ dnnl_query_reduction_d
reduction descriptor
Definition: dnnl_types.h:2628
@ dnnl_query_reorder_dst_engine
destination engine
Definition: dnnl_types.h:2604
@ dnnl_query_reorder_src_engine
source engine
Definition: dnnl_types.h:2603
@ dnnl_query_scratchpad_engine
(scratch) memory, additional to all inputs and outputs memory (bytes)
Definition: dnnl_types.h:2598
@ dnnl_query_undef
no query
Definition: dnnl_types.h:2584
@ dnnl_query_prop_kind
propagation kind
Definition: dnnl_types.h:2606
@ dnnl_query_pooling_d
pooling descriptor
Definition: dnnl_types.h:2616
@ dnnl_query_exec_arg_md
memory desc of an execute argument
Definition: dnnl_types.h:2641
@ dnnl_query_engine
execution engine
Definition: dnnl_types.h:2586
@ dnnl_query_rnn_d
rnn descriptor
Definition: dnnl_types.h:2621
@ dnnl_query_layer_normalization_d
layer normalization descriptor
Definition: dnnl_types.h:2619
@ dnnl_query_lrn_d
lrn descriptor
Definition: dnnl_types.h:2617
@ dnnl_query_dst_md
destination memory desc
Definition: dnnl_types.h:2637
@ dnnl_query_diff_dst_md
destination grad. memory desc
Definition: dnnl_types.h:2638
@ dnnl_query_logsoftmax_d
logsoftmax descriptor
Definition: dnnl_types.h:2624
@ use_scale_shift
Use scale and shift parameters.
@ none
Use no normalization flags.
@ fuse_norm_relu
Fuse normalization with ReLU.
@ use_global_stats
Use global statistics.
@ dnnl_backward_weights
Backward weights propagation.
Definition: dnnl_types.h:1096
@ dnnl_forward_inference
Forward data propagation (inference mode).
Definition: dnnl_types.h:1086
@ dnnl_backward
Backward propagation (with respect to all parameters).
Definition: dnnl_types.h:1092
@ dnnl_backward_data
Backward data propagation.
Definition: dnnl_types.h:1094
@ dnnl_prop_kind_undef
Undefined propagation type.
Definition: dnnl_types.h:1079
@ dnnl_forward
Forward data propagation (alias for dnnl_forward_training).
Definition: dnnl_types.h:1090
@ dnnl_forward_training
Forward data propagation (training mode).
Definition: dnnl_types.h:1082
@ dnnl_backward_bias
Backward bias propagation.
Definition: dnnl_types.h:1098
@ dnnl_forward_scoring
Forward data propagation (alias for dnnl_forward_inference).
Definition: dnnl_types.h:1088
dnnl_status_t DNNL_API dnnl_reduction_desc_init(dnnl_reduction_desc_t *desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, float p, float eps)
Initializes a descriptor for a reduction primitive.
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(dnnl_primitive_desc_t *reorder_primitive_desc, const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine, const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine, const_dnnl_primitive_attr_t attr)
Creates a primitive descriptor for a reorder primitive.
dnnl_status_t DNNL_API dnnl_resampling_backward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes a descriptor for resampling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_resampling_forward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a resampling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU backward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU forward propagation primitive.
rnn_direction
A direction of RNN primitive execution.
Definition: dnnl.hpp:728
dnnl_rnn_flags_t
Flags for RNN cell.
Definition: dnnl_types.h:2000
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM backward propagation primitive.
dnnl_rnn_direction_t
A direction of RNN primitive execution.
Definition: dnnl_types.h:2006
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_weights_projection_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or with out recurrent project...
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for LSTM forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or without recurrent projecti...
rnn_flags
RNN cell flags.
Definition: dnnl.hpp:674
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) forward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU backward propagation primitive.
@ unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
@ unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
@ bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
@ unidirectional
Alias for dnnl::rnn_direction::unidirectional_left2right.
@ bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
@ dnnl_rnn_flags_undef
Undefined RNN flags.
Definition: dnnl_types.h:2002
@ dnnl_unidirectional
Alias for dnnl_unidirectional_left2right.
Definition: dnnl_types.h:2018
@ dnnl_bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
Definition: dnnl_types.h:2013
@ dnnl_bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
Definition: dnnl_types.h:2016
@ dnnl_unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
Definition: dnnl_types.h:2008
@ dnnl_unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
Definition: dnnl_types.h:2010
@ undef
Undefined RNN flags.
dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable)
Configures dumping of JIT-generated code.
status set_max_cpu_isa(cpu_isa isa)
Sets the maximal ISA the library can dispatch to on the CPU.
Definition: dnnl.hpp:10923
dnnl_status_t DNNL_API dnnl_set_verbose(int level)
Configures verbose output to stdout.
status set_jit_dump(int enable)
Configures dumping of JIT-generated code.
Definition: dnnl.hpp:10882
status set_cpu_isa_hints(cpu_isa_hints isa_hints)
Sets the hints flag for the CPU ISA.
Definition: dnnl.hpp:10942
dnnl_cpu_isa_t
CPU instruction set flags.
Definition: dnnl_types.h:2733
status set_verbose(int level)
Configures verbose output to stdout.
Definition: dnnl.hpp:10872
cpu_isa get_effective_cpu_isa()
Gets the maximal ISA the library can dispatch to on the CPU.
Definition: dnnl.hpp:10929
dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa)
Sets the maximal ISA the library can dispatch to on the CPU.
dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
status set_jit_profiling_jitdumpdir(const std::string &dir)
Sets JIT dump output path.
Definition: dnnl.hpp:10892
const dnnl_version_t DNNL_API * dnnl_version(void)
Returns library version information.
status
Status values returned by the library functions.
Definition: dnnl.hpp:10854
cpu_isa_hints get_cpu_isa_hints()
Gets the ISA specific hints that library can follow.
Definition: dnnl.hpp:10948
status set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
Definition: dnnl.hpp:10887
const version_t * version()
Returns library version information.
Definition: dnnl.hpp:10877
cpu_isa
CPU instruction set flags.
Definition: dnnl.hpp:10897
dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void)
Gets the maximal ISA the library can dispatch to on the CPU.
dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints)
Sets the hints flag for the CPU ISA.
dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void)
Gets the ISA specific hints that library can follow.
dnnl_cpu_isa_hints_t
CPU ISA hints flags.
Definition: dnnl_types.h:2779
cpu_isa_hints
CPU ISA hints flags.
Definition: dnnl.hpp:10934
dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir)
Sets JIT dump output path.
@ dnnl_cpu_isa_avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
Definition: dnnl_types.h:2748
@ dnnl_cpu_isa_avx
Intel Advanced Vector Extensions (Intel AVX)
Definition: dnnl_types.h:2741
@ dnnl_cpu_isa_avx512_core_amx
Intel AVX-512, Intel DL Boost and bfloat16 support and Intel AMX with 8-bit integer and bfloat16 supp...
Definition: dnnl_types.h:2771
@ dnnl_cpu_isa_avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
Definition: dnnl_types.h:2761
@ dnnl_cpu_isa_avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
Definition: dnnl_types.h:2744
@ dnnl_cpu_isa_all
Any ISA (excepting those listed as initial support)
Definition: dnnl_types.h:2735
@ dnnl_cpu_isa_avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
Definition: dnnl_types.h:2756
@ dnnl_cpu_isa_sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
Definition: dnnl_types.h:2738
@ dnnl_cpu_isa_avx2_vnni
Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support.
Definition: dnnl_types.h:2774
@ dnnl_cpu_isa_avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
Definition: dnnl_types.h:2766
@ dnnl_cpu_isa_avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
Definition: dnnl_types.h:2752
@ not_required
Queried element is not required for given primitive.
@ invalid_arguments
The operation failed because of incorrect function arguments.
@ success
The operation was successful.
@ unimplemented
The operation failed because requested functionality is not implemented.
@ runtime_error
Primitive or engine failed on execution.
@ out_of_memory
The operation failed due to an out-of-memory condition.
@ iterator_ends
Primitive iterator passed over last primitive descriptor.
@ avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
@ avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
@ avx2_vnni
Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support.
@ avx
Intel Advanced Vector Extensions (Intel AVX)
@ all
Any ISA (excepting those listed as initial support)
@ avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
@ avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
@ sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
@ avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
@ avx512_core_amx
Intel AVX-512, Intel DL Boost and bfloat16 support and Intel AMX with 8-bit integer and bfloat16 supp...
@ avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
@ dnnl_cpu_isa_no_hints
No hints (use default features)
Definition: dnnl_types.h:2781
@ dnnl_cpu_isa_prefer_ymm
Prefer to exclusively use Ymm registers for computations.
Definition: dnnl_types.h:2784
@ no_hints
No hints (use default features)
@ prefer_ymm
Prefer to exclusively use Ymm registers for computations.
dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle forward propagation primitive.
dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, const dnnl_memory_desc_t *diff_data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(dnnl_softmax_desc_t *softmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax forward propagation primitive.
dnnl_stream_flags_t
Stream flags.
Definition: dnnl_types.h:2655
dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream)
Waits for all primitives in the execution stream to finish computations.
dnnl_status_t DNNL_API dnnl_stream_get_engine(const_dnnl_stream_t stream, dnnl_engine_t *engine)
Returns the engine of a stream object.
dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream)
Destroys an execution stream.
dnnl_status_t DNNL_API dnnl_stream_create(dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags)
Creates an execution stream.
@ dnnl_stream_out_of_order
Out-of-order execution.
Definition: dnnl_types.h:2659
@ dnnl_stream_default_flags
Default stream configuration.
Definition: dnnl_types.h:2661
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(dnnl_primitive_desc_t *sum_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, const float *scales, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an (out-of-place) sum primitive.
dnnl_status_t
Status values returned by the library functions.
Definition: dnnl_types.h:39
@ dnnl_iterator_ends
Primitive iterator passed over last primitive descriptor.
Definition: dnnl_types.h:49
@ dnnl_runtime_error
Primitive or engine failed on execution.
Definition: dnnl_types.h:51
@ dnnl_unimplemented
The operation failed because requested functionality is not implemented.
Definition: dnnl_types.h:47
@ dnnl_out_of_memory
The operation failed due to an out-of-memory condition.
Definition: dnnl_types.h:43
@ dnnl_success
The operation was successful.
Definition: dnnl_types.h:41
@ dnnl_invalid_arguments
The operation failed because of incorrect function arguments.
Definition: dnnl_types.h:45
@ dnnl_not_required
Queried element is not required for given primitive.
Definition: dnnl_types.h:53
oneDNN namespace
Definition: dnnl.hpp:74
oneAPI namespace
Definition: dnnl.hpp:11046
C API.
Descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6670
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for backward propagation.
Definition: dnnl.hpp:6685
Primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6699
primitive_desc(const desc &adesc, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6716
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6759
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization backward propagation primitive from a C A...
Definition: dnnl.hpp:6749
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6784
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6765
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6779
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6736
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6762
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6756
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:6768
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:6771
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6776
Batch normalization backward propagation primitive.
Definition: dnnl.hpp:6668
batch_normalization_backward()=default
Default constructor. Produces an empty object.
batch_normalization_backward(const primitive_desc &pd)
Constructs a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6793
Descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6541
desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for forward propagation.
Definition: dnnl.hpp:6558
Primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6571
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6619
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6625
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6632
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6585
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6601
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6628
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6636
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization forward propagation primitive from a C AP...
Definition: dnnl.hpp:6612
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6622
Batch normalization forward propagation primitive.
Definition: dnnl.hpp:6539
batch_normalization_forward()=default
Default constructor. Produces an empty object.
batch_normalization_forward(const primitive_desc &pd)
Constructs a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6664
Descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9819
desc()=default
Default constructor. Produces an empty object.
dnnl_binary_desc_t data
Underlying C operation descriptor.
Definition: dnnl.hpp:9821
desc(algorithm aalgorithm, const memory::desc &src0, const memory::desc &src1, const memory::desc &dst)
Constructs a descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9833
Primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9844
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9872
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:9885
memory::desc src0_desc() const
Returns the memory descriptor for source #0.
Definition: dnnl.hpp:9888
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a binary primitive from a C API primitive descriptor that must ...
Definition: dnnl.hpp:9881
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:9894
memory::desc src1_desc() const
Returns the memory descriptor for source #1.
Definition: dnnl.hpp:9891
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9857
Elementwise binary operator primitive.
Definition: dnnl.hpp:9817
binary()=default
Default constructor. Produces an empty object.
binary(const primitive_desc &pd)
Constructs an elementwise binary operation primitive.
Definition: dnnl.hpp:9903
Primitive descriptor for a concat primitive.
Definition: dnnl.hpp:3760
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3829
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for concat primitive from a C API primitive descriptor which must h...
Definition: dnnl.hpp:3822
primitive_desc(const memory::desc &dst, int concat_dimension, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition: dnnl.hpp:3776
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(int concat_dimension, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition: dnnl.hpp:3803
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3826
Tensor concatenation (concat) primitive.
Definition: dnnl.hpp:3758
concat()=default
Default constructor. Produces an empty object.
concat(const primitive_desc &pd)
Constructs a concatenation primitive.
Definition: dnnl.hpp:3837
Descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4301
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for dilated convolution backward propagation primitive.
Definition: dnnl.hpp:4372
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4329
Primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4393
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4451
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:4454
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution backward propagation primitive from a C API primi...
Definition: dnnl.hpp:4443
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4410
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:4448
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4430
Convolution backward propagation primitive.
Definition: dnnl.hpp:4298
convolution_backward_data()=default
Default constructor. Produces an empty object.
convolution_backward_data(const primitive_desc &pd)
Constructs a convolution backward propagation primitive.
Definition: dnnl.hpp:4463
Descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4469
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive without bias.
Definition: dnnl.hpp:4542
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive with bias.
Definition: dnnl.hpp:4587
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive without bias.
Definition: dnnl.hpp:4634
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive with bias.
Definition: dnnl.hpp:4499
Primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4655
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:4722
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:4711
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4690
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution weights gradient primitive from a C API primitive...
Definition: dnnl.hpp:4703
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4708
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4671
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:4716
Convolution weights gradient primitive.
Definition: dnnl.hpp:4467
convolution_backward_weights()=default
Default constructor. Produces an empty object.
convolution_backward_weights(const primitive_desc &pd)
Constructs a convolution weights gradient primitive.
Definition: dnnl.hpp:4733
Descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4028
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive with bias.
Definition: dnnl.hpp:4156
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive without bias.
Definition: dnnl.hpp:4107
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive without bias.
Definition: dnnl.hpp:4205
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive with bias.
Definition: dnnl.hpp:4061
Primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4226
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4240
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4256
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4273
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution forward propagation primitive from a C API primit...
Definition: dnnl.hpp:4267
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:4285
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4276
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:4279
Convolution forward propagation primitive.
Definition: dnnl.hpp:4026
convolution_forward(const primitive_desc &pd)
Constructs a convolution forward propagation primitive.
Definition: dnnl.hpp:4294
convolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5014
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution backward propagation primitive.
Definition: dnnl.hpp:5083
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5041
Primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5104
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5141
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:5162
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5165
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:5159
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5121
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution backward propagation primitive from a C API pri...
Definition: dnnl.hpp:5154
Deconvolution backward propagation primitive.
Definition: dnnl.hpp:5012
deconvolution_backward_data()=default
Default constructor. Produces an empty object.
deconvolution_backward_data(const primitive_desc &pd)
Constructs a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5174
Descriptor for a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5180
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive without bias.
Definition: dnnl.hpp:5341
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive with bias.
Definition: dnnl.hpp:5295
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive without bias.
Definition: dnnl.hpp:5251
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive with bias.
Definition: dnnl.hpp:5209
Primitive descriptor for a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5362
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5417
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5425
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution weights gradient primitive from a C API primiti...
Definition: dnnl.hpp:5412
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition: dnnl.hpp:5399
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition: dnnl.hpp:5379
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:5420
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:5428
primitive_desc()=default
Default constructor. Produces an empty object.
Deconvolution weights gradient primitive.
Definition: dnnl.hpp:5178
deconvolution_backward_weights()=default
Default constructor. Produces an empty object.
deconvolution_backward_weights(const primitive_desc &pd)
Constructs a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5439
Descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4749
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive with bias.
Definition: dnnl.hpp:4874
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive without bias.
Definition: dnnl.hpp:4826
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive with bias.
Definition: dnnl.hpp:4781
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive without bias.
Definition: dnnl.hpp:4922
Primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4943
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution forward propagation primitive from a C API prim...
Definition: dnnl.hpp:4984
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:4996
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4990
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4973
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4957
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:4999
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4993
Deconvolution forward propagation primitive.
Definition: dnnl.hpp:4747
deconvolution_forward(const primitive_desc &pd)
Constructs a deconvolution forward propagation primitive.
Definition: dnnl.hpp:5008
deconvolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6008
desc(algorithm aalgorithm, const memory::desc &diff_data_desc, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6022
Primitive descriptor for eltwise backward propagation.
Definition: dnnl.hpp:6035
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6093
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6052
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6072
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6090
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:6096
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise backward propagation primitive from a C API primitiv...
Definition: dnnl.hpp:6085
Elementwise unary operation backward propagation primitive.
Definition: dnnl.hpp:6006
eltwise_backward()=default
Default constructor. Produces an empty object.
eltwise_backward(const primitive_desc &pd)
Constructs an eltwise backward propagation primitive.
Definition: dnnl.hpp:6105
Descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5915
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5930
Primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5943
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5973
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5993
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5990
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5957
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise forward propagation primitive from a C API primitive...
Definition: dnnl.hpp:5984
Elementwise unary operation forward propagation primitive.
Definition: dnnl.hpp:5913
eltwise_forward(const primitive_desc &pd)
Constructs an eltwise forward propagation primitive.
Definition: dnnl.hpp:6002
eltwise_forward()=default
Default constructor. Produces an empty object.
An execution engine.
Definition: dnnl.hpp:885
static engine query(const primitive_desc &pd)
Returns the engine of a primitive descriptor.
Definition: dnnl.hpp:954
kind
Kinds of engines.
Definition: dnnl.hpp:890
@ gpu
GPU engine.
@ any
An unspecified engine.
@ cpu
CPU engine.
engine(kind akind, size_t index)
Constructs an engine.
Definition: dnnl.hpp:918
engine()=default
Constructs an empty engine.
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: dnnl.hpp:909
engine(const handle< dnnl_primitive_desc_t > &pd)
Constructs an engine based on a primitive from the primitive descriptor pd by querying its engine.
Definition: dnnl.hpp:930
kind get_kind() const
Returns the kind of the engine.
Definition: dnnl.hpp:941
oneDNN exception class.
Definition: dnnl.hpp:84
error(dnnl_status_t status, const char *message)
Constructs an instance of an exception class.
Definition: dnnl.hpp:92
static void wrap_c_api(dnnl_status_t status, const char *message)
A convenience function for wrapping calls to C API functions.
Definition: dnnl.hpp:103
const char * what() const noexcept override
Returns the explanatory string.
Definition: dnnl.hpp:96
Descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9066
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9113
Primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9147
primitive_desc(const desc &adesc, const engine &aengine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9163
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:9249
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9221
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9208
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9205
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:9254
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9213
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9218
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:9264
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:9259
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU backward propagation primitive from a C API primitive des...
Definition: dnnl.hpp:9195
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9200
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9229
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9182
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:9234
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:9239
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:9244
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9226
GRU backward propagation primitive.
Definition: dnnl.hpp:9064
gru_backward()=default
Default constructor. Produces an empty object.
gru_backward(const primitive_desc &pd)
Constructs a GRU backward propagation primitive.
Definition: dnnl.hpp:9275
Descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8917
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8952
Primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8975
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8988
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9033
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9020
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9041
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9028
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9038
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9046
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9049
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9025
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU forward propagation primitive from a C API primitive desc...
Definition: dnnl.hpp:9014
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:9003
GRU forward propagation primitive.
Definition: dnnl.hpp:8915
gru_forward(const primitive_desc &pd)
Constructs a GRU forward propagation primitive.
Definition: dnnl.hpp:9060
gru_forward()=default
Default constructor. Produces an empty object.
A class that provides the destructor for a oneDNN C API handle.
Definition: dnnl.hpp:120
oneDNN C API handle wrapper class.
Definition: dnnl.hpp:136
handle(const handle< T, traits > &)=default
Copy constructor.
bool operator==(const handle< T, traits > &other) const
Equality operator.
Definition: dnnl.hpp:210
bool operator!=(const handle &other) const
Inequality operator.
Definition: dnnl.hpp:220
T get(bool allow_empty=false) const
Returns the underlying C API handle.
Definition: dnnl.hpp:185
handle< T, traits > & operator=(const handle< T, traits > &)=default
Assignment operator.
handle()=default
Constructs an empty handle object.
void reset(T t, bool weak=false)
Resets the handle wrapper objects to wrap a new C API handle.
Definition: dnnl.hpp:176
handle(T t, bool weak=false)
Constructs a handle wrapper object from a C API handle.
Definition: dnnl.hpp:169
handle(handle< T, traits > &&)=default
Move constructor.
handle< T, traits > & operator=(handle< T, traits > &&)=default
Move assignment operator.
Descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7257
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7270
Primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7283
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7344
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7300
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7341
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product backward propagation primitive from a C API pr...
Definition: dnnl.hpp:7333
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7320
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:7338
primitive_desc()=default
Default constructor. Produces an empty object.
Inner product backward propagation primitive.
Definition: dnnl.hpp:7255
inner_product_backward_data(const primitive_desc &pd)
Constructs an inner product backward propagation primitive.
Definition: dnnl.hpp:7353
inner_product_backward_data()=default
Default constructor. Produces an empty object.
Descriptor for an inner product weights gradient primitive.
Definition: dnnl.hpp:7359
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive with bias.
Definition: dnnl.hpp:7373
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive without bias.
Definition: dnnl.hpp:7395
Primitive descriptor for an inner product weights gradient primitive.
Definition: dnnl.hpp:7408
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7463
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:7466
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7471
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product weights update primitive from a C API primitiv...
Definition: dnnl.hpp:7458
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition: dnnl.hpp:7445
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:7474
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition: dnnl.hpp:7425
Inner product weights gradient primitive.
Definition: dnnl.hpp:7357
inner_product_backward_weights(const primitive_desc &pd)
Constructs an inner product weights gradient primitive.
Definition: dnnl.hpp:7485
inner_product_backward_weights()=default
Default constructor. Produces an empty object.
Descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7132
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive without bias.
Definition: dnnl.hpp:7173
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive with bias.
Definition: dnnl.hpp:7149
Primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7186
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product forward propagation primitive from a C API pri...
Definition: dnnl.hpp:7227
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:7239
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7200
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7236
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:7242
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7233
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7216
Inner product forward propagation primitive.
Definition: dnnl.hpp:7130
inner_product_forward(const primitive_desc &pd)
Constructs an inner product forward propagation primitive.
Definition: dnnl.hpp:7251
inner_product_forward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:6968
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition: dnnl.hpp:7008
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition: dnnl.hpp:6984
Primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7022
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:7088
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:7099
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:7107
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization backward propagation primitive from a C A...
Definition: dnnl.hpp:7072
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7091
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:7085
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:7102
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7082
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7059
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:7094
primitive_desc(const desc &adesc, const engine &aengine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7039
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7079
Layer normalization backward propagation primitive.
Definition: dnnl.hpp:6966
layer_normalization_backward(const primitive_desc &pd)
Constructs a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7116
layer_normalization_backward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6823
desc(prop_kind aprop_kind, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition: dnnl.hpp:6837
desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition: dnnl.hpp:6858
Primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6871
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6885
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6922
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6919
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6928
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization forward propagation primitive from a C AP...
Definition: dnnl.hpp:6912
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6934
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6901
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6925
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6931
Layer normalization forward propagation primitive.
Definition: dnnl.hpp:6821
layer_normalization_forward()=default
Default constructor. Produces an empty object.
layer_normalization_forward(const primitive_desc &pd)
Constructs a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6962
Descriptor for a LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9433
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9481
Primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9515
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9578
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:9614
primitive_desc(const desc &adesc, const engine &aengine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9532
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:9634
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:9624
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9552
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9596
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9583
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9575
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:9609
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:9565
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9599
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9588
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9591
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9570
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:9619
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:9629
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:9604
LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9431
lbr_gru_backward(const primitive_desc &pd)
Constructs an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9645
lbr_gru_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9281
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9317
Primitive descriptor for an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9340
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9413
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9392
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9370
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9408
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9416
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:9381
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9354
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9405
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9387
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9400
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9395
LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9279
lbr_gru_forward()=default
Default constructor. Produces an empty object.
lbr_gru_forward(const primitive_desc &pd)
Constructs an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9427
Descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6415
desc()=default
Default constructor. Produces an empty object.
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6428
Primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6439
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6498
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6504
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6501
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive from a C API primit...
Definition: dnnl.hpp:6489
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6476
primitive_desc(const desc &adesc, const engine &aengine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6456
Logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6413
logsoftmax_backward(const primitive_desc &pd)
Constructs a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6513
logsoftmax_backward()=default
Default constructor. Produces an empty object.
Descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6321
desc(prop_kind aprop_kind, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6335
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6346
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6400
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6397
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6376
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive from a C API primiti...
Definition: dnnl.hpp:6387
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6360
primitive_desc()=default
Default constructor. Produces an empty object.
Logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6319
logsoftmax_forward()=default
Default constructor. Produces an empty object.
logsoftmax_forward(const primitive_desc &pd)
Constructs a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6409
Descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5551
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5566
Primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5579
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5614
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN backward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:5627
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5635
primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5595
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5638
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5632
primitive_desc()=default
Default constructor. Produces an empty object.
Local response normalization (LRN) backward propagation primitive.
Definition: dnnl.hpp:5549
lrn_backward(const primitive_desc &pd)
Constructs an LRN backward propagation primitive.
Definition: dnnl.hpp:5647
lrn_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5456
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for a LRN forward propagation primitive.
Definition: dnnl.hpp:5472
Primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5485
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5530
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5533
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5513
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5536
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5498
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN forward propagation primitive from a C API primitive des...
Definition: dnnl.hpp:5524
Local response normalization (LRN) forward propagation primitive.
Definition: dnnl.hpp:5454
lrn_forward()=default
Default constructor. Produces an empty object.
lrn_forward(const primitive_desc &pd)
Constructs an LRN forward propagation primitive.
Definition: dnnl.hpp:5545
Descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8413
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_weights_projection_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole and with or without projection) descriptor for backward ...
Definition: dnnl.hpp:8491
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM descriptor for backward propagation using prop_kind, direction,...
Definition: dnnl.hpp:8702
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole) descriptor for backward propagation using prop_kind,...
Definition: dnnl.hpp:8603
Primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8743
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8814
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:8895
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition: dnnl.hpp:8880
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:8819
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM backward propagation primitive from a C API primitive d...
Definition: dnnl.hpp:8791
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition: dnnl.hpp:8875
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8840
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8796
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8837
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8759
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:8850
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8801
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:8870
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:8824
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:8885
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8778
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8829
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8804
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8832
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:8900
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:8855
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:8890
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8809
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:8865
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8845
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:8860
LSTM backward propagation primitive.
Definition: dnnl.hpp:8411
lstm_backward()=default
Default constructor. Produces an empty object.
lstm_backward(const primitive_desc &pd)
Constructs an LSTM backward propagation primitive.
Definition: dnnl.hpp:8911
Descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8096
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8276
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole and with or without projection) forward...
Definition: dnnl.hpp:8147
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole) forward propagation primitive.
Definition: dnnl.hpp:8215
Primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8302
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8388
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:8370
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8365
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8383
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8396
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8315
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM forward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:8341
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8391
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8360
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:8375
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8355
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8330
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8352
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8380
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8347
LSTM forward propagation primitive.
Definition: dnnl.hpp:8094
lstm_forward(const primitive_desc &pd)
Constructs an LSTM forward propagation primitive.
Definition: dnnl.hpp:8407
lstm_forward()=default
Default constructor. Produces an empty object.
Descriptor for a matmul primitive.
Definition: dnnl.hpp:9921
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition: dnnl.hpp:9929
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition: dnnl.hpp:9943
Primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9953
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9979
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:9995
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a matmul primitive from a C API primitive descriptor that must ...
Definition: dnnl.hpp:9988
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:10000
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:9992
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10005
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9965
Matrix multiplication (matmul) primitive.
Definition: dnnl.hpp:9919
matmul(const primitive_desc &pd)
Constructs a matmul primitive.
Definition: dnnl.hpp:10013
matmul()=default
Default constructor. Produces an empty object.
A memory descriptor.
Definition: dnnl.hpp:2049
desc(const dims &adims, data_type adata_type, format_tag aformat_tag, bool allow_empty=false)
Constructs a memory descriptor.
Definition: dnnl.hpp:2073
desc()
Constructs a zero (empty) memory descriptor.
Definition: dnnl.hpp:2056
bool operator!=(const desc &other) const
An inequality operator.
Definition: dnnl.hpp:2284
desc permute_axes(const std::vector< int > &permutation, bool allow_empty=false) const
Constructs a memory descriptor by permuting axes in an existing one.
Definition: dnnl.hpp:2235
desc submemory_desc(const dims &adims, const dims &offsets, bool allow_empty=false) const
Constructs a memory descriptor for a region inside an area described by this memory descriptor.
Definition: dnnl.hpp:2131
bool operator==(const desc &other) const
An equality operator.
Definition: dnnl.hpp:2276
bool is_zero() const
Checks whether the memory descriptor is zero (empty).
Definition: dnnl.hpp:2270
memory::dims dims() const
Returns dimensions of the memory descriptor.
Definition: dnnl.hpp:2257
memory::data_type data_type() const
Returns the data type of the memory descriptor.
Definition: dnnl.hpp:2249
desc reshape(const dims &adims, bool allow_empty=false) const
Constructs a memory descriptor by reshaping an existing one.
Definition: dnnl.hpp:2187
desc(const dims &adims, data_type adata_type, const dims &strides, bool allow_empty=false)
Constructs a memory descriptor by strides.
Definition: dnnl.hpp:2101
size_t get_size() const
Returns size of the memory descriptor in bytes.
Definition: dnnl.hpp:2265
desc(const dnnl_memory_desc_t &data)
Constructs a memory descriptor from a C API data structure.
Definition: dnnl.hpp:2118
dnnl_memory_desc_t data
The underlying C API data structure.
Definition: dnnl.hpp:2052
Memory object.
Definition: dnnl.hpp:1124
void unmap_data(void *mapped_ptr) const
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
Definition: dnnl.hpp:2450
T * map_data() const
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
Definition: dnnl.hpp:2433
static void validate_dims(const std::vector< T > &v, int min_size=0)
Helper function that validates that an std::vector of dimensions can be safely converted to the C API...
Definition: dnnl.hpp:1140
memory()=default
Default constructor.
dnnl_dim_t dim
Integer type for representing dimension sizes and indices.
Definition: dnnl.hpp:1128
memory(const desc &md, const engine &aengine, void *handle)
Constructs a memory object.
Definition: dnnl.hpp:2317
void set_data_handle(void *handle, const stream &astream) const
Sets the underlying memory buffer.
Definition: dnnl.hpp:2389
void * get_data_handle() const
Returns the underlying memory buffer.
Definition: dnnl.hpp:2354
format_tag
Memory format tag specification.
Definition: dnnl.hpp:1227
data_type
Data type specification.
Definition: dnnl.hpp:1146
@ undef
Undefined data type (used for empty memory descriptors).
engine get_engine() const
Returns the associated engine.
Definition: dnnl.hpp:2343
format_kind
Memory format kind.
Definition: dnnl.hpp:1171
memory(const desc &md, const engine &aengine)
Constructs a memory object.
Definition: dnnl.hpp:2331
void set_data_handle(void *handle) const
Sets the underlying memory buffer.
Definition: dnnl.hpp:2405
static size_t data_type_size(data_type adata_type)
Returns size of data type in bytes.
Definition: dnnl.hpp:1166
desc get_desc() const
Returns the associated memory descriptor.
Definition: dnnl.hpp:2335
std::vector< dim > dims
Vector of dimensions.
Definition: dnnl.hpp:1131
Descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5775
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling backward propagation primitive.
Definition: dnnl.hpp:5799
Primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5818
primitive_desc(const desc &adesc, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5834
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5874
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5877
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5853
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5871
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:5866
Pooling backward propagation primitive.
Definition: dnnl.hpp:5773
pooling_backward()=default
Default constructor. Produces an empty object.
pooling_backward(const primitive_desc &pd)
Constructs a pooling backward propagation primitive.
Definition: dnnl.hpp:5886
Descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5663
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling forward propagation primitive.
Definition: dnnl.hpp:5690
Primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5709
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5737
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5757
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5754
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:5748
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5760
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5722
Pooling forward propagation primitive.
Definition: dnnl.hpp:5661
pooling_forward(const primitive_desc &pd)
Constructs a pooling forward propagation primitive.
Definition: dnnl.hpp:5769
pooling_forward()=default
Default constructor. Produces an empty object.
Descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:10420
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &dilation, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10446
Primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10467
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10526
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10523
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive f...
Definition: dnnl.hpp:10518
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const pooling_v2_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10484
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const pooling_v2_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10504
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:10529
Pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10418
pooling_v2_backward(const primitive_desc &pd)
Constructs a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10539
pooling_v2_backward()=default
Default constructor. Produces an empty object.
Descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:10299
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &dilation, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10328
Primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:10350
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:10404
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10401
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10398
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10380
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive fr...
Definition: dnnl.hpp:10392
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10364
Pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10297
pooling_v2_forward()=default
Default constructor. Produces an empty object.
pooling_v2_forward(const primitive_desc &pd)
Constructs a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10414
Post-ops.
Definition: dnnl.hpp:2515
void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 1.
Definition: dnnl.hpp:2691
void get_params_binary(int index, algorithm &aalgorithm, memory::desc &src1_desc) const
Returns the parameters of a binary post-op.
Definition: dnnl.hpp:2827
void get_params_sum(int index, float &scale, memory::data_type &data_type) const
Returns the parameters of an accumulation (sum) post-op.
Definition: dnnl.hpp:2592
void append_eltwise(float scale, algorithm aalgorithm, float alpha, float beta)
Appends an elementwise post-op.
Definition: dnnl.hpp:2614
void append_binary(algorithm aalgorithm, const memory::desc &src1_desc)
Appends a binary post-op.
Definition: dnnl.hpp:2816
void append_dw_k3s1p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 1.
Definition: dnnl.hpp:2665
primitive::kind kind(int index) const
Returns the primitive kind of post-op at entry with a certain index.
Definition: dnnl.hpp:2532
int len() const
Returns the number of post-ops entries.
Definition: dnnl.hpp:2527
void append_dw_k3s2p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 2.
Definition: dnnl.hpp:2750
post_ops()
Constructs an empty sequence of post-ops.
Definition: dnnl.hpp:2519
void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 2.
Definition: dnnl.hpp:2776
void get_params_eltwise(int index, float &scale, algorithm &aalgorithm, float &alpha, float &beta) const
Returns parameters of an elementwise post-op.
Definition: dnnl.hpp:2628
void get_params_sum(int index, float &scale) const
Returns the parameters of an accumulation (sum) post-op.
Definition: dnnl.hpp:2582
void append_sum(float scale=1.f, memory::data_type data_type=memory::data_type::undef)
Appends an accumulation (sum) post-op.
Definition: dnnl.hpp:2567
Descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10643
desc(const memory::desc &data_desc, const memory::desc &weight_desc, const memory::desc &diff_data_desc, const memory::desc &diff_weights_desc)
Constructs a descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10654
Primitive descriptor for prelu backward propagation.
Definition: dnnl.hpp:10667
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10722
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:10725
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const prelu_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10704
primitive_desc(const desc &adesc, const engine &aengine, const prelu_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10684
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10728
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a prelu backward propagation primitive from a C API primitive d...
Definition: dnnl.hpp:10717
primitive_desc()=default
Default constructor. Produces an empty object.
PReLU backward propagation primitive.
Definition: dnnl.hpp:10641
prelu_backward()=default
Default constructor. Produces an empty object.
prelu_backward(const primitive_desc &pd)
Constructs a prelu backward propagation primitive.
Definition: dnnl.hpp:10737
Descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10556
desc(prop_kind aprop_kind, const memory::desc &data_desc, const memory::desc &weight_desc)
Constructs a descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10567
Primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10578
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10628
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10625
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10592
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10608
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a prelu forward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:10619
PReLU forward propagation primitive.
Definition: dnnl.hpp:10554
prelu_forward(const primitive_desc &pd)
Constructs a prelu forward propagation primitive.
Definition: dnnl.hpp:10637
prelu_forward()=default
Default constructor. Produces an empty object.
Primitive attributes.
Definition: dnnl.hpp:2851
void get_zero_points(int arg, int &mask, std::vector< int32_t > &zero_points) const
Returns zero points correspondence mask and values.
Definition: dnnl.hpp:3018
const post_ops get_post_ops() const
Returns post-ops previously set via set_post_ops().
Definition: dnnl.hpp:3064
void set_rnn_data_qparams(float scale, float shift)
Sets quantization scale and shift parameters for RNN data tensors.
Definition: dnnl.hpp:3119
void get_rnn_weights_qparams(int &mask, std::vector< float > &scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3197
void get_rnn_data_qparams(float &scale, float &shift)
Returns the quantization scale and shift parameters for RNN data tensors.
Definition: dnnl.hpp:3135
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scaling factors correspondence mask and values.
Definition: dnnl.hpp:2953
void get_rnn_weights_projection_qparams(int &mask, std::vector< float > &scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3266
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Sets quantization scaling factors for RNN weights tensors.
Definition: dnnl.hpp:3171
void set_rnn_weights_projection_qparams(int mask, const std::vector< float > &scales)
Sets quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3238
void set_scratchpad_mode(scratchpad_mode mode)
Sets scratchpad mode.
Definition: dnnl.hpp:2882
void set_scales(int arg, int mask, const std::vector< float > &scales)
Sets scaling factors for primitive operations for a given memory argument.
Definition: dnnl.hpp:3001
void get_scales(int arg, int &mask, std::vector< float > &scales) const
Returns scaling factors correspondence mask and values for a given memory argument.
Definition: dnnl.hpp:2971
void get_output_scales(int &mask, std::vector< float > &scales) const
Returns output scaling factors correspondence mask and values.
Definition: dnnl.hpp:2897
primitive_attr(dnnl_primitive_attr_t attr)
Creates primitive attributes from a C API dnnl_primitive_attr_t handle.
Definition: dnnl.hpp:2867
void set_post_ops(const post_ops ops)
Sets post-ops.
Definition: dnnl.hpp:3081
primitive_attr()
Constructs default (empty) primitive attributes.
Definition: dnnl.hpp:2855
void set_zero_points(int arg, int mask, const std::vector< int32_t > &zero_points)
Sets zero points for primitive operations for a given memory argument.
Definition: dnnl.hpp:3053
scratchpad_mode get_scratchpad_mode() const
Returns the scratchpad mode.
Definition: dnnl.hpp:2871
Base class for all primitive descriptors.
Definition: dnnl.hpp:3290
primitive_attr get_primitive_attr() const
Returns the primitive attributes.
Definition: dnnl.hpp:3474
memory::desc diff_weights_desc(int idx) const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:3400
primitive_desc_base()=default
Default constructor. Produces an empty object.
engine get_engine() const
Returns the engine of the primitive descriptor.
Definition: dnnl.hpp:3298
memory::desc query_md(query what, int idx=0) const
Returns a memory descriptor.
Definition: dnnl.hpp:3335
memory::desc dst_desc(int idx) const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3364
memory::desc diff_dst_desc(int idx) const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:3391
memory::desc scratchpad_desc() const
Returns the scratchpad memory descriptor.
Definition: dnnl.hpp:3456
void reset_with_clone(const_dnnl_primitive_desc_t pd)
Resets the value of the handle to a clone of a C API primitive descriptor.
Definition: dnnl.hpp:3498
dnnl::primitive::kind get_kind() const
Returns the kind of the primitive descriptor.
Definition: dnnl.hpp:3486
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:3435
memory::desc diff_src_desc(int idx) const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:3382
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:3423
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3550
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3518
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:3429
memory::desc weights_desc(int idx) const
Returns a weights memory descriptor.
Definition: dnnl.hpp:3373
memory::dim query_s64(query what) const
Returns a memory::dim value (same as int64_t).
Definition: dnnl.hpp:3314
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:3447
engine scratchpad_engine() const
Returns the engine on which the scratchpad memory is located.
Definition: dnnl.hpp:3462
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3417
const char * impl_info_str() const
Returns implementation name.
Definition: dnnl.hpp:3302
memory::desc src_desc(int idx) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3355
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:3411
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3533
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:3441
A base class for descriptors of all primitives that have an operation descriptor and that support ite...
Definition: dnnl.hpp:3944
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr, const engine &aengine, const_dnnl_primitive_desc_t hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor.
Definition: dnnl.hpp:3971
bool next_impl()
Advances the primitive iterator to the next implementation.
Definition: dnnl.hpp:3989
Base class for all computational primitives.
Definition: dnnl.hpp:269
void execute(const stream &astream, const std::unordered_map< int, memory > &args) const
Executes computations specified by the primitive in a specified stream.
primitive()=default
Default constructor. Constructs an empty object.
primitive(const primitive_desc &pd)
Constructs a primitive from a primitive descriptor.
kind
Kinds of primitives supported by the library.
Definition: dnnl.hpp:271
@ deconvolution
A deconvolution primitive.
@ pooling_v2
A pooling version 2 primitive.
@ inner_product
An inner product primitive.
@ logsoftmax
A logsoftmax primitive.
@ layer_normalization
A layer normalization primitive.
@ pooling
A pooling primitive.
@ resampling
A resampling primitive.
@ shuffle
A shuffle primitive.
@ rnn
An RNN primitive.
@ batch_normalization
A batch normalization primitive.
@ lrn
An LRN primitive.
@ prelu
A PReLU primitive.
@ eltwise
An element-wise primitive.
@ convolution
A convolution primitive.
@ softmax
A softmax primitive.
@ undef
Undefined primitive.
primitive(const_dnnl_primitive_desc_t c_pd)
Constructs a primitive from a C API primitive descriptor.
Descriptor for reduction.
Definition: dnnl.hpp:10754
desc()=default
Default constructor. Produces an empty object.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, float p, float eps)
Constructs a descriptor for a reduction primitive using algorithm specific parameters,...
Definition: dnnl.hpp:10777
Primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10787
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10826
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10829
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a reduction primitive from a C API primitive descriptor that mu...
Definition: dnnl.hpp:10822
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10813
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10799
Reduction.
Definition: dnnl.hpp:10752
reduction(const primitive_desc &pd)
Constructs a reduction primitive.
Definition: dnnl.hpp:10837
reduction()=default
Default constructor. Produces an empty object.
Primitive descriptor for a reorder primitive.
Definition: dnnl.hpp:3614
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:3699
primitive_desc(const engine &src_engine, const memory::desc &src_md, const engine &dst_engine, const memory::desc &dst_md, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for reorder primitive.
Definition: dnnl.hpp:3637
primitive_desc(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for reorder primitive.
Definition: dnnl.hpp:3663
engine get_src_engine() const
Returns the engine on which the source memory is allocated.
Definition: dnnl.hpp:3688
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for reorder primitive from a C API primitive descriptor which must ...
Definition: dnnl.hpp:3683
engine get_dst_engine() const
Returns the engine on which the destination memory is allocated.
Definition: dnnl.hpp:3694
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3702
Reorder primitive.
Definition: dnnl.hpp:3612
reorder(const primitive_desc &pd)
Constructs a reorder primitive.
Definition: dnnl.hpp:3710
void execute(const stream &astream, memory &src, memory &dst) const
Executes the reorder primitive.
Definition: dnnl.hpp:3731
reorder()=default
Default constructor. Produces an empty object.
reorder(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr())
Constructs a reorder primitive that would reorder data between memory objects having the same memory ...
Definition: dnnl.hpp:3719
Descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10175
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for a resampling backward propagation primitive using source and destination ...
Definition: dnnl.hpp:10186
desc(algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for resampling backward propagation primitive.
Definition: dnnl.hpp:10203
Primitive descriptor for resampling backward propagation primitive.
Definition: dnnl.hpp:10216
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling backward propagation primitive from a C API primit...
Definition: dnnl.hpp:10266
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:10271
primitive_desc(const desc &adesc, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10233
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10274
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10253
Resampling backward propagation primitive.
Definition: dnnl.hpp:10173
resampling_backward(const primitive_desc &pd)
Constructs a resampling backward propagation primitive.
Definition: dnnl.hpp:10283
resampling_backward()=default
Default constructor. Produces an empty object.
Descriptor for resampling forward propagation.
Definition: dnnl.hpp:10031
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive using source and destination m...
Definition: dnnl.hpp:10049
desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &src_desc)
Constructs a descriptor for a resampling forward propagation primitive using source memory descriptor...
Definition: dnnl.hpp:10069
desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10096
Primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10110
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10124
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10160
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10157
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling forward propagation primitive from a C API primiti...
Definition: dnnl.hpp:10151
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10140
primitive_desc()=default
Default constructor. Produces an empty object.
Resampling forward propagation.
Definition: dnnl.hpp:10029
resampling_forward()=default
Default constructor. Produces an empty object.
resampling_forward(const primitive_desc &pd)
Constructs a resampling forward propagation primitive.
Definition: dnnl.hpp:10169
Base class for primitive descriptors for RNN primitives.
Definition: dnnl.hpp:7499
memory::desc dst_iter_c_desc() const
Returns destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:7584
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:7550
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:7610
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:7538
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:7544
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:7598
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:7658
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:7616
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:7652
rnn_primitive_desc_base()=default
Default constructor. Produces an empty object.
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:7604
rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
Constructs an RNN primitive descriptor base from a C API primitive descriptor while checking that it ...
Definition: dnnl.hpp:7512
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:7638
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:7570
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition: dnnl.hpp:7629
memory::desc src_iter_c_desc() const
Returns source recurrent cell state memory descriptor.
Definition: dnnl.hpp:7532
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:7526
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:7564
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:7556
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:7518
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:7644
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:7578
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition: dnnl.hpp:7622
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:7590
Descriptor for a shuffle primitive backward propagation primitive.
Definition: dnnl.hpp:9736
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9746
Primitive descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9755
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:9786
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:9791
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const shuffle_forward::primitive_desc &hint_fwd_pd, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9773
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:9794
Shuffle backward propagation primitive.
Definition: dnnl.hpp:9733
shuffle_backward()=default
Default constructor. Produces an empty object.
shuffle_backward(const primitive_desc &pd)
Constructs a shuffle backward propagation primitive.
Definition: dnnl.hpp:9803
Descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9661
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9673
Primitive descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9684
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:9720
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:9717
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:9711
primitive_desc(const desc &adesc, const engine &aengine, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9699
primitive_desc()=default
Default constructor. Produces an empty object.
Shuffle forward propagation primitive.
Definition: dnnl.hpp:9659
shuffle_forward()=default
Default constructor. Produces an empty object.
shuffle_forward(const primitive_desc &pd)
Constructs a shuffle forward propagation primitive.
Definition: dnnl.hpp:9729
Descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6211
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6224
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6235
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:6285
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6272
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6296
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6293
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6252
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6290
Softmax backward propagation primitive.
Definition: dnnl.hpp:6209
softmax_backward()=default
Default constructor. Produces an empty object.
softmax_backward(const primitive_desc &pd)
Constructs a softmax backward propagation primitive.
Definition: dnnl.hpp:6305
Descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6121
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6135
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6146
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6193
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6160
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6196
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:6187
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6176
primitive_desc()=default
Default constructor. Produces an empty object.
Softmax forward propagation primitive.
Definition: dnnl.hpp:6119
softmax_forward()=default
Default constructor. Produces an empty object.
softmax_forward(const primitive_desc &pd)
Constructs a softmax forward propagation primitive.
Definition: dnnl.hpp:6205
An execution stream.
Definition: dnnl.hpp:1001
engine get_engine() const
Returns the associated engine.
Definition: dnnl.hpp:1032
stream & wait()
Waits for all primitives executing in the stream to finish.
Definition: dnnl.hpp:1041
stream(const engine &aengine, flags aflags=flags::default_flags)
Constructs a stream for the specified engine and with behavior controlled by the specified flags.
Definition: dnnl.hpp:1023
flags
Stream flags. Can be combined using the bitwise OR operator.
Definition: dnnl.hpp:1005
@ out_of_order
Out-of-order execution.
@ default_flags
Default stream configuration.
@ in_order
In-order execution.
stream()=default
Constructs an empty stream.
Primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3853
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3926
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3923
primitive_desc(const memory::desc &dst, const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3867
primitive_desc(const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3897
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for sum primitive from a C API primitive descriptor which must have...
Definition: dnnl.hpp:3919
Out-of-place summation (sum) primitive.
Definition: dnnl.hpp:3851
sum()=default
Default constructor. Produces an empty object.
sum(const primitive_desc &pd)
Constructs a sum primitive.
Definition: dnnl.hpp:3934
Descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7869
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7924
Primitive descriptor for an RNN backward propagation primitive.
Definition: dnnl.hpp:7960
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7997
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8020
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:8074
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8036
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:8054
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:8064
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7977
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:8069
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8028
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive from a C API primi...
Definition: dnnl.hpp:8010
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8023
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8033
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8041
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:8079
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:8049
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8015
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:8059
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8044
Vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7867
vanilla_rnn_backward(const primitive_desc &pd)
Constructs a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:8090
vanilla_rnn_backward()=default
Default constructor. Produces an empty object.
Descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7708
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7751
Primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7776
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7790
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive from a C API primit...
Definition: dnnl.hpp:7817
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:7823
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:7828
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:7836
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:7831
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:7852
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:7849
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7806
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:7844
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:7841
Vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7706
vanilla_rnn_forward()=default
Default constructor. Produces an empty object.
vanilla_rnn_forward(const primitive_desc &pd)
Constructs a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7863
A descriptor of a Batch Normalization operation.
Definition: dnnl_types.h:1896
A descriptor of a binary operation.
Definition: dnnl_types.h:2104
A descriptor of a convolution operation.
Definition: dnnl_types.h:1600
A descriptor of a element-wise operation.
Definition: dnnl_types.h:1675
An opaque structure to describe an engine.
A descriptor of an inner product operation.
Definition: dnnl_types.h:1966
A descriptor of a Layer Normalization operation.
Definition: dnnl_types.h:1929
A descriptor of a Local Response Normalization (LRN) operation.
Definition: dnnl_types.h:1865
A descriptor of a matrix multiplication operation.
Definition: dnnl_types.h:2130
Memory descriptor.
Definition: dnnl_types.h:1511
dnnl_data_type_t data_type
Data type of the tensor elements.
Definition: dnnl_types.h:1531
dnnl_dims_t dims
Dimensions in the following order:
Definition: dnnl_types.h:1528
int ndims
Number of dimensions.
Definition: dnnl_types.h:1513
An opaque structure to describe a memory.
A descriptor of a pooling operation.
Definition: dnnl_types.h:1765
A descriptor of a pooling operation.
Definition: dnnl_types.h:1803
An opaque structure for a chain of post operations.
An opaque structure for primitive descriptor attributes.
An opaque structure to describe a primitive descriptor iterator.
An opaque structure to describe a primitive descriptor.
An opaque structure to describe a primitive.
A descriptor of reduction operation.
Definition: dnnl_types.h:2180
A descriptor of resampling operation.
Definition: dnnl_types.h:2152
A descriptor for an RNN operation.
Definition: dnnl_types.h:2022
A descriptor of a shuffle operation.
Definition: dnnl_types.h:1653
A descriptor of a Softmax operation.
Definition: dnnl_types.h:1735
An opaque structure to describe an execution stream.
Structure containing version information as per Semantic Versioning
Definition: dnnl_types.h:2703