20 #ifndef ONEAPI_DNNL_DNNL_HPP 21 #define ONEAPI_DNNL_DNNL_HPP 23 #include "oneapi/dnnl/dnnl_config.h" 32 #include <unordered_map> 42 #ifndef DNNL_ENABLE_EXCEPTIONS 43 #if __cpp_exceptions || __EXCEPTIONS \ 44 || (defined(_MSC_VER) && !defined(__clang__)) 45 #define DNNL_ENABLE_EXCEPTIONS 1 47 #define DNNL_ENABLE_EXCEPTIONS 0 51 #if defined(__GNUC__) || defined(__clang__) 52 #define DNNL_TRAP() __builtin_trap() 53 #elif defined(__INTEL_COMPILER) || defined(_MSC_VER) 54 #define DNNL_TRAP() __debugbreak() 56 #error "unknown compiler" 59 #if DNNL_ENABLE_EXCEPTIONS 60 #define DNNL_THROW_ERROR(status, msg) throw error(status, msg) 63 #define DNNL_THROW_ERROR(status, msg) \ 84 struct error :
public std::exception {
96 const char *
what() const noexcept
override {
return message; }
109 template <
typename T>
110 void validate_container_size(
const T &v,
const char *error_message,
111 int min_size = 1,
int max_size = -1) {
112 const int size = (int)v.size();
113 if (size < min_size || (max_size >= 0 && size > max_size))
119 template <
typename T>
135 template <
typename T,
typename traits = handle_traits<T>>
139 std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
142 bool operator==(
const T other)
const {
return other == data_.get(); }
143 bool operator!=(
const T other)
const {
return !(*
this == other); }
176 void reset(T t,
bool weak =
false) {
177 data_.reset(t, weak ? &dummy_destructor : traits::destructor);
185 T
get(
bool allow_empty =
false)
const {
186 T result = data_.get();
187 if (allow_empty ==
false && result ==
nullptr)
197 explicit operator T()
const {
return get(
true); }
202 explicit operator bool()
const {
return get(
true) !=
nullptr; }
211 return other.data_.get() == data_.get();
257 struct primitive_desc;
357 const std::unordered_map<int, memory> &args)
const;
371 "could not get a primitive descriptor from a primitive");
382 "could not get a primitive kind from a primitive descriptor");
472 undef = dnnl_alg_kind_undef,
664 #define DNNL_DEFINE_BITMASK_OPS(enum_name) \ 665 inline enum_name operator|(enum_name lhs, enum_name rhs) { \ 666 return static_cast<enum_name>( \ 667 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \ 670 inline enum_name operator&(enum_name lhs, enum_name rhs) { \ 671 return static_cast<enum_name>( \ 672 static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \ 675 inline enum_name operator^(enum_name lhs, enum_name rhs) { \ 676 return static_cast<enum_name>( \ 677 static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \ 680 inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \ 681 lhs = static_cast<enum_name>( \ 682 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \ 686 inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \ 687 lhs = static_cast<enum_name>( \ 688 static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \ 692 inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \ 693 lhs = static_cast<enum_name>( \ 694 static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \ 698 inline enum_name operator~(enum_name rhs) { \ 699 return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \ 900 "could not create an engine");
913 "could not get an engine from a primitive_desc");
914 reset(c_engine,
true);
922 "could not get kind of an engine");
931 template <
typename primitive_desc>
941 template <
typename primitive_desc>
946 "could not get an engine from a primitive_desc");
947 return engine(c_engine,
true);
1005 "could not create a stream");
1013 "could not get an engine from a stream object");
1014 return engine(c_engine,
true);
1117 template <
typename T>
1119 validate_container_size(
1420 AB16b16a = dnnl_AB16b16a,
1421 AB16b32a = dnnl_AB16b32a,
1422 AB16b64a = dnnl_AB16b64a,
1423 AB8b16a2b = dnnl_AB8b16a2b,
1424 AB8b32a2b = dnnl_AB8b32a2b,
1425 AB8b64a2b = dnnl_AB8b64a2b,
1426 AB4b16a4b = dnnl_AB4b16a4b,
1427 AB4b32a4b = dnnl_AB4b32a4b,
1428 AB4b64a4b = dnnl_AB4b64a4b,
1429 Abc16a = dnnl_Abc16a,
1430 ABc16a16b = dnnl_ABc16a16b,
1431 ABc4a4b = dnnl_ABc4a4b,
1434 ABc16b16a = dnnl_ABc16b16a,
1435 ABc16b32a = dnnl_ABc16b32a,
1436 ABc16b64a = dnnl_ABc16b64a,
1439 ABc4b16a4b = dnnl_ABc4b16a4b,
1440 ABc4b32a4b = dnnl_ABc4b32a4b,
1441 ABc4b64a4b = dnnl_ABc4b64a4b,
1442 ABc2b8a4b = dnnl_ABc2b8a4b,
1443 ABc16b16a4b = dnnl_ABc16b16a4b,
1444 ABc16b16a2b = dnnl_ABc16b16a2b,
1445 ABc4b4a = dnnl_ABc4b4a,
1446 ABc8a16b2a = dnnl_ABc8a16b2a,
1447 ABc8a8b = dnnl_ABc8a8b,
1448 ABc8a4b = dnnl_ABc8a4b,
1450 ABc8b16a2b = dnnl_ABc8b16a2b,
1451 ABc8b32a2b = dnnl_ABc8b32a2b,
1452 ABc8b64a2b = dnnl_ABc8b64a2b,
1453 ABc8b8a = dnnl_ABc8b8a,
1454 Abcd8a = dnnl_Abcd8a,
1455 Abcd16a = dnnl_Abcd16a,
1456 Abcd32a = dnnl_Abcd32a,
1457 ABcd16a16b = dnnl_ABcd16a16b,
1460 ABcd16b16a = dnnl_ABcd16b16a,
1461 ABcd16b32a = dnnl_ABcd16b32a,
1462 ABcd16b64a = dnnl_ABcd16b64a,
1463 aBCd16b16c = dnnl_aBCd16b16c,
1464 aBCd16c16b = dnnl_aBCd16c16b,
1465 Abcd4a = dnnl_Abcd4a,
1467 ABcd4b16a4b = dnnl_ABcd4b16a4b,
1468 ABcd4b32a4b = dnnl_ABcd4b32a4b,
1469 ABcd4b64a4b = dnnl_ABcd4b64a4b,
1470 ABcd2b8a4b = dnnl_ABcd2b8a4b,
1471 ABcd4b4a = dnnl_ABcd4b4a,
1472 ABcd4a4b = dnnl_ABcd4a4b,
1473 aBCd4c16b4c = dnnl_aBCd4c16b4c,
1474 aBCd2c8b4c = dnnl_aBCd2c8b4c,
1475 ABcd16b16a4b = dnnl_ABcd16b16a4b,
1476 ABcd16b16a2b = dnnl_ABcd16b16a2b,
1477 aBCd16c16b4c = dnnl_aBCd16c16b4c,
1478 aBCd16c16b2c = dnnl_aBCd16c16b2c,
1479 aBCd4c4b = dnnl_aBCd4c4b,
1480 aBCd4b4c = dnnl_aBCd4b4c,
1481 ABcd8a16b2a = dnnl_ABcd8a16b2a,
1482 ABcd8a8b = dnnl_ABcd8a8b,
1483 ABcd8a4b = dnnl_ABcd8a4b,
1486 ABcd8b16a2b = dnnl_ABcd8b16a2b,
1487 ABcd8b32a2b = dnnl_ABcd8b32a2b,
1488 ABcd8b64a2b = dnnl_ABcd8b64a2b,
1489 aBCd8b16c2b = dnnl_aBCd8b16c2b,
1492 aBCd8b8c = dnnl_aBCd8b8c,
1493 aBCd8b4c = dnnl_aBCd8b4c,
1494 aBCd8c16b2c = dnnl_aBCd8c16b2c,
1495 aBCd8c8b = dnnl_aBCd8c8b,
1496 Abcde16a = dnnl_Abcde16a,
1497 Abcde32a = dnnl_Abcde32a,
1498 ABcde16a16b = dnnl_ABcde16a16b,
1501 ABcde16b16a = dnnl_ABcde16b16a,
1502 ABcde16b32a = dnnl_ABcde16b32a,
1503 ABcde16b64a = dnnl_ABcde16b64a,
1504 aBCde16b16c = dnnl_aBCde16b16c,
1505 aBCde16c16b = dnnl_aBCde16c16b,
1506 aBCde2c8b4c = dnnl_aBCde2c8b4c,
1507 Abcde4a = dnnl_Abcde4a,
1509 ABcde4b4a = dnnl_ABcde4b4a,
1510 ABcde4a4b = dnnl_ABcde4a4b,
1511 aBCde4b4c = dnnl_aBCde4b4c,
1512 aBCde4c16b4c = dnnl_aBCde4c16b4c,
1513 aBCde16c16b4c = dnnl_aBCde16c16b4c,
1514 aBCde16c16b2c = dnnl_aBCde16c16b2c,
1515 aBCde4c4b = dnnl_aBCde4c4b,
1516 Abcde8a = dnnl_Abcde8a,
1517 ABcde8a8b = dnnl_ABcde8a8b,
1518 ABcde8a4b = dnnl_ABcde8a4b,
1520 ABcde8b16a2b = dnnl_ABcde8b16a2b,
1521 ABcde8b32a2b = dnnl_ABcde8b32a2b,
1522 ABcde8b64a2b = dnnl_ABcde8b64a2b,
1524 ABcde4b32a4b = dnnl_ABcde4b32a4b,
1525 ABcde4b64a4b = dnnl_ABcde4b64a4b,
1527 aBCde8b16c2b = dnnl_aBCde8b16c2b,
1528 ABcde8b8a = dnnl_ABcde8b8a,
1529 aBCde8b8c = dnnl_aBCde8b8c,
1530 aBCde8b4c = dnnl_aBCde8b4c,
1531 ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1532 ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1533 aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1534 aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1535 aBCde8c16b2c = dnnl_aBCde8c16b2c,
1536 aBCde8c8b = dnnl_aBCde8c8b,
1538 aBCdef16b16c = dnnl_aBCdef16b16c,
1539 aBCdef16c16b = dnnl_aBCdef16c16b,
1542 aBCdef4c4b = dnnl_aBCdef4c4b,
1543 aBCdef4b4c = dnnl_aBCdef4b4c,
1544 aBCdef8b8c = dnnl_aBCdef8b8c,
1545 aBCdef8b4c = dnnl_aBCdef8b4c,
1546 aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1547 aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
1548 aBCdef8c8b = dnnl_aBCdef8c8b,
1549 aBdc16b = dnnl_aBdc16b,
1550 aBdc4b = dnnl_aBdc4b,
1551 aBdc8b = dnnl_aBdc8b,
1552 aBdec16b = dnnl_aBdec16b,
1553 aBdec4b = dnnl_aBdec4b,
1554 aBdec8b = dnnl_aBdec8b,
1555 aBdefc16b = dnnl_aBdefc16b,
1556 aCBdef16c16b = dnnl_aCBdef16c16b,
1557 aCBdef16b16c = dnnl_aCBdef16b16c,
1558 aBdefc4b = dnnl_aBdefc4b,
1559 aBdefc8b = dnnl_aBdefc8b,
1560 Acb16a = dnnl_Acb16a,
1563 aCBd16b16c = dnnl_aCBd16b16c,
1564 aCBd16c16b = dnnl_aCBd16c16b,
1565 aCBde16b16c = dnnl_aCBde16b16c,
1566 aCBde16c16b = dnnl_aCBde16c16b,
1567 Acdb16a = dnnl_Acdb16a,
1568 Acdb4a = dnnl_Acdb4a,
1569 Acdb8a = dnnl_Acdb8a,
1570 Acdeb16a = dnnl_Acdeb16a,
1571 Acdeb4a = dnnl_Acdeb4a,
1572 Acdeb8a = dnnl_Acdeb8a,
1573 BAc16a16b = dnnl_BAc16a16b,
1574 BAc16b16a = dnnl_BAc16b16a,
1575 BAcd16a16b = dnnl_BAcd16a16b,
1576 BAcd16b16a = dnnl_BAcd16b16a,
1577 ABcd32a32b = dnnl_ABcd32a32b,
1578 BAcde16b16a = dnnl_BAcde16b16a,
1579 BAcde16a16b = dnnl_BAcde16a16b,
1580 aBdec32b = dnnl_aBdec32b,
1581 Abcdef16a = dnnl_Abcdef16a,
1582 Abcdef32a = dnnl_Abcdef32a,
1583 Acdb32a = dnnl_Acdb32a,
1587 aBCd2c4b2c = dnnl_aBCd2c4b2c,
1588 aBCde2c4b2c = dnnl_aBCde2c4b2c,
1589 aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
1590 aBCd4b8c2b = dnnl_aBCd4b8c2b,
1591 aBCde4b8c2b = dnnl_aBCde4b8c2b,
1592 aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
1593 aBCd4c8b2c = dnnl_aBCd4c8b2c,
1594 aBCde4c8b2c = dnnl_aBCde4c8b2c,
1595 aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
1596 AB32a32b8a4b = dnnl_AB32a32b8a4b,
1597 AB32a32b8a2b = dnnl_AB32a32b8a2b,
1598 AB8a4b = dnnl_AB8a4b,
1599 AB8a2b = dnnl_AB8a2b,
1600 abDc32d = dnnl_abDc32d,
1601 abDC32d4c = dnnl_abDC32d4c,
1602 abdEc32e = dnnl_abdEc32e,
1603 abdEC32e2c = dnnl_abdEC32e2c,
1604 abdEC32e4c = dnnl_abdEC32e4c,
1617 NCw16n16c = dnnl_NCw16n16c,
1618 NChw16n16c = dnnl_NChw16n16c,
1619 NCdhw16n16c = dnnl_NCdhw16n16c,
1620 NCdhw32n32c = dnnl_NCdhw32n32c,
1621 NChw32n32c = dnnl_NChw32n32c,
1622 IOhw16i16o = dnnl_IOhw16i16o,
1623 OI16i16o = dnnl_OI16i16o,
1624 OI16i32o = dnnl_OI16i32o,
1625 OI16i64o = dnnl_OI16i64o,
1626 OI8i16o2i = dnnl_OI8i16o2i,
1627 OI8i32o2i = dnnl_OI8i32o2i,
1628 OI8i64o2i = dnnl_OI8i64o2i,
1629 OI4i16o4i = dnnl_OI4i16o4i,
1630 OI4i32o4i = dnnl_OI4i32o4i,
1631 OI4i64o4i = dnnl_OI4i64o4i,
1632 Ohwi32o = dnnl_Ohwi32o,
1633 IOdhw16i16o = dnnl_IOdhw16i16o,
1634 gIOhw16i16o = dnnl_gIOhw16i16o,
1635 gOhwi32o = dnnl_gOhwi32o,
1636 Goidhw16g = dnnl_Goidhw16g,
1637 IOw16o16i = dnnl_IOw16o16i,
1638 OIw16i16o = dnnl_OIw16i16o,
1639 OIw16i32o = dnnl_OIw16i32o,
1640 OIw16i64o = dnnl_OIw16i64o,
1641 IOw16i16o = dnnl_IOw16i16o,
1642 gIOw16i16o = dnnl_gIOw16i16o,
1643 OIw16o16i = dnnl_OIw16o16i,
1644 Oiw16o = dnnl_Oiw16o,
1645 OIw4i16o4i = dnnl_OIw4i16o4i,
1646 OIw4i32o4i = dnnl_OIw4i32o4i,
1647 OIw4i64o4i = dnnl_OIw4i64o4i,
1648 OIw2i8o4i = dnnl_OIw2i8o4i,
1649 OIw4i4o = dnnl_OIw4i4o,
1650 OIw4o4i = dnnl_OIw4o4i,
1652 OIw8i16o2i = dnnl_OIw8i16o2i,
1653 OIw8i32o2i = dnnl_OIw8i32o2i,
1654 OIw8i64o2i = dnnl_OIw8i64o2i,
1655 OIw8i8o = dnnl_OIw8i8o,
1656 OIw8o16i2o = dnnl_OIw8o16i2o,
1657 OIw8o8i = dnnl_OIw8o8i,
1658 OIw8o4i = dnnl_OIw8o4i,
1659 Owi16o = dnnl_Owi16o,
1660 OwI16o2i = dnnl_OwI16o2i,
1663 IOhw16o16i = dnnl_IOhw16o16i,
1664 Ohwi16o = dnnl_Ohwi16o,
1665 OhwI16o2i = dnnl_OhwI16o2i,
1666 Ohwi4o = dnnl_Ohwi4o,
1667 Ohwi8o = dnnl_Ohwi8o,
1668 OIhw16i16o = dnnl_OIhw16i16o,
1669 OIhw16i32o = dnnl_OIhw16i32o,
1670 OIhw16i64o = dnnl_OIhw16i64o,
1671 OIhw16o16i = dnnl_OIhw16o16i,
1672 Oihw16o = dnnl_Oihw16o,
1673 OIhw4i16o4i = dnnl_OIhw4i16o4i,
1674 OIhw4i32o4i = dnnl_OIhw4i32o4i,
1675 OIhw4i64o4i = dnnl_OIhw4i64o4i,
1676 OIhw4i4o = dnnl_OIhw4i4o,
1677 OIhw4o4i = dnnl_OIhw4o4i,
1678 Oihw4o = dnnl_Oihw4o,
1679 OIhw8i16o2i = dnnl_OIhw8i16o2i,
1680 OIhw8i32o2i = dnnl_OIhw8i32o2i,
1681 OIhw8i64o2i = dnnl_OIhw8i64o2i,
1682 OIhw8i8o = dnnl_OIhw8i8o,
1683 OIhw8o16i2o = dnnl_OIhw8o16i2o,
1684 OIhw8o8i = dnnl_OIhw8o8i,
1685 OIhw8o4i = dnnl_OIhw8o4i,
1686 OIhw2i8o4i = dnnl_OIhw2i8o4i,
1687 IOdhw16o16i = dnnl_IOdhw16o16i,
1688 Odhwi16o = dnnl_Odhwi16o,
1689 OdhwI16o2i = dnnl_OdhwI16o2i,
1690 Odhwi4o = dnnl_Odhwi4o,
1691 Odhwi8o = dnnl_Odhwi8o,
1692 OIdhw16i16o = dnnl_OIdhw16i16o,
1693 OIdhw16i32o = dnnl_OIdhw16i32o,
1694 OIdhw16i64o = dnnl_OIdhw16i64o,
1695 OIdhw16o16i = dnnl_OIdhw16o16i,
1696 Oidhw16o = dnnl_Oidhw16o,
1697 OIdhw4i4o = dnnl_OIdhw4i4o,
1698 OIdhw4o4i = dnnl_OIdhw4o4i,
1699 Oidhw4o = dnnl_Oidhw4o,
1700 OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1701 OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
1702 OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
1703 OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
1704 OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
1705 OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
1706 OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
1707 OIdhw8i8o = dnnl_OIdhw8i8o,
1708 OIdhw8o8i = dnnl_OIdhw8o8i,
1709 OIdhw8o4i = dnnl_OIdhw8o4i,
1710 gIOw16o16i = dnnl_gIOw16o16i,
1711 gOIw16i16o = dnnl_gOIw16i16o,
1712 gOIw16o16i = dnnl_gOIw16o16i,
1713 gOiw16o = dnnl_gOiw16o,
1714 gOIw4i16o4i = dnnl_gOIw4i16o4i,
1715 gOIw2i8o4i = dnnl_gOIw2i8o4i,
1716 gOIw4i4o = dnnl_gOIw4i4o,
1717 gOIw4o4i = dnnl_gOIw4o4i,
1718 gOiw4o = dnnl_gOiw4o,
1719 gOIw8i16o2i = dnnl_gOIw8i16o2i,
1720 gOIw8i8o = dnnl_gOIw8i8o,
1721 gOIw8o16i2o = dnnl_gOIw8o16i2o,
1722 gOIw8o8i = dnnl_gOIw8o8i,
1723 gOIw8o4i = dnnl_gOIw8o4i,
1724 gOwi16o = dnnl_gOwi16o,
1725 gOwI16o2i = dnnl_gOwI16o2i,
1726 gOwi4o = dnnl_gOwi4o,
1727 gOwi8o = dnnl_gOwi8o,
1728 Goiw8g = dnnl_Goiw8g,
1729 Goiw16g = dnnl_Goiw16g,
1730 gIOhw16o16i = dnnl_gIOhw16o16i,
1731 gOhwi16o = dnnl_gOhwi16o,
1732 gOhwI16o2i = dnnl_gOhwI16o2i,
1733 gOhwi4o = dnnl_gOhwi4o,
1734 gOhwi8o = dnnl_gOhwi8o,
1735 Goihw16g = dnnl_Goihw16g,
1736 gOIhw16i16o = dnnl_gOIhw16i16o,
1737 gOIhw16o16i = dnnl_gOIhw16o16i,
1738 gOihw16o = dnnl_gOihw16o,
1739 gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1740 gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1741 gOIhw4i4o = dnnl_gOIhw4i4o,
1742 gOIhw4o4i = dnnl_gOIhw4o4i,
1743 gOihw4o = dnnl_gOihw4o,
1744 Goihw8g = dnnl_Goihw8g,
1745 gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1746 gOIhw8i8o = dnnl_gOIhw8i8o,
1747 gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1748 OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
1749 OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
1750 OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1751 OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1752 gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
1753 gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
1754 gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1755 gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1756 OIhw16i16o4i = dnnl_OIhw16i16o4i,
1757 OIhw16i16o2i = dnnl_OIhw16i16o2i,
1758 gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
1759 gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
1760 gOIhw8o8i = dnnl_gOIhw8o8i,
1761 gOIhw8o4i = dnnl_gOIhw8o4i,
1762 gIOdhw16i16o = dnnl_gIOdhw16i16o,
1763 gIOdhw16o16i = dnnl_gIOdhw16o16i,
1764 gOdhwi16o = dnnl_gOdhwi16o,
1765 gOdhwI16o2i = dnnl_gOdhwI16o2i,
1766 gOdhwi4o = dnnl_gOdhwi4o,
1767 gOdhwi8o = dnnl_gOdhwi8o,
1768 gOIdhw16i16o = dnnl_gOIdhw16i16o,
1769 gOIdhw16o16i = dnnl_gOIdhw16o16i,
1770 gOidhw16o = dnnl_gOidhw16o,
1771 gOIdhw4i4o = dnnl_gOIdhw4i4o,
1772 gOIdhw4o4i = dnnl_gOIdhw4o4i,
1773 gOidhw4o = dnnl_gOidhw4o,
1774 gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1775 gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
1776 gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
1777 gOIdhw8i8o = dnnl_gOIdhw8i8o,
1778 gOIdhw8o8i = dnnl_gOIdhw8o8i,
1779 gOIdhw8o4i = dnnl_gOIdhw8o4i,
1780 gOIw2i4o2i = dnnl_gOIw2i4o2i,
1781 gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
1782 gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
1783 gOIw2o4i2o = dnnl_gOIw2o4i2o,
1784 gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
1785 gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
1786 gOIw4i8o2i = dnnl_gOIw4i8o2i,
1787 gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
1788 gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
1789 gOIw4o8i2o = dnnl_gOIw4o8i2o,
1790 gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
1791 gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
1793 ldOI32o4i = abDC32d4c,
1794 ldgOi32o = abdEc32e,
1795 ldgOI32o2i = abdEC32e2c,
1796 ldgOI32o4i = abdEC32e4c,
1825 bool allow_empty =
false)
1827 validate_dims(adims);
1829 (
int)adims.size(), adims.data(),
convert_to_c(adata_type),
1833 "could not construct a memory descriptor using a " 1853 bool allow_empty =
false)
1855 validate_dims(adims);
1856 if (!strides.empty()) validate_dims(strides, (
int)adims.size());
1858 (
int)adims.size(), adims.data(),
convert_to_c(adata_type),
1859 strides.empty() ? nullptr : &strides[0]);
1862 "could not construct a memory descriptor using " 1883 bool allow_empty =
false)
const {
1884 validate_dims(adims, data.
ndims);
1885 validate_dims(offsets, data.
ndims);
1888 &sub_md, &data, adims.data(), offsets.data());
1891 return desc(sub_md);
1939 if (data.
ndims) validate_dims(adims, 1);
1942 &out_md, &data, (
int)adims.size(), adims.data());
1945 status,
"could not reshape a memory descriptor");
1946 return desc(out_md);
1987 bool allow_empty =
false)
const {
1988 validate_dims(permutation, data.
ndims);
1991 &out_md, &data, permutation.data());
1994 "could not permute axes of a memory descriptor");
1995 return desc(out_md);
2040 explicit operator bool()
const {
return data.
ndims != 0; }
2072 "could not create a memory object");
2089 "could not get a memory descriptor from a memory object");
2090 return desc(*cdesc);
2097 "could not get an engine from a memory object");
2098 return engine(c_engine,
true);
2108 "could not get a native handle from a memory object");
2143 "could not set native handle of a memory object");
2159 "could not set native handle of a memory object");
2183 template <
typename T =
void>
2187 "could not map memory object data");
2188 return static_cast<T *
>(mapped_ptr);
2203 "could not unmap memory object data");
2285 "post-ops index is out of range");
2322 "could not append a sum post-op");
2325 memory::convert_to_c(data_type)),
2326 "could not append a sum post-op");
2335 "could not get parameters of a sum post-op");
2347 get(), index, &scale, &c_data_type),
2348 "could not get parameters of a sum post-op");
2366 float scale,
algorithm aalgorithm,
float alpha,
float beta) {
2369 "could not append an elementwise post-op");
2380 float &alpha,
float &beta)
const {
2383 get(), index, &scale, &c_alg, &alpha, &beta),
2384 "could not get parameters of an elementwise post-op");
2418 int mask,
const std::vector<float> &scales) {
2421 memory::convert_to_c(weights_data_type),
2422 memory::convert_to_c(bias_data_type),
2423 memory::convert_to_c(dst_data_type),
2424 scales.size(), mask, &scales[0]),
2425 "could not append depthwise post-op");
2444 int &mask, std::vector<float> &scales)
const {
2451 const float *c_scales;
2453 &c_weights_data_type, &c_bias_data_type,
2454 &c_dst_data_type, &count, &c_mask, &c_scales),
2455 "could not get parameters of depthwise post-op");
2460 scales.resize(count);
2464 scales[c] = c_scales[c];
2503 int mask,
const std::vector<float> &scales) {
2506 memory::convert_to_c(weights_data_type),
2507 memory::convert_to_c(bias_data_type),
2508 memory::convert_to_c(dst_data_type),
2509 scales.size(), mask, &scales[0]),
2510 "could not append depthwise post-op");
2529 int &mask, std::vector<float> &scales)
const {
2536 const float *c_scales;
2538 &c_weights_data_type, &c_bias_data_type,
2539 &c_dst_data_type, &count, &c_mask, &c_scales),
2540 "could not get parameters of depthwise post-op");
2545 scales.resize(count);
2549 scales[c] = c_scales[c];
2570 "could not append a binary post-op");
2584 "could not get parameters of a binary post-op");
2586 src1_desc.
data = *data;
2609 "could not create primitive attribute");
2626 "could not get scratchpad mode primitive attribute");
2636 "could not set scratchpad mode primitive attribute");
2651 const float *c_scales;
2653 get(), &count, &c_mask, &c_scales),
2654 "could not get output scales primitive attribute");
2655 scales.resize(count);
2659 scales[c] = c_scales[c];
2707 get(), (
dnnl_dim_t)scales.size(), mask, scales.data()),
2708 "could not set output scales primitive attribute");
2722 void get_scales(
int arg,
int &mask, std::vector<float> &scales)
const {
2725 const float *c_scales;
2727 get(), arg, &count, &c_mask, &c_scales),
2728 "could not get scales primitive attributes");
2729 scales.resize(count);
2733 scales[c] = c_scales[c];
2752 void set_scales(
int arg,
int mask,
const std::vector<float> &scales) {
2755 (
dnnl_dim_t)scales.size(), mask, scales.data()),
2756 "could not set scales primitive attribute");
2770 int arg,
int &mask, std::vector<int32_t> &zero_points)
const {
2773 const int32_t *c_zero_points;
2775 get(), arg, &count, &c_mask, &c_zero_points),
2776 "could not get zero points primitive attribute");
2777 zero_points.resize(count);
2781 zero_points[c] = c_zero_points[c];
2805 int arg,
int mask,
const std::vector<int32_t> &zero_points) {
2808 zero_points.data()),
2809 "could not set zero points primitive attribute");
2819 "could not get post-ops primitive attribute");
2820 result.
reset(const_cast<dnnl_post_ops_t>(c_result),
true);
2834 "could not set post-ops primitive attribute");
2873 "could not set RNN data quantization parameters primitive " 2887 float c_scale, c_shift;
2889 get(), &c_scale, &c_shift),
2890 "could not set RNN data quantization parameters primitive " 2924 (
int)scales.size(), mask, scales.data()),
2925 "could not set RNN weights quantization parameters primitive " 2951 const float *c_scales;
2953 get(), &count, &c_mask, &c_scales),
2954 "could not get primitive RNN weights quantization " 2955 "parameters attributes");
2956 scales.resize(count);
2960 scales[c] = c_scales[c];
2990 int mask,
const std::vector<float> &scales) {
2993 get(), (
int)scales.size(), mask, scales.data()),
2994 "could not set primitive RNN weights projection quantization " 2995 "parameters attributes");
3018 int &mask, std::vector<float> &scales) {
3021 const float *c_scales;
3024 get(), &count, &c_mask, &c_scales),
3025 "could not get primitive RNN weights projection quantization " 3026 "parameters attributes");
3027 scales.resize(count);
3031 scales[c] = c_scales[c];
3057 "could not retrieve implementation info string from a " 3058 "primitive descriptor");
3091 if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
3092 [=](
query q) {
return what == q; }))
3094 "memory descriptor query is invalid");
3218 "could not retrieve scratchpad engine from a primitive " 3220 return engine(c_engine,
true);
3228 "could not get attributes from a primitive descriptor");
3231 "could not clone primitive attributes");
3241 "could not get primitive kind from a primitive descriptor");
3252 "could not clone a primitive descriptor");
3305 if (pd ==
nullptr)
return;
3318 rc,
"could not get primitive kind from a primitive descriptor");
3319 if (pd_kind != c_prim_kind)
3321 "primitive descriptor operation kind mismatch");
3331 "could not get propagation kind from the primitive " 3337 && (pd_prop_kind == c_prop_kind1
3338 || pd_prop_kind == c_prop_kind2))) {
3345 "primitive descriptor propagation kind mismatch");
3391 bool allow_empty =
false) {
3395 dst_engine.
get(), attr.get());
3398 "could not create a primitive descriptor for a reorder " 3416 bool allow_empty =
false) {
3425 "could not create a primitive descriptor for a reorder " 3499 const std::vector<memory::desc> &mems) {
3500 std::vector<dnnl_memory_desc_t> c_mems;
3501 c_mems.reserve(mems.size());
3502 for (
const auto &s : mems)
3503 c_mems.push_back(s.data);
3528 const std::vector<memory::desc> &srcs,
const engine &aengine,
3535 (
int)c_srcs.size(), concat_dimension, c_srcs.data(),
3536 attr.get(), aengine.
get()),
3537 "could not create a primitive descriptor for a concat " 3555 const std::vector<memory::desc> &srcs,
const engine &aengine,
3562 (
int)c_api_srcs.size(), concat_dimension,
3563 c_api_srcs.data(), attr.get(), aengine.
get()),
3564 "could not create a primitive descriptor for a concat " 3619 const std::vector<float> &scales,
3620 const std::vector<memory::desc> &srcs,
const engine &aengine,
3622 validate_container_size(scales,
3623 "counts of scales and sources are not equal",
3624 (
int)srcs.size(), (int)srcs.size());
3631 (
int)c_api_srcs.size(), scales.data(),
3632 c_api_srcs.data(), attr.get(), aengine.
get()),
3633 "could not create a primitive descriptor for a sum " 3649 const std::vector<memory::desc> &srcs,
const engine &aengine,
3651 validate_container_size(scales,
3652 "counts of scales and sources are not equal",
3653 (
int)srcs.size(), (int)srcs.size());
3659 (
int)c_api_srcs.size(), scales.data(),
3660 c_api_srcs.data(), attr.get(), aengine.
get()),
3661 "could not create a primitive descriptor for a sum " 3724 bool allow_empty =
false)
3725 : allow_empty_(allow_empty) {
3728 desc, attr ? attr->
get() :
nullptr, aengine.
get(), hint_fwd_pd);
3731 status,
"could not create a primitive descriptor iterator");
3732 pd_iterator.reset(iterator);
3745 status,
"could not advance a primitive descriptor iterator");
3751 bool allow_empty_ =
false;
3755 pd_iterator.
get(allow_empty_));
3758 "could not fetch a primitive descriptor from a primitive " 3759 "descriptor iterator");
3825 &strides[0], &padding_l[0], &padding_r[0]),
3826 "could not create a descriptor for a convolution forward " 3827 "propagation primitive");
3869 &weights_desc.
data,
nullptr, &dst_desc.
data,
3870 &strides[0], &padding_l[0], &padding_r[0]),
3871 "could not create a descriptor for a convolution forward " 3872 "propagation primitive");
3919 &weights_desc.
data, &bias_desc.
data,
3920 &dst_desc.
data, &strides[0], &dilates[0],
3921 &padding_l[0], &padding_r[0]),
3922 "could not create a descriptor for a dilated convolution " 3923 "forward propagation primitive");
3968 &weights_desc.
data,
nullptr,
3969 &dst_desc.
data, &strides[0], &dilates[0],
3970 &padding_l[0], &padding_r[0]),
3971 "could not create a descriptor for a dilated convolution " 3972 "forward propagation primitive");
3992 bool allow_empty =
false)
3994 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4008 const engine &aengine,
bool allow_empty =
false)
4010 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4090 &weights_desc.
data, &diff_dst_desc.
data,
4091 &strides[0], &padding_l[0], &padding_r[0]),
4092 "could not create a descriptor for a convolution backward " 4093 "propagation primitive");
4135 &weights_desc.
data, &diff_dst_desc.
data,
4136 &strides[0], &dilates[0], &padding_l[0],
4138 "could not create a descriptor for a dilated convolution " 4139 "backward propagation primitive");
4163 bool allow_empty =
false)
4165 hint_fwd_pd.
get(), allow_empty) {}
4184 bool allow_empty =
false)
4186 hint_fwd_pd.
get(), allow_empty) {}
4261 &diff_weights_desc.
data, &diff_bias_desc.
data,
4262 &diff_dst_desc.
data, &strides[0], &padding_l[0],
4264 "could not create a descriptor for a convolution weights " 4265 "update primitive");
4302 &diff_weights_desc.
data,
nullptr,
4303 &diff_dst_desc.
data, &strides[0],
4304 &padding_l[0], &padding_r[0]),
4305 "could not create a descriptor for a convolution weights " 4306 "update primitive");
4351 &diff_weights_desc.
data, &diff_bias_desc.
data,
4352 &diff_dst_desc.
data, &strides[0], &dilates[0],
4353 &padding_l[0], &padding_r[0]),
4354 "could not create a descriptor for a dilated convolution " 4355 "weights gradient primitive");
4397 &diff_weights_desc.
data,
nullptr,
4398 &diff_dst_desc.
data, &strides[0], &dilates[0],
4399 &padding_l[0], &padding_r[0]),
4400 "could not create a descriptor for a dilated convolution " 4401 "weights gradient primitive");
4424 bool allow_empty =
false)
4426 hint_fwd_pd.
get(), allow_empty) {}
4444 bool allow_empty =
false)
4446 hint_fwd_pd.
get(), allow_empty) {}
4545 &strides[0], &padding_l[0], &padding_r[0]),
4546 "could not create a descriptor for a deconvolution forward " 4547 "propagation primitive");
4588 &weights_desc.
data,
nullptr, &dst_desc.
data,
4589 &strides[0], &padding_l[0], &padding_r[0]),
4590 "could not create a descriptor for a deconvolution forward " 4591 "propagation primitive");
4637 &weights_desc.
data, &bias_desc.
data,
4638 &dst_desc.
data, &strides[0], &dilates[0],
4639 &padding_l[0], &padding_r[0]),
4640 "could not create a descriptor for a dilated deconvolution " 4641 "forward propagation primitive");
4685 &weights_desc.
data,
nullptr,
4686 &dst_desc.
data, &strides[0], &dilates[0],
4687 &padding_l[0], &padding_r[0]),
4688 "could not create a descriptor for a dilated deconvolution " 4689 "forward propagation primitive");
4709 bool allow_empty =
false)
4711 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4725 const engine &aengine,
bool allow_empty =
false)
4727 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4802 &weights_desc.
data, &diff_dst_desc.
data,
4803 &strides[0], &padding_l[0], &padding_r[0]),
4804 "could not create a descriptor for a deconvolution " 4805 "backward propagation primitive");
4846 &weights_desc.
data, &diff_dst_desc.
data,
4847 &strides[0], &dilates[0], &padding_l[0],
4849 "could not create a descriptor for a dilated deconvolution " 4850 "backward propagation primitive");
4874 bool allow_empty =
false)
4876 hint_fwd_pd.
get(), allow_empty) {}
4895 bool allow_empty =
false)
4897 hint_fwd_pd.
get(), allow_empty) {}
4971 &diff_weights_desc.
data, &diff_bias_desc.
data,
4972 &diff_dst_desc.
data, &strides[0], &padding_l[0],
4974 "could not create a descriptor for a deconvolution weights " 4975 "update primitive");
5011 &src_desc.
data, &diff_weights_desc.
data,
5012 nullptr, &diff_dst_desc.
data, &strides[0],
5013 &padding_l[0], &padding_r[0]),
5014 "could not create a descriptor for a deconvolution weights " 5015 "update primitive");
5059 &diff_weights_desc.
data, &diff_bias_desc.
data,
5060 &diff_dst_desc.
data, &strides[0], &dilates[0],
5061 &padding_l[0], &padding_r[0]),
5062 "could not create a descriptor for a dilated deconvolution " 5063 "weights gradient primitive");
5104 &diff_weights_desc.
data,
nullptr,
5105 &diff_dst_desc.
data, &strides[0], &dilates[0],
5106 &padding_l[0], &padding_r[0]),
5107 "could not create a descriptor for a dilated deconvolution " 5108 "weights gradient primitive");
5132 bool allow_empty =
false)
5134 hint_fwd_pd.
get(), allow_empty) {}
5153 bool allow_empty =
false)
5155 hint_fwd_pd.
get(), allow_empty) {}
5225 float alpha,
float beta,
float k = 1.f) {
5229 local_size, alpha, beta, k),
5230 "could not create a descriptor for a lrn forward " 5231 "propagation primitive");
5250 bool allow_empty =
false)
5252 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5265 const engine &aengine,
bool allow_empty =
false)
5267 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5319 float alpha,
float beta,
float k = 1.f) {
5322 &diff_data_desc.
data, &data_desc.
data, local_size,
5324 "could not create a descriptor for a lrn backward " 5325 "propagation primitive");
5348 bool allow_empty =
false)
5350 hint_fwd_pd.
get(), allow_empty) {}
5368 bool allow_empty =
false)
5370 hint_fwd_pd.
get(), allow_empty) {}
5452 &dst_desc.
data, &strides[0], &kernel[0],
5453 &padding_l[0], &padding_r[0]),
5454 "could not create a descriptor for a pooling forward " 5455 "propagation primitive");
5474 bool allow_empty =
false)
5476 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5489 const engine &aengine,
bool allow_empty =
false)
5491 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5561 &diff_dst_desc.
data, &strides[0], &kernel[0],
5562 &padding_l[0], &padding_r[0]),
5563 "could not create a descriptor for a pooling backward " 5564 "propagation primitive");
5587 bool allow_empty =
false)
5589 hint_fwd_pd.
get(), allow_empty) {}
5607 bool allow_empty =
false)
5609 hint_fwd_pd.
get(), allow_empty) {}
5687 &data_desc.
data, alpha, beta),
5688 "could not create a descriptor for an eltwise forward " 5689 "propagation primitive");
5709 bool allow_empty =
false)
5711 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5725 const engine &aengine,
bool allow_empty =
false)
5727 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5779 &diff_data_desc.
data, &data_desc.
data, alpha, beta),
5780 "could not create a descriptor for an eltwise backward " 5781 "propagation primitive");
5805 bool allow_empty =
false)
5807 hint_fwd_pd.
get(), allow_empty) {}
5826 bool allow_empty =
false)
5828 hint_fwd_pd.
get(), allow_empty) {}
5890 &data_desc.
data, softmax_axis),
5891 "could not create a descriptor for a softmax forward " 5892 "propagation primitive");
5912 bool allow_empty =
false)
5914 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5928 const engine &aengine,
bool allow_empty =
false)
5930 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5979 &data_desc.
data, softmax_axis),
5980 "could not create a descriptor for a softmax backward " 5981 "propagation primitive");
6005 bool allow_empty =
false)
6007 hint_fwd_pd.
get(), allow_empty) {}
6026 bool allow_empty =
false)
6028 hint_fwd_pd.
get(), allow_empty) {}
6087 int logsoftmax_axis) {
6090 &data_desc.
data, logsoftmax_axis),
6091 "could not create a descriptor for a logsoftmax forward " 6092 "propagation primitive");
6112 bool allow_empty =
false)
6114 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6128 const engine &aengine,
bool allow_empty =
false)
6130 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6180 int logsoftmax_axis) {
6182 &diff_data_desc.
data, &data_desc.
data,
6184 "could not create a descriptor for a logsoftmax backward " 6185 "propagation primitive");
6209 bool allow_empty =
false)
6211 hint_fwd_pd.
get(), allow_empty) {}
6230 bool allow_empty =
false)
6232 hint_fwd_pd.
get(), allow_empty) {}
6315 "could not create a descriptor for a batch normalization " 6316 "forward propagation primitive");
6337 bool allow_empty =
false)
6339 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6353 const engine &aengine,
bool allow_empty =
false)
6355 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6400 "could not retrieve a descriptor from a primitive " 6401 "descriptor for batch normalization forward propagation " 6441 &diff_data_desc.
data, &data_desc.
data,
6443 "could not create a descriptor for a batch normalization " 6444 "backward propagation primitive");
6469 bool allow_empty =
false)
6471 hint_fwd_pd.
get(), allow_empty) {}
6490 bool allow_empty =
false)
6492 hint_fwd_pd.
get(), allow_empty) {}
6595 "could not create a descriptor for a layer normalization " 6596 "forward propagation primitive");
6615 "could not create a descriptor for a layer normalization " 6616 "forward propagation primitive");
6637 bool allow_empty =
false)
6639 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6653 const engine &aengine,
bool allow_empty =
false)
6655 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6698 "could not retrieve a descriptor from a primitive " 6699 "descriptor for layer normalization forward propagation " 6741 &diff_data_desc.
data, &data_desc.
data,
6743 "could not create a descriptor for a batch normalization " 6744 "backward propagation primitive");
6764 &diff_data_desc.
data, &data_desc.
data,
6766 "could not create a descriptor for a batch normalization " 6767 "backward propagation primitive");
6792 bool allow_empty =
false)
6794 hint_fwd_pd.
get(), allow_empty) {}
6813 bool allow_empty =
false)
6815 hint_fwd_pd.
get(), allow_empty) {}
6905 &src_desc.
data, &weights_desc.
data,
6907 "could not create a descriptor for an inner product " 6908 "forward propagation primitive");
6930 &weights_desc.
data,
nullptr, &dst_desc.
data),
6931 "could not create a descriptor for an inner product " 6932 "forward propagation primitive");
6952 bool allow_empty =
false)
6954 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6968 const engine &aengine,
bool allow_empty =
false)
6970 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7025 &diff_src_desc.
data, &weights_desc.
data,
7026 &diff_dst_desc.
data),
7027 "could not create a descriptor for an inner product " 7028 "backward propagation primitive");
7053 bool allow_empty =
false)
7055 hint_fwd_pd.
get(), allow_empty) {}
7074 bool allow_empty =
false)
7076 hint_fwd_pd.
get(), allow_empty) {}
7130 &src_desc.
data, &diff_weights_desc.
data,
7131 &diff_bias_desc.
data, &diff_dst_desc.
data),
7132 "could not create a descriptor for an inner product " 7133 "weights gradient primitive");
7151 &src_desc.
data, &diff_weights_desc.
data,
nullptr,
7152 &diff_dst_desc.
data),
7153 "could not create a descriptor for an inner product " 7154 "weights gradient primitive");
7178 bool allow_empty =
false)
7180 hint_fwd_pd.
get(), allow_empty) {}
7199 bool allow_empty =
false)
7201 hint_fwd_pd.
get(), allow_empty) {}
7251 using primitive_desc::primitive_desc;
7435 "could not retrieve a descriptor from a primitive descriptor " 7436 "for an RNN primitive");
7443 && (
rnn_d->prop_kind == c_prop_kind1
7444 ||
rnn_d->prop_kind == c_prop_kind2)
7445 &&
rnn_d->cell_kind == c_cell_kind;
7449 "mismatch between expected and provided descriptors for an " 7511 float beta = 0.0f) {
7517 &src_iter_desc.
data, &weights_layer_desc.
data,
7518 &weights_iter_desc.
data, &bias_desc.
data,
7519 &dst_layer_desc.
data, &dst_iter_desc.
data,
7521 "could not create a descriptor for a vanilla RNN forward " 7522 "propagation primitive");
7542 bool allow_empty =
false)
7544 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
7558 const engine &aengine,
bool allow_empty =
false)
7560 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7691 float beta = 0.0f) {
7697 &src_iter_desc.
data, &weights_layer_desc.
data,
7698 &weights_iter_desc.
data, &bias_desc.
data,
7699 &dst_layer_desc.
data, &dst_iter_desc.
data,
7700 &diff_src_layer_desc.
data, &diff_src_iter_desc.
data,
7701 &diff_weights_layer_desc.
data,
7702 &diff_weights_iter_desc.
data, &diff_bias_desc.
data,
7703 &diff_dst_layer_desc.
data, &diff_dst_iter_desc.
data,
7705 "could not create a descriptor for a vanilla RNN backward " 7706 "propagation primitive");
7730 bool allow_empty =
false)
7732 hint_fwd_pd.
get(), allow_empty) {}
7751 bool allow_empty =
false)
7753 hint_fwd_pd.
get(), allow_empty) {}
7915 &src_iter_desc.
data, &src_iter_c_desc.
data,
7916 &weights_layer_desc.
data, &weights_iter_desc.
data,
7917 &weights_peephole_desc.
data,
7918 &weights_projection_desc.
data, &bias_desc.
data,
7919 &dst_layer_desc.
data, &dst_iter_desc.
data,
7921 "could not create a descriptor for an LSTM forward " 7922 "propagation primitive");
7982 &src_iter_desc.
data, &src_iter_c_desc.
data,
7983 &weights_layer_desc.
data, &weights_iter_desc.
data,
7984 &weights_peephole_desc.
data, &bias_desc.
data,
7985 &dst_layer_desc.
data, &dst_iter_desc.
data,
7987 "could not create a descriptor for an LSTM forward " 7988 "propagation primitive");
8042 &src_iter_desc.
data, &src_iter_c_desc.
data,
8043 &weights_layer_desc.
data, &weights_iter_desc.
data,
8044 &bias_desc.
data, &dst_layer_desc.
data,
8045 &dst_iter_desc.
data, &dst_iter_c_desc.
data,
8047 "could not create a descriptor for an LSTM forward " 8048 "propagation primitive");
8067 bool allow_empty =
false)
8069 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
8082 const engine &aengine,
bool allow_empty =
false)
8084 &adesc.data, &attr, aengine, nullptr, allow_empty) {}
8270 &src_iter_desc.
data, &src_iter_c_desc.
data,
8271 &weights_layer_desc.
data, &weights_iter_desc.
data,
8272 &weights_peephole_desc.
data,
8273 &weights_projection_desc.
data, &bias_desc.
data,
8274 &dst_layer_desc.
data, &dst_iter_desc.
data,
8275 &dst_iter_c_desc.
data, &diff_src_layer_desc.
data,
8276 &diff_src_iter_desc.
data,
8277 &diff_src_iter_c_desc.
data,
8278 &diff_weights_layer_desc.
data,
8279 &diff_weights_iter_desc.
data,
8280 &diff_weights_peephole_desc.
data,
8281 &diff_weights_projection_desc.
data,
8282 &diff_bias_desc.
data, &diff_dst_layer_desc.
data,
8283 &diff_dst_iter_desc.
data,
8284 &diff_dst_iter_c_desc.
data,
8286 "could not create a descriptor for an LSTM backward " 8287 "propagation primitive");
8380 &src_iter_desc.
data, &src_iter_c_desc.
data,
8381 &weights_layer_desc.
data, &weights_iter_desc.
data,
8382 &weights_peephole_desc.
data, &bias_desc.
data,
8383 &dst_layer_desc.
data, &dst_iter_desc.
data,
8384 &dst_iter_c_desc.
data, &diff_src_layer_desc.
data,
8385 &diff_src_iter_desc.
data,
8386 &diff_src_iter_c_desc.
data,
8387 &diff_weights_layer_desc.
data,
8388 &diff_weights_iter_desc.
data,
8389 &diff_weights_peephole_desc.
data,
8390 &diff_bias_desc.
data, &diff_dst_layer_desc.
data,
8391 &diff_dst_iter_desc.
data,
8392 &diff_dst_iter_c_desc.
data,
8394 "could not create a descriptor for an LSTM backward " 8395 "propagation primitive");
8477 &src_iter_desc.
data, &src_iter_c_desc.
data,
8478 &weights_layer_desc.
data, &weights_iter_desc.
data,
8479 &bias_desc.
data, &dst_layer_desc.
data,
8480 &dst_iter_desc.
data, &dst_iter_c_desc.
data,
8481 &diff_src_layer_desc.
data, &diff_src_iter_desc.
data,
8482 &diff_src_iter_c_desc.
data,
8483 &diff_weights_layer_desc.
data,
8484 &diff_weights_iter_desc.
data, &diff_bias_desc.
data,
8485 &diff_dst_layer_desc.
data, &diff_dst_iter_desc.
data,
8486 &diff_dst_iter_c_desc.
data,
8488 "could not create a descriptor for an LSTM backward " 8489 "propagation primitive");
8512 bool allow_empty =
false)
8514 hint_fwd_pd.
get(), allow_empty) {}
8532 bool allow_empty =
false)
8534 hint_fwd_pd.
get(), allow_empty) {}
8716 &src_iter_desc.
data, &weights_layer_desc.
data,
8717 &weights_iter_desc.
data, &bias_desc.
data,
8718 &dst_layer_desc.
data, &dst_iter_desc.
data,
8720 "could not create a descriptor for a GRU forward " 8721 "propagation primitive");
8740 bool allow_empty =
false)
8742 &adesc.data, nullptr, aengine, nullptr, allow_empty) {}