Links


Notation









































Data Types

Categories Name Details WebNN ONNX DML XNNPACK StableHLO TOSA NumPy TensorFlow TensorFlow Lite PyTorch CoreML BNNS MPS MLX NCNN CNTK OpenVINO OneDNN ANN (for future additions)
Integer Unsigned uint1 x x x x i1 (bool) x x x x x x x x x x ? ov::element::Type_t::u1 x ? ?
Integer Unsigned uint4 x x x x ui4 x x x x x x x x x x ? ov::element::Type_t::u4 x ? ?
Integer Unsigned uint8 uint8 ONNX.TensorProto.DataType.UINT8 DML_TENSOR_DATA_TYPE_UINT8 x ui8 tosa.uint8_t numpy.uint8 / ubyte tensorflow.uint8 tensorflow.uint8 torch.uint8 x BNNSDataTypeUInt8 MPSDataType.uInt8 mx.uint8 x CNTK.DataType.UChar ov::element::Type_t::u8 dnnl_data_type_t::dnnl_u8 ? ?
Integer Unsigned uint16 x ONNX.TensorProto.DataType.UINT16 DML_TENSOR_DATA_TYPE_UINT16 x ui16 tosa.uint16_t numpy.uint16 / ushort tensorflow.uint16 tensorflow.uint16 x x BNNSDataTypeUInt16 MPSDataType.uInt16 mx.uint16 x x ov::element::Type_t::u16 x ? ?
Integer Unsigned uint32 uint32 ONNX.TensorProto.DataType.UINT32 DML_TENSOR_DATA_TYPE_UINT32 x ui32 x numpy.uint32 / uintc tensorflow.uint32 tensorflow.uint32 x x BNNSDataTypeUInt32 MPSDataType.uInt32 mx.uint32 x x ov::element::Type_t::u32 x ANEURALNETWORKS_UINT32 ?
Integer Unsigned uint48 x x x x x x x x x x x x x x x x x x ? ?
Integer Signed uint64 uint64 ONNX.TensorProto.DataType.UINT64 DML_TENSOR_DATA_TYPE_UINT64 x ui64 x numpy.uint64 / uint tensorflow.uint64 tensorflow.uint64 x x BNNSDataTypeUInt64 MPSDataType.uInt64 mx.uint64 x x ov::element::Type_t::u64 x ? ?
Integer Signed int4 int=3, sign=1, twos complement x x x x si4 tosa.int4_t x x x x x x x x x x ov::element::Type_t::i4 x ? ?
Integer Signed int8 int=7, sign=1, twos complement int8 ONNX.TensorProto.DataType.INT8 DML_TENSOR_DATA_TYPE_INT8 x si8 tosa.int8_t numpy.int8 / byte tensorflow.int8 tensorflow.int8 torch.int8 x BNNSDataTypeInt8 MPSDataType.int8 mx.int8 int8 CNTK.DataType.Int8 ov::element::Type_t::i8 dnnl_data_type_t::dnnl_s8 ? ?
Integer Signed int16 int=15, sign=1, twos complement x ONNX.TensorProto.DataType.INT16 DML_TENSOR_DATA_TYPE_INT16 x si16 tosa.int16_t numpy.int16 / short tensorflow.int16 tensorflow.int16 torch.int16 / short x BNNSDataTypeInt16 MPSDataType.int16 mx.int16 x CNTK.DataType.Int16 ov::element::Type_t::i16 x ? ?
Integer Signed int32 int=31, sign=1, twos complement int32 ONNX.TensorProto.DataType.INT32 DML_TENSOR_DATA_TYPE_INT32 x si32 tosa.int32_t numpy.int32 / intc tensorflow.int32 tensorflow.int32 torch.int32 / int ArrayFeatureType.ArrayDataType.INT32 BNNSDataTypeInt32 MPSDataType.int32 mx.int32 x x ov::element::Type_t::i32 dnnl_data_type_t::dnnl_s32 ANEURALNETWORKS_INT32 ?
Integer Signed int48 int=47, sign=1, twos complement x x x x x tosa.int48_t x x x x x x x x x x x x ? ?
Integer Signed int64 int=63, sign=1, twos complement int64 ONNX.TensorProto.DataType.INT64 DML_TENSOR_DATA_TYPE_INT64 x si64 x numpy.int64 / int_ tensorflow.int64 tensorflow.int64 torch.int64 / long x BNNSDataTypeInt64 MPSDataType.int64 mx.int64 x x ov::element::Type_t::i64 x ? ?
Float Signed float8f3e4s1Fn frac=3 exp=4 sign=1 (1) (2) (3), has nan, no inf x ONNX.TensorProto.DataType.FLOAT8E4M3FN x x f8E4M3FN x x x x x x x x x x x x dnnl_data_type_t::dnnl_f8_e4m3 ? ?
Float Signed float8f3e4s1FnUz frac=3 exp=4 sign=1 (1) (2) (3), has nan, no inf, no negative zero x ONNX.TensorProto.DataType.FLOAT8E4M3FNUZ x x f8E4M3FNUZ x x x x x x x x x x x x x ? ?
Float Signed float8f2e5s1 frac=2 exp=5 sign=1 (1) (2) (3), has nan, has inf, has negative zero x ONNX.TensorProto.DataType.FLOAT8E5M2 x x f8E5M2 x x x x x x x x x x x x dnnl_data_type_t::dnnl_f8_e5m2 ? ?
Float Signed float8f2e5s1FnUz frac=2 exp=5 sign=1 (1) (2) (3), has nan, no inf, no negative zero x ONNX.TensorProto.DataType.FLOAT8E5M2FNUZ x x f8E5M2FNUZ x x x x x x x x x x x x x ? ?
Float Signed float8f2e5bias4s1FnUz frac=2 exp=5 sign=1 (1) TODO: Figure out details x x x x f8E4M3B11FNUZ x x x x x x x x x x x x x ? ?
Float Signed float16f10e5s1 IEEE frac=10 exp=5 sign=1 float16 ONNX.TensorProto.DataType.FLOAT16 DML_TENSOR_DATA_TYPE_FLOAT16 xnn_datatype_fp16 f16 tosa.fp16_t numpy.float16 / half tensorflow.float16 / half tensorflow.float16 / half torch.float16 / half ArrayFeatureType.ArrayDataType.FLOAT16 BNNSDataTypeFloat16 MPSDataType.float16 mx.float16 float16 CNTK.DataType.Float16 ov::element::Type_t::f16 dnnl_data_type_t::dnnl_f16 ANEURALNETWORKS_FLOAT16 ?
Float Signed float16f7e8s1 Brain frac=7 exp=8 sign=1 x ONNX.TensorProto.DataType.BFLOAT16 x x bf16 tosa.bf16_t x tensorflow.bfloat16 tensorflow.bfloat16 torch.bfloat16 x BNNSDataTypeBFloat16 MPSDataType.bFloat16 x bfloat16 x ov::element::Type_t::bf16 dnnl_data_type_t::dnnl_bf16 ? ?
Float Signed float32f23e8s1 IEEE frac=23 exp=8 sign=1 float32 ONNX.TensorProto.DataType.FLOAT DML_TENSOR_DATA_TYPE_FLOAT32 xnn_datatype_fp32 f32 tosa.fp32_t numpy.float32 / single tensorflow.float32 / float tensorflow.float32 / float torch.float32 / float ArrayFeatureType.ArrayDataType.FLOAT32 BNNSDataTypeFloat32 MPSDataType.float32 mx.float32 float32 CNTK.DataType.Float ov::element::Type_t::f32 dnnl_data_type_t::dnnl_f32 ANEURALNETWORKS_FLOAT32 ?
Float Signed float64f52e11s1 IEEE frac=52 exp=11 sign=1 float64 ONNX.TensorProto.DataType.DOUBLE DML_TENSOR_DATA_TYPE_FLOAT64 x f64 tosa.fp64_t numpy.float64 / double / float_ tensorflow.float64 / double tensorflow.float64 / double torch.float64 / double ArrayFeatureType.ArrayDataType.DOUBLE x x x x CNTK.DataType.Double ov::element::Type_t::f64 dnnl_data_type_t::dnnl_f64 ? ?
Float Signed float16 x 2 NA x x x x x x x x x x x MPSDataType.complexFloat16 x x x x x ? ?
Float Signed float32 x 2 NA ONNX.TensorProto.DataType.COMPLEX64 x x x x numpy.complex64 / singlecomplex tensorflow.complex64 tensorflow.complex64 torch.complex64 / cfloat x x MPSDataType.complexFloat32 x x x x x ? ?
Float Signed float64 x 2 NA ONNX.TensorProto.DataType.COMPLEX128 x x x x numpy.complex128 / doublecomplex tensorflow.complex128 tensorflow.complex128 torch.complex128 / cdouble x x x x x x x x ? ?
Boolean bool8 typically just lowest bit is used, but anything != 0 is true uint8 ONNX.TensorProto.DataType.BOOL DML_TENSOR_DATA_TYPE_UINT8 x x tosa.bool_t numpy.bool tensorflow.bool tensorflow.bool torch.bool x BNNSDataTypeBoolean (bit size?) MPSDataType.bool (bit size?) bool_ x x ov::element::Type_t::boolean dnnl_data_type_t::dnnl_boolean ANEURALNETWORKS_BOOL8 ?
String string8 array of char8's (typically UTF-8) NA ONNX.TensorProto.DataType.STRING x x x x x tensorflow.string tensorflow.string x x x x x x x x x ? ?

Operators

Categories Name Details/Formula WebNN ONNX DML XNNPACK StableHLO TOSA NumPy TensorFlow TensorFlowLite PyTorch CoreML BNNS MPS MLX NCNN CNTK OpenVINO OneDNN ANN (for future additions) Precision
Elementwise Elementwise Generic
function elementwiseNullary(functor, dataType, dimensions) output = new Tensor(dataType, dimensions) for each coordinate in output.coordinates output[coordinate] = functor() endfor return output endfunction
function elementwiseUnary(input, functor) output = new Tensor(input.dataType, input.dimensions) for each coordinate in input.coordinates output[coordinate] = functor(input[coordinate]) endfor return output endfunction
function elementwiseBinary(a, b, functor) outputDimensions = broadcastDimensions(a.dimensions, b.dimensions) output = new Tensor(a.dataType, outputDimensions) aBroadcast = broadcast(a, outputDimensions) // Some implementations can directly use strides to avoid intermediates. bBroadcast = broadcast(b, outputDimensions) for each coordinate in input.coordinates output[coordinate] = functor(aBroadcast[coordinate], bBroadcast[coordinate]) endfor return output endfunction
function elementwiseTrinary(a, b, c, functor) outputDimensions = broadcastDimensions(a.dimensions, b.dimensions, c.dimensions) output = new Tensor(a.dataType, outputDimensions) aBroadcast = broadcast(a, outputDimensions) // Some implementations can directly use strides to avoid intermediates. bBroadcast = broadcast(b, outputDimensions) cBroadcast = broadcast(c, outputDimensions) for each coordinate in input.coordinates output[coordinate] = functor(aBroadcast[coordinate], bBroadcast[coordinate], cBroadcast[coordinate]) endfor return output endfunction
NA NA NA NA NA NA NA NA NA NA NA ? ? ? UnaryOp ? ? ? ? Exact
Elementwise Identity function identity(input) = elementwiseUnary(input, (x) => x) identity Identity DML_OPERATOR_ELEMENT_WISE_IDENTITY or DML_ACTIVATION_IDENTITY ? stablehlo.optimization_barrier tosa.identity numpy.identity tf.identity tf.identity torch.nn.Identity CopyLayerParams BNNSActivationFunctionIdentity ? mlx.identity ? activation_identity ? dnnl::reorder ? Exact
Input Constant function constant() = value constant Constant NA just provide the tensor data ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Input Constant Of Shape function constantOfShape(scalarValue, newShape) = broadcast(scalarValue, newShape.dimensions) constant+expand ConstantOfShape DML_OPERATOR_ELEMENT_WISE_IDENTITY with zero strides to broadcast ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Add function add(a, b) = elementwiseBinary(a, b, (x,y) => x + y) add Add DML_OPERATOR_ELEMENT_WISE_ADD ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 1 ULP
Elementwise Math Subtract function subtract(a, b) = elementwiseBinary(a, b, (x,y) => x - y) sub Sub DML_OPERATOR_ELEMENT_WISE_SUBTRACT ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 1 ULP
Elementwise Math Multiply function multiply(a, b) = elementwiseBinary(a, b, (x,y) => x * y) mul Mul DML_OPERATOR_ELEMENT_WISE_MULTIPLY ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 1 ULP
Elementwise Math Divide function divide(a, b) = elementwiseBinary(a, b, (x,y) => x / y) div Div DML_OPERATOR_ELEMENT_WISE_DIVIDE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 1 ULP
Elementwise Math Reciprocal function reciprocal(input) = elementwiseUnary(input, (x) => 1 / x) reciprocal Reciprocal DML_OPERATOR_ELEMENT_WISE_RECIP ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 1 ULP
Elementwise Math Modulus Truncate function modulusTruncate(a, b) = elementwiseBinary(a, b, (x,y) => x - (y * floor(x / y))) // Result sign follows divisor sign. ? Mod fmod=0 DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 2 ULP
Elementwise Math Modulus Floor function modulusFloor(a, b) = elementwiseBinary(a, b, (x,y) => x - (y * trunc(x / y)) // Result sign follows dividend sign. ? Mod fmod=1 DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 2 ULP
Elementwise Math Power function pow(x, exponent) = elementwiseBinary(x, exponent, (x, exponent) => powScalar(x, exponent)) pow Pow DML_OPERATOR_ELEMENT_WISE_POW ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Root function root(x, exponent) = elementwiseBinary(x, exponent, (x, exponent) => root(x, exponent)) // OR pow(x, reciprocal(exponent)) reciprocal and pow ONNX NA DML_OPERATOR_ELEMENT_WISE_RECIP & POW ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Square root function sqrt(input) = elementwiseUnary(input, (x) => sqrt(x)) // OR pow(x, 1/2) sqrt Sqrt DML_OPERATOR_ELEMENT_WISE_SQRT ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Exponent function exp(input) = elementwiseBinary(input, (x) => exp(x)) // OR pow(2.71828, x) exp Exp DML_OPERATOR_ELEMENT_WISE_EXP ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Logarithm Natural function logarithm(x, base) = elementwiseUnary(x, (x) => log(x, base)) function logarithmNatural(x) = logarithm(x, 2.71828) function logarithmBase10(x) = logarithm(x, 10) // OR elementwiseUnary(x, (x) => logNatural(x) / logNatural(2)) // Notes: recall log(exp(x)) == x, and log(x, b) == ln(x) / ln(b) log Log DML_OPERATOR_ELEMENT_WISE_LOG ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Absolute function absolute(input) = elementwiseUnary(input, (x) => abs(x)) abs Abs DML_OPERATOR_ELEMENT_WISE_ABS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Negate function negate(input) = elementwiseUnary(input, (x) => -x) neg Neg DML_OPERATOR_ELEMENT_WISE_NEGATE
DML_OPERATOR_ELEMENT_WISE_IDENTITY with scale = -1
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Ceiling function ceiling(input) = elementwiseUnary(input, (x) => ceil(x)) ceil Ceil DML_OPERATOR_ELEMENT_WISE_CEIL ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Floor function floor(input) = elementwiseUnary(input, (x) => floor(x)) floor Floor DML_OPERATOR_ELEMENT_WISE_FLOOR ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Clamp function clamp(input, minValue, maxValue) = elementwiseUnary(input, (x) => min(max(x, maxValue), minValue)) clamp Clip DML_OPERATOR_ELEMENT_WISE_CLIP ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Gauss Error Function
function erf(input) = elementwiseUnary(input, (x) => 1/sqrt(pi) * integrate(i = -x to x, e ^ -(i^2))) // OR 2/sqrt(pi) * integrate(i = 0 to x, e ^ -(i^2)) double f(double x) { // Polynomial approximation constants. double a1 = 0.254829592; double a2 = -0.284496736; double a3 = 1.421413741; double a4 = -1.453152027; double a5 = 1.061405429; double p = 0.3275911; // Save the sign of x. int sign = 1; if (x < 0) sign = -1; x = fabs(x); // Approximate the formula A&S 7.1.26: // 2/sqrt(pi) * integrate(i = 0 to x, e ^ -(i^2)) double t = 1.0/(1.0 + p*x); double y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1) * t*expₑ(-x*x); return sign * y; }
erf Erf DML_OPERATOR_ELEMENT_WISE_ERF ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Is Not a Number function isNan(input) = elementwiseUnary(input, (x) => isNan(x)) isNan(float32 x) = (x.reinterpretAs(uint32) & 0x7FFFFFFF) > 0x7F800000 Any float32 value with all 1's for exponent and a nonzero mantissa is NaN. The sign is ignored. e.g. s1111111 10000000 0000000 00000001 : float32 NaN e.g. s1111100 00000001 : float16 NaN ? IsNan DML_OPERATOR_ELEMENT_WISE_IS_NAN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Is Infinity function isInfinity(input) = elementwiseUnary(input, (x) => isinf(x) && iif(x > 0, detectPositive, detectNegative)) isinf(x) = (x.reinterpretAs(uint32) & 0x7FFFFFFF) == 0x7F800000 Check for positive or negative infinity. For infinity, test that all exponent bits are one, and all mantissa bits are 0: ? IsInf DML_OPERATOR_ELEMENT_WISE_IS_INFINITY ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Sign function sign(input) = elementwiseUnary(input, (x) => if x > 0 then 1 elif x < 0 then -1 elif x == 0 then 0 else NaN) ? Sign DML_OPERATOR_ELEMENT_WISE_SIGN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Comparison Equal function equal(a, b) = elementwiseBinary(a, b, (x, y) => (x == y) equal Equal DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Comparison Unequal function unequal(a, b) = elementwiseBinary(a, b, (x, y) => (x != y)) not(equal()) Not and Equal DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT and EQUALS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Comparison Greater function greater(a, b) = elementwiseBinary(a, b, (x, y) => (x > y)) greater Greater DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Comparison Lesser function lesser(a, b) = elementwiseBinary(a, b, (x, y) => (x < y)) lesser Less DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Comparison Greater or Equal function greaterOrEqual(a, b) = elementwiseBinary(a, b, (x, y) => (x > y)) greaterOrEqual GreaterOrEqual DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Comparison Lesser or Equal function lesserOrEqual(a, b) = elementwiseBinary(a, b, (x, y) => (x < y)) lesserOrEqual LessOrEqual DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Bitwise Bitwise Not function bitwiseNot(input) = elementwiseUnary(input, (x) => ~x) NA BitwiseNot DML_OPERATOR_ELEMENT_WISE_BIT_NOT ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Bitwise Bitwise And function bitwiseAnd(a, b) = elementwiseBinary(a, b, (x, y) = > x & y) NA BitwiseAnd DML_OPERATOR_ELEMENT_WISE_BIT_AND ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Bitwise Bitwise Or function bitwiseOr(a, b) = elementwiseBinary(a, b, (x, y) = > x | y) NA BitwiseOr DML_OPERATOR_ELEMENT_WISE_BIT_OR ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Bitwise Bitwise Xor function bitwiseXor(a, b) = elementwiseBinary(a, b, (x, y) = > x ^ y) NA BitwiseXor DML_OPERATOR_ELEMENT_WISE_BIT_XOR ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Bitwise Bitwise Left Shift function bitwiseLeftShift(a, b) = elementwiseBinary(a, b, (x, y) = > x << y) NA BitShift direction = LEFT DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Bitwise Bitwise Right Shift function bitwiseRightShift(a, b) = elementwiseBinary(a, b, (x, y) = > x >> y) NA BitShift direction = RIGHT DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_Right ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Bitwise Bitwise Count function bitwiseCount(input) = elementwiseBinary(input, (x) => (x & 1) + iif(x > 0, f(x >> 1), 0)) Add one to count for each set bit in x. NA NA DML_OPERATOR_ELEMENT_WISE_BIT_COUNT ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Logical Logical Not function logicalNot(input) = elementwiseBinary(input, (x) = > !x) ? Not DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT ? ? ? ? ? ? ? LogicalNotLayerParams ? ? ? ? ? ? ? ? Exact
Elementwise, Logical Logical And function logicalAnd(a, b) = elementwiseBinary(a, b, (x, y) = > x && y) NA And DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Logical Logical Or function logicalOr(a, b) = elementwiseBinary(a, b, (x, y) = > x || y) NA Or DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise, Logical Logical Xor function logicalXor(a, b) = elementwiseBinary(a, b, (x, y) = > !!x xor !!y) NA Xor DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Trigonometric Sine function sine(input) = elementwiseUnary(input, (x) = > sin(x)) sin Sin DML_OPERATOR_ELEMENT_WISE_SIN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Cosine function cosine(input) = elementwiseUnary(input, (x) = > cos(x)) cos Cos DML_OPERATOR_ELEMENT_WISE_COS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Tangent function tangent(input) = elementwiseUnary(input, (x) = > tan(x)) tan Tan DML_OPERATOR_ELEMENT_WISE_TAN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Arcsine function arcsine(input) = elementwiseUnary(input, (x) = > asin(x)) ? Asin DML_OPERATOR_ELEMENT_WISE_ASIN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Arccosine function arccosine(input) = elementwiseUnary(input, (x) = > acos(x)) ? Acos DML_OPERATOR_ELEMENT_WISE_ACOS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Arctangent function arctangent(input) = elementwiseUnary(input, (x) = > atan(x)) ? Atan DML_OPERATOR_ELEMENT_WISE_ATAN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Hyperbolic Sine function hyperbolicSine(input) = elementwiseUnary(input, (x) = > sinh(x)) ? Sinh DML_OPERATOR_ELEMENT_WISE_SINH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Hyperbolic Cosine function hyperbolicCosine(input) = elementwiseUnary(input, (x) = > cosh(x)) ? Cosh DML_OPERATOR_ELEMENT_WISE_COSH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Hyperbolic Tangent function hyperbolicTangent(input) = elementwiseUnary(input, (x) = > tanh(x)) function hyperbolicTangent(X) = div(sub(1, exp(mul(X, -2)), add(1, exp(X, -2))) // OR elementwise (1 - expₑ(-2 * x)) / (1 + expₑ(-2 * x)) // OR elementwise 2 / (1 + expₑ(-2 * x)) - 1 // OR elementwise (expₑ(x) - expₑ(-x)) / (expₑ(x) + expₑ(-x)) tanh Tanh DML_OPERATOR_ELEMENT_WISE_TANH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Scaled Hyperbolic Tangent function scaledHyperbolicTangent(input, alpha, beta) = mul(tanh(mul(input, beta)), alpha) ? ScaledTanh DML_OPERATOR_ACTIVATION_SCALED_TANH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Hyperbolic Arccosine function hyperbolicArccosine(input) = elementwiseUnary(input, (x) = > arccosh(x)) // OR elementwise logₑ(x + sqrt(x * x - 1)) ? Acosh DML_OPERATOR_ELEMENT_WISE_ACOSH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Hyperbolic Arcsine function hyperbolicArcsine(input) = elementwiseUnary(input, (x) = > arcsinh(x)) // OR elementwise logₑ(x + sqrt(x * x + 1)) ? Asinh DML_OPERATOR_ELEMENT_WISE_ASINH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric Hyperbolic Arctangent function hyperbolicArctangent(input) = elementwiseUnary(input, (x) = > arctanh(x)) // OR elementwise logₑ((1 + x) / (1 - x)) / 2 ? Atanh DML_OPERATOR_ELEMENT_WISE_ATANH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Trigonometric CosineGrad function cosineGrad(dx, x) = elementwiseBinary(dx, x, (dx, x) = > mul(sin(x), dx)) ? Composed DML_OPERATOR_ELEMENT_WISE_SIN &
DML_OPERATOR_ELEMENT_WISE_MULTIPLY
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Reduction Sum funtion sum(a, b, ...) = elementwiseNnary((a, b, ...), f(x, y, ...) = x + y + …) add repeated Sum DML_OPERATOR_ELEMENT_WISE_ADD via repeated inputs ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? < N-1 ULP
Elementwise Math Reduction Mean funtion mean(a, b, ...) = elementwiseNnary((a, b, ...), f(x, y, ...) = (x + y + …) / n) add repeated and div Mean DML_OPERATOR_ELEMENT_WISE_MEAN via repeated inputs ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Reduction Maximum funtion maximum(a, b, ...) = elementwiseNnary((a, b, ...), f(x, y, ...) = max(x, y, …)) max repeated Max DML_OPERATOR_ELEMENT_WISE_MAX via repeated inputs ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Threshold funtion threshold(input, minValue) = elementwiseUnary(input, (x) => max(x, minValue)) notes: Not equivalent to ThresholdedRelu. ? Max DML_OPERATOR_ELEMENT_WISE_THRESHOLD ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Reduction Minimum funtion minimum(a, b, ...) = elementwiseNnary((a, b, ...), f(x, y, ...) = min(x, y, …)) ? Min DML_OPERATOR_ELEMENT_WISE_MIN via repeated inputs ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Elementwise Math Quantization Quantize Linear function quantizeLinear(input:float32, scale:float32, zeroPoint:int32) return clamp(add(round(input / scale), zeroPoint), 0, 255).cast(uint8) endfunction ? QuantizeLinear
com.microsoft QuantizeLinear
DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Elementwise Math Quantization Dequantize Linear function dequantizeLinear(input:uint8, scale:float32, zeroPoint:uint8) return mul(sub(input, zeroPoint).cast(float32), scale) endfunction ? DequantizeLinear
com.microsoft DequantizeLinear
DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR ? ? ? ? tf.quantization.dequantize ? ? constexpr_affine_dequantize ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Sigmoid function sigmoid(input) = reciprocal(add(1, exp(negate(x))))) // OR elementwiseUnary(input, x => 1 / (1 + expₑ(-x))) // OR elementwiseUnary(input, x => expₑ(x) / (1 + expₑ(x))) sigmoid Sigmoid DML_OPERATOR_ACTIVATION_SIGMOID ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Hard Sigmoid function hardSigmoid(input, scale = 0.2, offset = 0.5) = clamp(affine(input, scale, offset), 0, 1) // OR elementwise max(0, min(x * scale + offset, 1)) hardSigmoid HardSigmoid DML_OPERATOR_ACTIVATION_HARD_SIGMOID ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Hard Swish function hardSwish(input, scale = 1/6, offset = 0.5) = mul(input, hardSigmoid(input, scale, offset)) // OR using limit=6 instead, elementwise x * max(0, min(limit, (x + (limit/2)))) / limit function hardSwishAlways6(input) = hardSwish(input, 1/6, 0.5) hardSigmoid HardSigmoid DML_OPERATOR_ACTIVATION_HARD_SWISH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Clamp Positive (Rectified Linear Unit) function clampPositive(input) = max(input, 0) // OR elementwise if x >= 0 then x else 0 relu Relu DML_OPERATOR_ACTIVATION_RELU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Activation, Elementwise Leaky Rectified Linear Unit function leakyRectifiedLinearUnit(input, alpha) = select(lesser(input, 0), mul(input, alpha), input) // OR elementwise if x >= 0 then x else alpha * x leakyRelu LeakyRelu DML_OPERATOR_ACTIVATION_LEAKY_RELU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? <= Mul precision
Activation, Elementwise Parameterized Rectified Linear Unit function parameterizedRectifiedLinearUnit(input, slope) = select(greaterOrEqual(input, 0), input, mul(input, slope)) // OR elementwise if x >= 0 then x else slope * x PRelu and LeakyRelu are identical, except one slope is an input tensor and one slope is a constant. prelu PRelu DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Thresholded Rectified Linear Unit function thresholdedRectifiedLinearUnit(input, alpha = 1) = select(greater(input, alpha), input, 0) // OR elementwise if x > alpha then x else 0 ? ThresholdedRelu DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Activation, Elementwise Exponential Linear Unit function exponentialLinearUnit(input, alpha = 1) return select(greaterOrEqual(input, 0), input, mul(sub(exp(input), 1), alpha)) // OR add(clamp(input, 0, inf), clamp(mul(sub(exp(input), 1), alpha), -inf, 0)) // OR elementwiseUnary(input, (x) => if x >= 0 then x else alpha * (expₑ(x) - 1)) endfunction elu Elu DML_OPERATOR_ACTIVATION_ELU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Scaled Exponential Linear Unit function scaledExponentialLinearUnit(input, alpha = 1.6732, gamma = 1.0507) return mul(elu(input, alpha), gamma) // OR elementwise gamma * iif(x > 0, x, alpha * (expₑ(x) - 1)) endfunction ? Selu DML_OPERATOR_ACTIVATION_SCALED_ELU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Gaussian Error Linear Unit (ref) function guassianErrorLinearUnit(x) = mul(mul(x, 0.5), add(1.0, erf(div(x, sqrt(2))))) // OR elementwise x * 0.5 * (1.0 + erf(x / sqrt(2))) ? Gelu DML_OPERATOR_ACTIVATION_GELU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Normalization Soft Maximum Raise all elements to e, and divide all the elements in each batch by that batch's sum. function softMax(input, axes) // Applies: // - DML_OPERATOR_ACTIVATION_SOFTMAX1 expInput = exp(input) reducedExpInput = reduceSum(expInput, axis=axes, keepDimensions=1) return div(expInput, reducedExpInput) // Or for more numerical stability: // maxInput = reduceMax(input, axes, keepDimensions=1) // expInput = exp(input - maxInput) // reducedExpInput = reduceSum(expInput, axis=axes, keepDimensions=1) // return div(expInput, reducedExpInput) endfunction function softMax1D(input, axis) // Only handle a single axis, explicitly given. // Applies: // - ONNX Softmax-13 // - torch.nn.Softmtax // - tf.nn.softmax // - NCNN Softmax // - MPS softMax // - mil.ops.defs.iOS15.activation.softmax return softMax(input, axes=[axis]) endfunction function softMax1DRightmostAxis(input) // Only handle a single axis, implicitly rightmost dimension. // Applies: // - DML_OPERATOR_ACTIVATION_SOFTMAX(0) return softMax(input, axes=[1]) endfunction function softmaxRightmostRange(input, firstAxis) // Only handles rightmost dimensions. // ONNX Softmax-11 flattenedInput = flattenTo2D(input, firstAxis) // Flatten to 2D normalizedInput = softMax(flattenedInput, axes=[1]) return reshape(normalizedInput, input.dimensions) endfunction or per batch: f(x) = expₑ(x) / sum(expₑ(X)) or per batch for more numerical stability: expₑ(x - max(X)) / sum(expₑ(x - max(X))) softmax Softmax DML_OPERATOR_ACTIVATION_SOFTMAX ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation Log Soft Maximum function logSoftMaximum(input, axes) return log(softmax(input, axis)) // OR logₑ(expₑ(x - max(input)) / sum(expₑ(input - max(input)))) // OR (x - max(input)) - logₑ(x - max(input)))) endfunction function logSoftMaximum1D(input, axis) // Only handle a single axis, explicitly given. return logSoftMaximum1D(input, [axis]) endfunction ? LogSoftmax DML_OPERATOR_ACTIVATION_LOG_SOFTMAX ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation Hard Maximum function hardMaximum(x) = if x[i] == max(X) then 1 else 0 *but only for first element along that axis ? Hardmax DML_OPERATOR_ACTIVATION_HARDMAX ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Soft Sign function softSign(input) = div(input, add(abs(input), 1)) softsign Softsign DML_OPERATOR_ACTIVATION_SOFTSIGN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Softplus function softPlus(X) = log(add(exp(x), 1)) softplus Softplus DML_OPERATOR_ACTIVATION_SOFTPLUS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Parametric Softplus function parametricSoftPlus(input, alpha, beta) = mul(softplus(mul(input, beta)), alpha) // OR elementwise logₑ(expₑ(x * beta) + 1) * alpha ? ParametricSoftplus DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Affine function affine(input, alpha, beta) = add(mul(input, alpha), beta) linear Affine DML_OPERATOR_ACTIVATION_LINEAR
DML_OPERATOR_ELEMENT_WISE_IDENTITY
with scale and bias
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Activation, Elementwise Symmetric signal shift function symmetricSignalShift(input, threshold /*lambda*/, bias) return add( mul(input, greater(abs(input), threshold)), mul(sign(input), negate(bias)) ) // OR return select( greater(abs(input), threshold), sub(input, mul(sign(input), bias), 0 ) // OR elementwise if x < -threshold then y = x + bias // elif x > threshold then y = x - bias // else y = 0 endfunction ? Shrink DML_OPERATOR_ACTIVATION_SHRINK ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Generation, Random Random Normal function randomNormal(scale, mean) = MarsagliaPolarTransform(random(), random()) * scale + mean ? RandomNormal --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Unpredictable
Generation, Random Random Normal Like function randomNormalLike(scale, mean) = MarsagliaPolarTransform(random(), random()) * scale + mean ? RandomNormalLike --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Unpredictable
Generation, Random Random Uniform function randomUniform(low, high, dataType, dimensions) // see MT19937 function f() range = high - low // note inclusive end if dataType is integer then return (rand() % (range+1) + low if dataType is float then (rand() / randmax) * range + low endif return elementwiseNullary(f, dataType, dimensions) endfunction ? RandomUniform
RandomUniformLike
--- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Unpredictable
Generation, Random Random Multinomial TODO: function randomMultinomial() = ... ? Multinomial --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Unpredictable
Generation, Matrix Multiplication Diagonal Matrix Notes: set 1's all along diagonal. In other words, all output[i, i+k] = 1, and every other element = 0. function eyeLike(input, outputDimensions, outputDataType, diagonalShift=0) assert(outputDimensions == input.dimensions) assert(outputDataType == input.dataType) output = new Tensor(outputDataType, outputDimensions) if not input exists input = zeros(outputDataType, outputDimensions) endif for each coordinate in output tensor coordinates output[coordinate] = if coordinate.h + diagonalShift == coordinate.w then 1 else 0 endfor endfunction ? EyeLike DML_OPERATOR_DIAGONAL_MATRIX ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Generation, Matrix Multiplication Diagonal Matrix TODO: triangular Trilu DML_OPERATOR_DIAGONAL_MATRIX1 ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Matrix Multiplication Generic Matrix Multiplication function matrixMultiplication(A, B, C, alpha, beta, transA, transB; Y) A2 = if(transA, transpose(A), A) B2 = if(transB, transpose(B), B) Y = add(mul(alpha, matMul(A2, B2)), mul(beta, C)) endfunction gemm Gemm DML_OPERATOR_MATRIX_GEMM ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Matrix Multiplication Matrix Multiplication for i=0..<h do for j=0..<w do y[i,j] = dot(A[i,0..w], B[0..h,j]) TODO: Demonstrate via reduceSum of higher dimensional space. It's essentially a ReduceSum(Mul(A.row, B.column)) per output element. notes: A and B can be 1D vectors, which are treated as [1,W] and [H,1] matrices. A and B can have batch count dimensions, where each 2D matrix is multiplied separately. The batch count can be broadcast too. matmul MatMul DML_OPERATOR_MATRIX_GEMM ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Matrix Multiplication Convolve Every output element is the convolution of the filter with the corresponding input elements. out[j] = (x[i]*w[0]) + (x[i+1]*w[1]) + (x[i+2]*w[2]) + ... + (x[i+k-1]*w[k-1]) + b notes: 'steps' affects the size of steps over the input. 'dilations' affects the step of the filter, as if the filter had been resized with interceding zeroes between elements. A dilation of 1 means no change (1:1 with filter), whereas dilation of 2 inserts lines of zeros between every filter line. 'pads' are not actually added to the input, just virtually treated as if zeros. vdumoulin convolution diagrams function convolve(input, filterWeights, windowDimensions, padding, dilations, strides) startPads = pads[0..pads.size/2] endPads = pads[pads.size/2..pads.size] // TODO: compute output size // output.dimensions = (input.dimensions + startPads + endPads) // todo: consider strides and kernel size for each outputCoordinate in output coordinates output[outputCoordinate] = convolveKernel(input, filterWeights, outputCoordinate * strides - startPads, dilations) endfor endfunction function convolveKernel(input, filterWeights, firstInputCoordinate, dilations) // 2D example only // TODO:Figure out what 'group' does and what 'M' is? result = 0 // todo: How do 'M' and 'C' factor into this? for y=0..<filterWeights.dimensions[2] for x=0..<filterWeights.dimensions[3] inputCoordinate = firstInputCoordinate + ([y,x] * dilations) if (input.contains(inputCoordinate)) // check coordinates within tensor result += filterWeights[y,x] * input[inputCoordinate] endif endfor // x endfor // y return result endfunction conv Conv DML_OPERATOR_CONVOLUTION with DML_CONVOLUTION_MODE_CROSS_CORRELATION and DML_CONVOLUTION_DIRECTION_FORWARD ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Matrix Multiplication Convolve Tranposed TODO: Here be dragons. questions: What is the difference between CONVOLUTION vs CORRELATION enum, and FORWARD vs BACKWARD? ? ConvTranspose DML_OPERATOR_CONVOLUTION with DML_CONVOLUTION_MODE_CROSS_CORRELATION and DML_CONVOLUTION_DIRECTION_BACKWARD ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Data Conversion, Elementwise Cast function cast(input) = elementwiseUnary(input, (x) => cast(x)) cast Cast DML_OPERATOR_CAST ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? < 1 ULP
Data reorganization Transpose Reorder axes, such as X Y -> Y X, or X Y Z -> Z X Y. function transpose(input, permutationAxes /*gather semantics*/) assert(permutationAxes.size == input.rank) rank = input.rank for i=0..<rank do output.dimensions[i] = input.dimensions[permutationAxes[i]] outputCoordinate = repeat(rank, 0) for each inputCoordinate in input coordinates for i=0..<rank do outputCoordinate[i] = inputCoordinate[permutationAxes[i]] output[outputCoordinate] = input[inputCoordinate] endfor endfunction transpose Transpose DML_OPERATOR_ELEMENT_WISE_IDENTITY with TENSOR_DESC that flips via permuted strides ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Broadcast Broadcast any single size dimensions up to the output dimension counts. Similar to NumPy broadcast_to. function broadcast(input, targetDimensions) output.dimensions = broadcastDimensions(input.dimensions, targetDimensions) inputShape = padLeadingValues(input.dimensions, output.rank, 1) for each outputCoordinate in output coordinates for i=0..<output.rank do inputCoordinate[i] = iif(inputShape[i] > 1), outputCoordinate[i], 0) output[outputCoordinate] = inputData[inputCoordinate] endfor endfunction expand Expand DML_OPERATOR_ELEMENT_WISE_IDENTITY with TENSOR_DESC using zero strides along broadcast dimension ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Broadcast Dimensions Helper to compute the broadcasted output dimensions (1D tensor) from a list of multiple input dimensions. Any single size dimensions are stretched to the output dimension size. function broadcastDimensions(dimensionsList...) outputRank = 0 for each dimensions in dimensionsList do outputRank = max(outputRank, dimensions.size) // Determine the largest rank broadcastedDimensions = ones([outputRank]) // [1,1,...] for each dimension in dimensions // Take the max of all dimensions. paddedDimensions = padLeadingValues(dimensions, outputRank, 1) for i=0..<outputRank do assert(paddedDimensions[i] == broadcastedDimensions[i] || paddedDimensions[i] == 1 || broadcastedDimensions[i] == 1)) broadcastedDimensions[i] = max(broadcastedDimensions[i], paddedDimensions[i]) endfor endfor return broadcastedDimensions endfunction // Resize right-aligned by padding with leading values. // e.g. paddingSize=4 with [H,W] -> [1,1,H,W] function padLeadingValues(values, paddedSize, padValue) // Right align. e.g. original dimensions=[H,W], paddedSize=4, padValue=1 -> [1,1,H,W] paddingCount = max(paddedSize, values.size) - values.size paddedValues = values paddedValues.prepend(paddingCount, padValue) return paddedValues endfunction NA NA NA NA NA NA NA NA NA NA NA NA NA NA NA ? ? ? ? Exact
Data reorganization Reshape Return tensor with a different view of the data, like a reinterpret cast using new dimensions that are element-count compatible. function reshape(input, newDimensions) output = input output.dimensions = newDimensions return output endfunction reshape Reshape NA, no actual data change, just update the TENSOR_DESC ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Reshape To 2D Reinterpret the view of the tensor, reducing the dimensions from N to 2. e.g. [1,2,3,4,5] with a split at axis 3 yields [1*2*3,4*5] -> [6,20]) function flattenTo2D(input, axis) output = input oldDimensions = input.dimensions output.dimensions = join(reduceProduct(oldDimensions[0..axis]), reduceProduct(oldDimensions[axis..oldDimensions.size])) return output endfunction ONNX: function flattenTo2D(input, axis) inputShape = Shape(input) shapeFrontHalf = Slice(inputShape; ends=axis) shapeBackHalf = Slice(inputShape; starts=axis) newShape = Concat(axis=0, ReduceProd(shapeFrontHalf), ReduceProd(shapeBackHalf)) output = Reshape(input, newShape) endfunction reshape (plus caller logic) Flatten NA, no actual data change, just update the TENSOR_DESC ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Reshape Removing Ones Reinterpret the view of the tensor, removing 1's for deletable axes. function reshapeDeletingOnes(input, axes) output = input output.dimensions = deleteOnesInDimensions(input.dimensions, axes) return output endfunction function deleteOnesInDimensions(dimensions, axes) if axes is undefined axes = increasingSequence(0, dimensions.size) // Remove all 1's. else assert(allOf(dimensions, (d) => (d == 1))) axes = removeDuplicates(sortDescending(axes)) endif newDimensions = dimensions for i in axes // work from back to front if newDimensions[i] == 1 newDimensions.deleteAt(i) endif endfor return newDimensions endfunction reshape (plus caller logic) Squeeze NA, just rearrange the TENSOR_DESC ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Reshape Inserting Ones Reinterpret the view of the tensor, filling in 1's for newly inserted axes. function reshapeInsertingOnes(input, axes) output = input output.dimensions = insertOnesInDimensions(input.dimensions, axes) return output endfunction function insertOnesInDimensions(dimensions, axes) // Note the axes are relative to their *final* index. // So dimensions = [3,4] with axes = [0,2] yields new dimensions = [1,3,1,4]. newDimensions = dimensions axes = removeDuplicates(sort(axes)) for i in axes newDimensions.insertAt(i, 1) endfor return newDimensions endfunction reshape (plus caller logic) Unsqueeze NA, just rearrange the TENSOR_DESC ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Reshape From Axes Reinterpret the view of the tensor, gathering dimensions from axes, filling in 1's for filler dimensions. function reshapeFromAxes(input, newRank, axes) assert(input.rank == axes.size) assert(containsUniqueValues(axes)) output = input output.dimensions = gatherValues(input.dimensions, axes, newRank, 1) return output endfunction function gatherValues(values, indices, newValuesSize, fillerValue) newValues = repeat(newValuesSize, fillerValue) for (i, gatherIndex) in indices newValues[i] = values[gatherIndex] endfor endfunction reshape (plus caller logic) Reshape plus caller logic NA, just rearrange the TENSOR_DESC ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Reshape To Axes Reinterpret the view of the tensor, scattering dimensions to axes, filling in 1's for filler dimensions. function reshapeToAxes(input, newRank, axes) assert(input.rank == axes.size) assert(containsUniqueValues(axes)) output = input output.dimensions = scatterValues(input.dimensions, axes, newRank, 1) return output endfunction function scatterValues(values, indices, newValuesSize, fillerValue) newValues = repeat(newValuesSize, fillerValue) for (i, scatterIndex) in indices newValues[scatterIndex] = values[i] endfor endfunction reshape (plus caller logic) Reshape plus caller logic NA, just rearrange the TENSOR_DESC ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Tile Repeat entire tensor along each axis by repeat counts. function tile(input, repeats) assert(repeats.size == input.rank) for i=0..<input.rank do assert(repeats[i] > 0) outputDimensions = input.dimensions * repeats // elementwise multiply per axis output = new Tensor(input.dataType, outputDimensions) for each outputCoordinate in output for i=0..<input.rank do inputCoordinate[i] = outputCoordinate[i] % input.dimensions[i] output[outputCoordinate] = inputData[inputCoordinate] endfor endfunction ? Tile DML_OPERATOR_TILE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Split Split input into multiple output tensors. split Split DML_OPERATOR_SPLIT ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Slice Crop the tensor to the given ranges for each axis. function slice(input, starts, ends, axes, steps) N = input.rank if axes.empty then axes = arange(0, N-1) // [0,1,2,...] if starts.empty then starts = zeroes(N) // [0,0,0,...] if ends.empty then ends = input.dimensions if steps.empty then steps = ones(N) // [1,1,1,...] assert(axes.size == input.rank || axes.size == 0) assert(starts.size == axes.size) assert(ends.size == axes.size) assert(steps.size == axes.size) starts = max(starts, zeroes(N)) ends = min(ends, input.dimensions) ends = max(ends, starts) for i=0..<N do output.dimensions[i] = ceil((ends[i] - starts[i]) / steps[i]) // negative steps unhandled! for each outputCoordinate in output for i=0..<N do inputCoordinate[i] = outputCoordinate[i] * steps[i] + starts[i] output[outputCoordinate] = inputData[inputCoordinate] endfor endfunction slice Slice DML_OPERATOR_SLICE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Concatenate Combine multiple tensors into large output tensor. e.g. {1,2,3} with {4,5} -> {1,2,3,4,5} function concatenate(inputs, axis) sizesAlongAxis = [] for each input in inputs sizesAlongAxis.append(input.dimensions[axis]) endfor outputOffsets = cumulativeSum(axisSizes) for each inputIndex from 0 up to inputs.count input = inputs[inputIndex] outputOffset = outputOffset[inputIndex] for each index from 0 up to axis output[..., outputOffset + index, ...] = input[..., index, ...] endfor endfor endfunction concat Concat DML_OPERATOR_JOIN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Gather TODO: gather Gather DML_OPERATOR_GATHER ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Gather Elements Return output tensor the same size as indices, filling with values from input indexed along the axis by indices. function gatherElements(input, indices, axis) output = new Tensor(input.dataType, indices.dimensions) for each coordinate in indices tensor inputCoordinate = coordinate inputCoordinate[axis] = indices[coordinate] output[coordinate] = input[inputCoordinate] endfor endfunction output[i][j][k] = input[ index[i][j][k] ][j][k] # if dim == 0 output[i][j][k] = input[i][ index[i][j][k] ][k] # if dim == 1 output[i][j][k] = input[i][j][ index[i][j][k] ] # if dim == 2 e.g. input = [1,2,3,4,5,6] indices = [0,0,1,5] axis = 0 output = [1,1,2,6] e.g. input = [[1,2],[3,4],[5,6]] indices = [[0,0],[1,0],[1,1]] axis = 1 output = [[1,1], [4,3], [6,6]] ? GatherElements DML_OPERATOR_GATHER_ELEMENTS ? ? ? ? ? ? torch.gather ? ? ? ? ? ? ? ? ? Exact
Data reorganization Gather Multidimensional TODO: ? GatherND DML_OPERATOR_GATHER_ND ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Scatter Elements Opposite of gather elements. Overwrites input values with updates into along the axis at the given indices.
If two output element indices overlap, the last write wins in practice. function scatterElements(input, indices, updates, axis) output = input for each coordinate in indices tensor outputCoordinate = coordinate outputCoordinate[axis] = indices[coordinate] output[outputCoordinate] = updates[coordinate] endfor endfunction output[ index[i][j][k] ][j][k] = input[i][j][k] # if dim == 0 output[i][ index[i][j][k] ][k] = input[i][j][k] # if dim == 1 output[i][j][ index[i][j][k] ] = input[i][j][k] # if dim == 2 e.g. data = [[1, 2, 3, 4, 5]] // data == input indices = [[1, 3]] updates = [[11, 21]] axis = 1 output = [[1, 11, 3, 21, 5]]
? ScatterElements
torch.tensor.scatter_
DML_OPERATOR_SCATTER_ELEMENTS ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Scatter ND TODO: ? ScatterND DML_OPERATOR_SCATTER_ND ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Pad Inflate the input with zeroes on the edges pad Pad DML_OPERATOR_PADDING ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Space To Depth Rearrange blocks of elements. channelCountDivBlockCount = channelCount / (blockSize * blockSize); inputIndices = [ outputIndices.batch, outputIndices.channel % channelCountDivBlockCount, (outputIndices.channel / channelCountDivBlockCount) / blockSize + (outputIndices.height * blockSize), (outputIndices.channel / channelCountDivBlockCount) % blockSize + (outputIndices.width * blockSize) ] output[outputIndices] = input[inputIndices]; ? SpaceToDepth DML_OPERATOR_SPACE_TO_DEPTH ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Depth To Space Rearrange blocks of elements. channelCountDivBlockCount = channelCount / (blockSize * blockSize); outputIndices = [ inputIndices.batch, inputIndices.channel % channelCountDivBlockCount, (inputIndices.channel / channelCountDivBlockCount) / blockSize + (inputIndices.height * blockSize), (inputIndices.channel / channelCountDivBlockCount) % blockSize + (inputIndices.width * blockSize) ] output[outputIndices] = input[inputIndices]; NumPy: # Using DCR mode (depth/column/row) b, c, h, w = x.shape tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) # Using CRD mode (column/row/depth) b, c, h, w = x.shape tmp = np.reshape(x, [b, c // (blocksize ** 2), blocksize, blocksize, h, w]) tmp = np.transpose(tmp, [0, 1, 4, 2, 5, 3]) y = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize]) reshape and transpose DepthToSpace DML_OPERATOR_DEPTH_TO_SPACE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Dimensions (aka Shape, Sizes) Return the dimensions of the tensor as a 1D tensor. function dimensions(input) = input.dimensions MLOperandDescriptor::dimensions Shape NA, just read the TENSOR_DESC dimensions ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Data reorganization Element Count function elementCount(input) = reduceProduct(input.dimensions, keepDimensions=false). note: Size is unfortunately named, inconsistently so with Resize-10 which accepts separate dimensions. 🙃 If you want the sizes of the tensor (N C H W) rather than just the total element count, called Shape instead. product(MLOperandDescriptor::dimensions) Size NA, just compute the number of TENSOR_DESC elements ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Element Count Along Axes function elementCountAlongAxes(inputDimensions, axes) return reduceProduct(gatherElements(input.dimensions, axes), keepDimensions=false) endfunction product(MLOperandDescriptor::dimensions) Size NA, just compute the number of TENSOR_DESC elements ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization mapping One Hot Along Axis Set all elements to 'off' values, then set one element to 'on' value along specified axis using index offset. function oneHot(indices, axis, axisLength, values) // Indices and output are broadcast compatible. // Indices has a dimension of size 1 at axis. // Output has a dimension of size axisLength at axis (opposite of reduction). // 1D values[2] contains {offValue, oneValue}. assert(indices.dimensions[axis] == 1) outputDimensions = indices.dimensions outputDimensions[axis] = axisLength defaultValues = broadcast(values[0], outputDimensions) return scatterElements(defaultValues, indices, values[1], axis) endfunction function oneHotExpandedOutput(indices, axis, axisLength, values) // Output is 1 dimension bigger than input, inserted at the axis. broadcastCompatibleDimensions = reshapeInsertingOnes(indices, [axis]) broadcastCompatibleIndices = reshape(indices, broadcastCompatibleDimensions) return oneHot(broadcastCompatibleIndices, axis, axisLength, values) endfunction ? OneHot DML_OPERATOR_ELEMENT_WISE_ONE_HOT ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Top K Sorted Selection function topK(input, axis, axisLength) // Order the entries along an axis, keeping a length of the top K. return slice(sortDecreasing(input, axis), starts=[0], ends=[axisLength], axes=[axis]) endfunction ? TopK DML_OPERATOR_TOP_K ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization, Elementwise, Selection Select elementwise
("where" is a bad name)
function select(condition, trueValue, falseValue) elementwiseTrinary(condition, trueValue, falseValue, (c, t, f) => if c then t else f) endfunction notes: A conditional per-element if statement. Can be used to implement composites that use logical operators (e.g. PRelu). where Where DML_OPERATOR_ELEMENT_WISE_IF ? stablehlo.select tola.select numpy.where ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization, Selection Join Selected Slices A conditional slice/join along a specific axis. Has utterly nothing to do with data compression, despite the confusing name. ? Compress --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Data reorganization Reverse axes Reverse all the elements along the given axes. function reverseAxes(input, axes) output = newTensor(input.dimensions, input.dataType) for each inputCoordinate in input tensor // Flip the coordinate along all applicable axes. outputCoordinates = inputCoordinates for each axis in axes outputCoordinates[axis] = output.dimensions[axis] - outputCoordinates[axis] - 1 endfor output[outputCoordinate] = input[inputCoordinate] endfor endfunction ? Reverse (Reverted) DML_OPERATOR_SLICE1
with negative strides
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Pooling Pool Generic function poolGeneric(input, axes, windowDimensions, padding, strides, dilations, reductionFunction, initialValue) // Massage all axes-relative parameters to be directly compatible with input/output rank. expandedWindowDimensions = gatherValues(windowDimensions, axes, input.rank, 1) expandedPadding = gatherValues(padding, axes, input.rank, 0) expandedStrides = gatherValues(strides, axes, input.rank, 1) expandedDilations = gatherValues(dilations, axes, input.rank, 1) // Compute the output tensor size based on window size/padding/strides/dilations. filterExtents = ((expandedWindowDimensions - 1) * expandedDilations) + 1 paddedDimensions = input.dimensions + expandedPadding.leading + expandedPadding.trailing outputDimensions = (paddedDimensions - filterExtents + 1) / expandedStrides output = new Tensor(input.type, outputDimensions) // Reduce input along active axes. for each outputCoordinate in output coordinates // For each input in the window, apply the reduction function outputValue = initialValue for each (inputCoordinate, inputValue) in local input window outputValue = reductionFunction(outputValue, input[inputCoordinate]) endfor output[outputCoordinate] = outputValue endfor return output endfunction function poolGenericWithIndices(input, axes, windowDimensions, padding, strides, dilations, indicesDataType, reductionFunction, initialValue) // TODO: Complete... // output = new Tensor(input.type, outputDimensions) // indices = new Tensor(indicesDataType, outputDimensions) return (output, indices) endfunction TODO: Express pooling as higher dimension strided reduction. ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Pooling Pool Sum function poolSum(input, axes, windowDimensions, padding, strides, dilations) return poolGeneric(input, axes, windowDimensions, padding, strides, dilations, add, 0) // OR convolve(input, filter = ones(windowDimensions), axes, windowDimensions, padding, strides, dilations) endfunction ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Pooling Pool Average function poolAverage(input, axes, windowDimensions, padding, strides, dilations) windowElementCount = elementCountAlongAxes(input.dimensions, axes) return div(poolSum(input, axes, windowDimensions, padding, strides, dilations), windowElementCount) endfunction averagePool2d AveragePool DML_OPERATOR_AVERAGE_POOLING ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Pooling Pool Average Spatial Dimensions Global function poolAverageSpatialDimensionsGlobal(input) axes = increasingSequence(2, input.rank) // Skip N and C dimensions. return reduceAverage(input, axes, keepDimensions=true) // Alternately poolAverage with windowDimensions equal to the input sizes after N,C. endfunction Average all spatial elements in each batch & channel. So X[N C H W] reduces to Y[N C 1 1] ONNX GlobalAveragePool: InputShape = Shape(X) SpatialDimensions = Slice(InputShape; starts=2) // skip leading N and C dimensions. Output = AveragePool(X, kernel_shape=SpatialDimensions)) averagePool2d / reduceAverage GlobalAveragePool DML_OPERATOR_AVERAGE_POOLING ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Pooling Pool Maximum function poolMaximum(input, axes, windowDimensions, padding, strides, dilations) return poolGeneric(input, axes, windowDimensions, padding, strides, dilations, max, -infinity) endfunction maxPool2d MaxPool DML_OPERATOR_MAX_POOLING ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Pooling Pool Maximum Spatial Dimensions Global function poolMaximumSpatialDimensionsGlobal(input) axes = increasingSequence(2, input.rank) // Skip N and C dimensions. return reduceMax(input, axes, keepDimensions=true) // Alternately poolMaximum with windowDimensions equal to the input sizes after N,C. endfunction maxPool2d / reduceMax GlobalMaxPool DML_OPERATOR_MAX_POOLING with output being 1 element ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Pooling Unpool Maximum Opposite of MaxPool. Fill the output tensor of the given shape (either explicit or the input shape plus padding) with zeros. Then write each value from the input tensor into the output tensor at the element offset from the corresponding indices array. ? MaxUnpool --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Pooling Pool Lebesgue function poolLebesgue(input, axes, windowDimensions, padding, strides, dilations, exponent) return root(poolSum(pow(input, exponent), axes, windowDimensions, padding, strides, dilations), exponent) // y = (x1^p + x2^p + ... + xn^p) ^ (1/p) endfunction l2Pool2d LpPool DML_OPERATOR_LP_POOLING ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Pooling Pool Lebesgue Spatial Dimensions Global function poolLebesgueSpatialDimensionsGlobal(input, exponent) axes = increasingSequence(2, input.rank) // Skip N and C dimensions. return reduceLebesgue(input, axes, exponent, keepDimensions=true) // TODO: Add the above reduction function endfunction So X[N C H W] reduces to Y[N C 1 1] e.g. (3^2 + 4^2) ^ (1/2) = 5 l2Pool2d GlobalLpPool DML_OPERATOR_LP_POOLING with output being 1 element ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Pooling Pool Maximum Region of Interest Apply MaxPool to given input within each numbered region: [batch_index, w_offset_start, h_offset_start, w_offset_last_inclusive, h_offset_last_inclusive]. Then write the maximal value back to the output. questions: Are x2 and y2 really supposed to be end-inclusive? If so, how can that possibly work correctly with the spatial_scale attribute? What's the point of the pooled_shape when each region has a specific size anyway? ? MaxRoiPool DML_OPERATOR_ROI_POOLING (only POOLING_MAX is supported) ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce Generic
function reduceGeneric(input, axes, keepDimensions, reductionFunction, initialValue) // Determine output tensor dimensions. outputDimensions = input.dimensions outputCoordinateMask = repeat(outputDimensions.size, 0xFFFFFFFF) output = new Tensor(input.type, input.dimensions, initialValue) for each axis in axes outputDimensions[axis] = 1 outputCoordinateMask[axis] = 0 endfor // Reduce input along active axes. for each (inputCoordinate, value) in input outputCoordinate = inputCoordinate & outputCoordinateMask previousValue = output[outputCoordinate] output[outputCoordinate] = reductionFunction(input, previousValue) endfor // Remove reduced dimensions (size 1) from output tensor if desired. if keepDimensions == false outputDimensions = deleteOnesInDimensions(outputDimensions, axes) endif output.dimensions = outputDimensions return output endfunction
NA NA DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Sum function reduceSum(input, axes, keepDimensions) return reduceGeneric(input, axes, keepDimensions, add, 0) // x[0] + x[1] + ... + x[n-1] endfunction reduceSum ReduceSum DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_SUM ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Mean function reduceAverage(input, axes, keepDimensions) reducedElementCount = elementCountAlongAxes(input.dimensions, axes) return div(reduceSum(input, axes, keepDimensions), reducedElementCount) endfunction reduceMean ReduceMean DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_AVERAGE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Product function reduceProduct(input, axes, keepDimensions) return reduceGeneric(input, axes, keepDimensions, mul, 1) // x[0] * x[1] * ... * x[n-1] endfunction reduceProduct ReduceProd DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_MULTIPLY ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Logarithm of Sum function reduceSumLog(input, axes, keepDimensions) return log(reduceSum(input, axes, keepDimensions)) // logₑ(x[0] + x[1] + ... + x[n-1]) endfunction reduceLogSum ReduceLogSum DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_LOG_SUM ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Logarithm of Sum of Exponents function reduceExpSumLog(input, axes, keepDimensions) return log(reduceSum(exp(input), axes, keepDimensions)) // logₑ(expₑ(x[0]) + expₑ(x[1]) + ... + expₑ(x[n-1])) endfunction reduceLogSumExp ReduceLogSumExp DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_LOG_SUM_EXP ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Sum of Squares function reduceSumSquares(input, axes, keepDimensions) return reduceSum(pow(X, 2), axes, keepDimensions) // x[0]^2 + x[1]^2 + ... + x[n-1]^2 endfunction reduceSumSquare ReduceSumSquare DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_SUM_SQUARE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Sum of Absolute Values function reduceL1(input, axes, keepDimensions) return reduceSum(abs(input), axes, keepDimensions) // abs(x[0]) + abs(x[1]) + ... + abs(x[n-1]) endfunction reduceL1 ReduceL1 DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_L1 ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to L2 Distance function reduceL2(input, axes, keepDimensions) return sqrt(reduceSum(pow(X, 2), axes, keepDimensions)) // sqrt(x[0]^2 + x[1]^2 + ... + x[n-1]^2) endfunction reduceL2 ReduceL2 DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_L2 ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Reduction Reduce to Maximum function reduceMaximum(input, axes, keepDimensions) return reduceGeneric(input, axes, keepDimensions, max, -inf) // max(max(max(x[0], x[1]), x[2]), ..., x[n-1]) endfunction reduceMax ReduceMax DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_MAX ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Reduction Reduce to Minimum function reduceMaximum(input, axes, keepDimensions) return reduceGeneric(input, axes, keepDimensions, min, inf) // min(min(min(x[0], x[1]), x[2]), ..., x[n-1]) endfunction reduceMin ReduceMin DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_MIN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Reduction Find Maximum Index int32 {i j k ..} = maxindex(X Y Z …) argMax ArgMax DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_ARGMAX ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Reduction Find Minimum Index int32 {i j k ..} = minindex(X Y Z …) argMin ArgMin DML_OPERATOR_REDUCE with DML_REDUCE_FUNCTION_ARGMIN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Imaging Operators Resample function resample(input, scales) outputDimensions = floor(input.dimensions * scales) output = new Tensor(input.dataType, outputDimensions) for each coordinate in output output[coordinate] = input[coordinate / scales] endfor return output endfunction resample Resize DML_OPERATOR_RESAMPLE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Imaging Operators Resample Up resample(input, scales) resample Upsample DML_OPERATOR_UPSAMPLE_2D ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Control Flow If f(cond, then_graph, else_graph, outputs...): subgraph = cond ? then_graph : else_graph outputs = subgraph(implictly_named_inputs_from_outer_graph) ? If --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Control Flow Loop TODO: ? Loop --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Control Flow Scan TODO: ? Scan --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Normalization Mean Variance Normalization For each output element, subtract the mean, and divide by standard deviation. function meanVarianceNormalization(input, axes) // = (input - mean) / standardDeviation // = (input - mean) / sqrt(variance + epsilon) // = (input - mean(X)) / sqrt(mean((input - mean(input))^2)) // = (input - mean(X)) / sqrt(mean(input^2) - mean(input)^2) centeredInput = sub(input, mean) mean = reduceAverage(input, axes, keepDimensions=true) meanSquared = pow(mean, 2) squareMeaned = reduceAverage(pow(input, 2), axes, keepDimensions=true) variance = sub(squareMeaned, meanSquared) standardDeviation = sqrt(add(varianceEpsilon, epsilon)) return div(centeredInput, standardDeviation) endfunction
ONNX and NumPy ONNX: exponent = Const(2.0) epsilon = Const(1e-9) inputMean = ReduceMean(input) inputMeanSquared = Pow(inputMean, exponent) inputSquared = Pow(input, exponent) inputSquareMeaned = ReduceMean(inputSquared, keepdims=1) variance = Sub(inputSquareMeaned, inputMeanSquared) standardDeviation = Sqrt(variance) inputCentered = Sub(input, inputMean) standardDeviationWithEpsilon = Add(standardDeviation, epsilon) X_MVN = Div(inputCentered, standardDeviationWithEpsilon) NumPy: inputMean = np.mean(input, axes, keepdims=1) inputMeanSquared = np.power(inputMean, 2) inputSquared = np.power(input, 2) inputSquareMeaned = np.mean(inputSquared, axes, keepdims=1) standardDeviation = np.sqrt(inputSquareMeaned - inputMeanSquared) output = (input - inputMean) / (standardDeviation + 1e-9)
? MeanVarianceNormalization DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1 ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Normalization Spatial Normalization
(independent batch&channel)
function instanceNormalization(input, scale, bias, axes) // Generic version // Applies: DirectML assert(isBroadcastCompatible(reshapedScale.dimensions, input.dimensions)) assert(isBroadcastCompatible(reshapedBias.dimensions, input.dimensions)) // scale * (input - mean) / sqrt(variance + epsilon) + reshapedBias return add(mul(scale, meanVarianceNormalization(input, axes)), bias) endfunction function instanceNormalizationSpatialDimensions(input, scale1D, bias1D) // Applies: ONNX, WebNN spatialAxes = [2 ... input.rank-1] // Exclude axes {0,1} for N and C. channelAxes = [1] // 1D scale is coerced to [batch size, C axis size, spatial dims...] // 1D bias is coerced to [batch size, C axis size, spatial dims...] reshapedScale = reshapeToAxes(scale1D, input.rank, channelAxes) reshapedBias = reshapeToAxes(bias1D, input.rank, channelAxes) return instanceNormalization(input, reshapedScale, reshapedBias, spatialAxes) endfunction Mean and variance are computed across spatial dimensions DHW, independently per batch & channel (NC): axes = [2,3, ..., inputRank-1] // Exclude axes {0,1} mean = (x0 + x1 + …) / xn; variance = ((x0 - xmean)^2 + (x1 - xmean)^2 + …) / xn ONNX: mean = ReduceAverage(X, axes, keepdims=true) variance = ReduceAverage(Pow(Sub(X, mean), 2), axes, keepdims=true) NumPy: mean = np.mean(x, axis=axes, keepdims=True) variance = np.var(x, axis=axes, keepdims=True) instanceNormalization InstanceNormalization DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1 with axes = [2,3,4,...] excluding (N,C) DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION with acrossChannels=false normalizeVariance=true scale and bias provided ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Normalization Channel&Spatial Normalization
(independent leading batches) ref
function layerNormalization(input, scale, bias, firstAxis) axes = [firstAxis...input.rank - 1] // Scale and bias are expected to already be broadcast-compatible with input. return add(mul(scale, meanVarianceNormalization(input, axes)), bias) // scale * (input - mean) / sqrt(variance + epsilon) + bias endfunction Mean and variance are computed across all dimensions from and after axis, independently per leading batches: axes = [axis, axis+1, ..., inputRank-1] // Exclude axes {0, ..., axis-1} mean = (x0 + x1 + …) / xn; variance = ((x0 - xmean)^2 + (x1 - xmean)^2 + …) / xn ONNX: mean = ReduceAverage(axes, keepdims=true) variance = ReduceAverage(Pow(Sub(x, mean), 2), axes, keepdims=true) NumPy: mean = np.mean(x, axis=axes, keepdims=True) variance = np.var(x, axis=axes, keepdims=True) layerNormalization LayerNormalization DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1 with axes = [firstAxis, firstAxis+1, ..., input.rank - 1]. DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION with acrossChannels=false normalizeVariance=true scale and bias provided ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Normalization Batch&Spatial Normalization
(independent channels)
function batchAndSpatialNormalization(input, scale, bias, mean, var, epsilon, momentum) return scale * (input - mean) / sqrt(variance + epsilon) + bias endfunction function batchAndSpatialNormalization1DBias(input, scale, bias, mean, var, epsilon, momentum) // Assuming C == axis 1. return batchAndSpatialNormalization(input, scale, reshapeToAxes(bias, [1]), mean, var, epsilon, momentum) endfunction Typically statistics are precomputed across batch and spatial dimensions NDHW (and not just the batch dimension, as the name would misleadingly lead you believe). So each channel C is independent. Then they are reshaped to be broadcast-compatible with the input. axes = [0,2, ..., inputRank-1] // Exclude axes {1}, and everything except channel mean = (x0 + x1 + …) / xn; variance = ((x0 - xmean)^2 + (x1 - xmean)^2 + …) / xn ONNX: mean = ReduceAverage(axes, keepdims=true) variance = ReduceAverage(Pow(Sub(x, mean), 2), axes, keepdims=true) NumPy: axes = allAxesExceptChannelAxes mean = np.var(x, axis=axes, keepdims=True) variance = np.mean(x, axis=axes, keepdims=True) ? BatchNormalization DML_OPERATOR_BATCH_NORMALIZATION ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Normalization Local Response Normalization function localResponseNormalization(input, axes, windowDimensions, padding, scale, bias, exponent) regionAverages = averagePoolND(pow(input, 2), axes, windowDimensions, padding) return input / pow((regionAverages * scale + bias), exponent) endfunction function localResponseNormalizationSquare(input, axes, windowLength, scale, bias, exponent) // Only handles square reduction windows. windowDimensions = repeat(axes.size, [windowLength]) leadingPadding = floor((windowLength - 1) / 2) // Center halfway around sliding window trailingPadding = ceil((windowLength - 1) / 2) // Center halfway around sliding window padding = repeat(axes.size * 2, [leadingPadding, trailingPadding]) return localResponseNormalization(input, axes, windowDimensions, padding, scale, bias, exponent) endfunction For each output element, sum all the corresponding inputs in a local window, considering scale, power, and bias. Implementations support either 1D or 2D (some only support 1D). LRN(x, localSize, scaleAlpha, powerBeta, bias = 1) = x / (bias + (scaleAlpha / localSize) * sum(xi^2 for every xi in the local region)) ^ powerBeta ? LRN DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Normalization Lebesgue Length Normalization function lebesgueLengthNormalization(input, axes, exponent) reduced = reduceLp(input, axes, exponent, keepDimensions=1) // reduceL1 or reduceL2 epsilon = 1e-9 return div(input, add(sqrt(reduced), epsilon)) // reduced is implicitly expanded to X endfunction ? LpNormalization (1D) DML_OPERATOR_LP_NORMALIZATION (1D) ? ? ? ? ? ? ? L2NormalizeLayerParams (3D rightmost) ? ? ? ? ? ? ? ? Precision TBD
Sparse tensor collation Nonzero Coordinates List Append the coordinates of every nonzero input value to a list, with coordinates stored interleaved. So, not [[X1,Y1,Z1], [X2,Y2,Z2], …] but rather [[X1,X2,…], [Y1,Y1,…], [Z1,Z2,…]]. function nonzeroCoordinatesList(input) coordinates = [] for each (coordinate, i) in input: if x != 0 then: coordinates.append(coordinate) endif endfor endfunction ? Nonzero DML_OPERATOR_NONZERO_COORDINATES ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
NGram Term Frequency Inverse
Document Frequency Vectorizer
Read input front to back, incrementing output histogram for each occurrence found of desired patterns. It's basically a word count algorithm with the output histogram size equalling the number of words in the dictionary to find. ? TfldfVectorizer --- ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact
Aggregate Einstein Summation
reduceSum(multiply(a.reshape(aToProductshape)[aToProductStep], B.reshape(bToProductshape)[bToProductStep]), axes = reductionAxes) Links: - https://github.com/onnx/onnx/pull/2504 - https://ajcr.net/Basic-guide-to-einsum/ - https://numpy.org/devdocs/reference/generated/numpy.einsum.html - https://www.tensorflow.org/api_docs/python/tf/einsum - https://pytorch.org/docs/1.2.0/_modules/torch/functional.html - https://numpy.org/doc/stable/reference/generated/numpy.dot.html - https://numpy.org/doc/stable/reference/generated/numpy.matmul.html - https://ajcr.net/Basic-guide-to-einsum/ Einsum combines tranposition/multiplication/sum reduction into a single operator using a concise string notation of comma separated input axes and an output (e.g. "i,j->ij" with two 1D inputs and a 2D output), which can represent the following operators: identity, diag, trace, transpose, sum, dot product, matmul, elementwise multiplication, inner product, outer product. For example: - "i,j->ij" means product[i][j] = input0[i] * input1[j]; output[i][j] = product[i][j], two 1D vectors perpendicular to each other, broadcasted, and multiplied to yield a 2D tensor. - "ij,jk->ik" means product[i][k][j] = input0[i][j] * input1[k][j]; output[i][k] = sum(product[i][k]) which is the classic matrix "multiplication" (np.matmul). - "ij,i->" means product[i][j] = sum(input0[i][j] * input1[i]); output = sum(product, axis=(0,1)) where every element in a 2D input0 is multiplied by the corresponding row elements in 1D input1 and then all summed into a 0D output scalar. These can all be expressed generically via... ONNX general form: ReduceSum( VariadicMul( Reproject(A, aAxesToProductAxes), Reproject(B, bAxesToProductAxes), ), reductionAxes, keepdims=false ) Numpy Python general form: np.sum( np.multiply( A.reshape(aToProductshape)[aToProductStep], B.reshape(bToProductshape)[bToProductStep], ), axis = reductionAxes ) VariadicMul and Reproject are helper operators (not actual ONNX). VariadicMul is a variant of Mul that accepts arbitrary input tensors, like ONNX Sum is to Add. Reproject reprojects the input axes to be compatible with the axes order of the intermediate product, using a combination of Transpose, Unsqueeze, and Range+GatherElements along the diagonal (numpy.diag) depending on the input axes and notation. Any of these steps may be a nop, which is how the generic form can represent any one of 10 operators. For example, if all output axes are expected (reductionAxes is empty), then the ReduceSum step is identity. If there is only one input tensor, then the VariadicMul is identity. If the input axes are homogenous and listed in the same order as the output tensor, then the Reproject step is identity. Parsing notation The difficult part isn't the execution (these fundamentals are already supported in DirectML) but rather parsing the input/output axes from the string and mapping them into actionable operations: - Concatenate all the unique input axes in alphabetic order (ignore the output on the right of the arrow) to yield the intermediate product tensor's axes. "j,i" yields "ij". "ik,kj" yields "ijk". "cx,by,az" yields "abcxyz". Note the axes in the final output may only include axes found in the original inputs (else error), an axis name may not be reused more than once in the output (axes can be repeated in input assuming both uses have the same dimension size), and the output rank will be equal to or less than the product tensor rank. Note case matters, where aZ is reordered Za (since ASCII Z < a). - Extract the output axes after the arrow (e.g. the ij from i,j->ij). Mentally you can reorder things, putting the right side of the arrow on the left like normal assignment. So i,j->ij becomes ij = i,j. - If there is no arrow output specified (e.g. "i,j", which differs from "i,j->" where an arrow explicitly yields a 0D scalar) then the output tensor shape defaults to the product tensor ("i,j" becomes "i,j->ij"). If any axis appears more than once in the input terms (either within the same input term or different terms), remove the axis from the output shape ("ik,ij" yields "jk", "iji" yields "j", "i,ik" yields "k", "i,i" yields scalar ""). - Multiply all comma delimited terms for the intermediate product. ij = i,j (originally i,j->ij) becomes ij = i * j, which means product[i][j] = input0[i] * input1[j]. - Normalize all inputs to be broadcast compatible. e.g. product[i][j] = input0[i][0] * input1[0][j]. Using ONNX, this would require a mix of operators (Unsqueeze, Transpose...) whereas with DirectML, these can all be expressed directly via strides. - If any axes present in the intermediate product are missing from the final output then put those into reductionAxes (e.g. einsum('ij->j', A) is missing i, and einsum('ij,jk', A, B) with implicit output i,k is missing j). Then get the sum along the reduced axes, ReduceSum(product, axes=j, keepdims=false). - Transpose the reduced sum to the final output order (e.g "i,j->ji" yields product[i][j] which is remapped to output[j][i]). Note this step doesn't actually need a separate tensor and can be folded into earlier steps, using output strides in the ReduceSum or reordering the input broadcast normalization. - An efficient implementation would attempt to eliminate steps that yield identity or intermediates that can be handled better by another operator (e.g. for MatMul, the product tensor can be eliminated by using the MatMul operator instead). Examples: #E = np.einsum('ii->i', A) A = np.array([[0,1,2],[3,4,5],[6,7,8]]) Areshaped = A.reshape([9])[0:9:4] # Or A2.diagonal(). product = np.multiply(Areshaped, identity) # Nop for only one input term. np.sum(product, axis=()) # Nop given no output axes removed from product. #Z = np.einsum('ij,kj->ik', A, B) #Z = np.inner(A, B) A = np.array([[0,1],[2,3]]) B = np.array([[1,2],[3,4]]) Areshaped = A.reshape([2,1,2]) Breshaped = B.reshape([1,2,2]) Z = np.multiply(Areshaped, Breshaped) Z = Z.sum(Z, axis=2) More examples: Given... A0 = np.array(2) B0 = np.array(3) A1 = np.array([0,1,2]) B1 = np.array([3,4,5]) A2 = np.array([[0,1],[2,3]]) B2 = np.array([[1,2],[3,4]]) A3 = np.array([[[0,1],[2,3]],[[4,5],[6,7]]]) B3 = np.array([[[1,2],[3,4]],[[5,6],[7,8]]])
Call signatureNumPy equivalentDescription
('i', A1)A1returns a view of A1
('i->', A1)sum(A1)sums the values of A1
('i,i->i', A1, B1)A1 * B1element-wise multiplication of A1 and B1
('i,i->', A1, B1)inner(A1, B1) or dot(A1, B1)inner product of A1 and B1
('i,i', A1, B1)inner(A1, B1) or dot(A1, B1)inner product of A1 and B1
('i,j->ij', A1, B1)outer(A1, B1)outer product of A1 and B1
('ij->ij', A2)A2returns a view of A2
('ij', A2)A2returns a view of A2
('ji', A2)A2.Tview transpose of A2
('ji->ij', A2)A2.Tview transpose of A2
('ii->i', A2)diag(A2)view main diagonal of A2
('ii->', A2)trace(A2)sums main diagonal of A2
('ij->', A2)sum(A2)sums the values of A2
('ij->j', A2)sum(A2, axis=0)sum down the columns of A2 (across rows)
('ij->i', A2)sum(A2, axis=1)sum horizontally along the rows of A2
('ij,ij->ij', A2, B2)A2 * B2element-wise multiplication of A2 and B2
('ij,ji->ij', A2, B2)A2 * B2.transpose()element-wise multiplication of A2 and B2.T
('ij,jk', A2, B2)matmul(A2, B2) or dot(A2, B2)matrix multiplication of A2 and B2
('ij,jk->ik', A2, B2)matmul(A2, B2) or dot(A2, B2)matrix multiplication of A2 and B2
('bij,bjk->bik', A2, B2)matmul(A3, B3)matrix multiplication of A3 and B3 (a stack of 2D matrices)
('bij,bkj->bik', A2, B2)matmul(A3, transpose(B3))matrix multiplication of A3 and B3 (a stack of 2D matrices)
('ij,kj->ik', A2, B2)inner(A2, B2)inner product of A2 and B2
('ij,kj->ikj', A2, B2)A2[:, None] * B2each row of A2 multiplied by B2
('ij,kl->ijkl', A2, B2)A2[:, :, None, None] * B2each value of A2 multiplied by B2
(',ij', 3, B2)Scalar times array: array([[ 0, 3, 6], [ 9, 12, 15]])
("ij,j", A2, B1)matvec(A2, B1)Matrix and vector.
("ii,ii->i", A2, B2)A2.diag() * B2.diag()diagonals multiplied by each other
("ii,ii->", A2, B2)dot(A2.diag(), B2.diag())dot product of diagonals
For well known patterns, they can be mapped directly to better known operators: https://github.com/microsoft/onnxruntime/blob/4477f57ee3151287a9759bd09d269f0e258a9eda/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L1583-L1599
? EinSum DML_OPERATOR_GEMM/TRANSPOSE/REDUCE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Aggregate Recurrent Neural Network Y = Activation(Clip(MatMul(X, Transpose(W)) + MatMul(Initial_h, Transpose(R)) + B), -clip, +clip) ? RNN DML_OPERATOR_RNN ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Aggregate Gated Recurrent Unit TODO: Need better summary. Iteratively apply matrix multiplication. Z = Activation1(Clip(MatMul(X, Transpose(W1)) + MatMul(Initial_h1, Transpose(R1)) + b1, -clip, +clip)) R = Activation1(Clip(MatMul(X, Transpose(W2)) + MatMul(Initial_h1, Transpose(R2)) + b2, -clip, +clip)) C = Mul(Initial_h1, R) O = Activation2(Clip(MatMul(X, Transpose(W3)) + MatMul(Initial_h1, Transpose(R3)) + b3, -clip, +clip)) Y = Mul((1-Z), O) + Mul(Z, Initial_h1) W = [W1, W2, W3]; b1 = B[0, :] + B[3*hidden_size, :]; b2 = B[1, :] + B[4*hidden_size, :]; b3 = B[2, :] + B[5*hidden_size, :]; ? GRU DML_OPERATOR_GRU ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Aggregate Gated Recurrent Unit Unit ??? gruCell GRUUnit (delete) NA ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Aggregate Long Short Term Memory TODO: Need better summary. Iteratively apply matrix multiplication. I = Activation1f(Clip(MatMul(X, Transpose(W1)) + MatMul(Initial_h1, Transpose(R1)) + Mul(p, initial_c) + b1), -clip, +clip) F = Activation1f(Clip(MatMul(X, Transpose(W2)) + MatMul(Initial_h1, Transpose(R2)) + Mul(p, initial_c) + b2), -clip, +clip) Z = Activation2g(Clip(MatMul(X, Transpose(W3)) + MatMul(Initial_h1, Transpose(R3)) + b3), -clip, +clip) C = Mul(Initial_h1, F) + Mul(I, Z) O = Activation2g(clip(MatMul(X, Transpose(W4)) + MatMul(Initial_h1, Transpose(R4)) + Mul(p, initial_c) + b4)) Y = Mul(Activation3h(C), O)   W = [W1, W2, W3, W4]; b1 = B[0, :] + B[4*hidden_size, :]; b2 = B[1, :] + B[5*hidden_size, :]; b3 = B[2, :] + B[6*hidden_size, :]; b4 = B[3, :] + B[7*hidden_size, :]; lstm LSTM DML_OPERATOR_LSTM ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Aggregate Long Short Term Memory Unit One occurence of LSTM lstmCell NA DML_OPERATOR_LSTM ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Aggregate Multihead Attention TODO: X X DML_OPERATOR_MULTIHEAD_ATTENTION ? ? ? ? X X torch.nn.MultiheadAttention ? ? ? mlx.nn.MultiHeadAttention MultiHeadAttention ? ? ? ? Precision TBD
Elementwise, Training Dropout For each element, randomly zero it or multiply it by 1 / (1 - ratio). f(X, ratio) = select(lesser(random(0, 0.9999), ratio), 0, mul(X, recip(sub(1, ratio)))), f(x, ratio) = iif(random(0, 0.9999) < ratio, 0, 1 / (1 - ratio) * x); For forward execution, ratio is 0, and so it's equivalent to: f(X) = identity(X) notes: If probability = 1, then all zeroes. If 0, then identity. Selected randomly per element. ? Dropout NA (can use identity, since not used during inference) ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Unpredictable
Deleted, Code Execution A TENsor Kernel ??? ? ATen experimental and deprecated NA ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Unpredictable
Elementwise Math (Deleted) Scale Signal function scaleSignal(X, scale) = mul(X, scale) mul Scale DML_OPERATOR_ELEMENT_WISE_MULTIPLY ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 1 ULP
Elementwise Math (Deleted) Image Scaler f(X) = add(mul(X, scale), reshapeToAxes(biasTensor, X.rank, [1])) // reshape bias to [1,C,1,1] f(x, scale, bias) = x * scale + bias ? ImageScaler DML_OPERATOR_VALUE_SCALE_2D ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Precision TBD
Data reorganization (Deleted) Crop Crop the tensor to the given ranges for each axis. Crop is confusing and redundant. Just use Slice. ? Crop (ONNX Slice subset)? DML_OPERATOR_SLICE ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? Exact