tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1253 @@
1
+ # pylint: disable=possibly-unused-variable
2
+ from typing import Any, Sequence, cast, Literal, NamedTuple, Generator
3
+ import dataclasses, functools, io, math, types, warnings, pathlib, sys, os, struct, enum
4
+ from io import BufferedReader
5
+ from tinygrad.nn.state import TensorIO
6
+ from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
7
+ from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN
8
+ from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate
9
+ from tinygrad.device import is_dtype_supported, Device
10
+
11
+ # ***** protobuf definitions ******
12
+ class WireType(enum.IntEnum):
13
+ """
14
+ Protocol Buffer wire types for decoding fields.
15
+ Reference: https://github.com/protocolbuffers/protobuf/blob/main/python/google/protobuf/internal/wire_format.py#L24-L29
16
+ """
17
+ VARINT = 0; FIXED64 = 1; LENGTH_DELIMITED = 2; START_GROUP = 3; END_GROUP = 4; FIXED32 = 5 # noqa: E702
18
+
19
+ class AttributeType(enum.IntEnum):
20
+ """
21
+ ONNX attribute type identifiers.
22
+ Reference: https://github.com/onnx/onnx/blob/rel-1.18.0/onnx/onnx.proto3#L128-L145
23
+ """
24
+ FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; FLOATS = 6; INTS = 7; STRINGS = 8 # noqa: E702
25
+
26
+ def to_field_name(self) -> str: return {1: "f", 2: "i", 3: "s", 4: "t", 6: "floats", 7: "ints", 8: "strings"}[self.value]
27
+
28
+ class OnnxDataType(enum.IntEnum):
29
+ """
30
+ ONNX tensor data type identifiers.
31
+ Reference: https://github.com/onnx/onnx/blob/rel-1.18.0/onnx/onnx.proto3#L500-L544
32
+ """
33
+ FLOAT = 1; UINT8 = 2; INT8 = 3; UINT16 = 4; INT16 = 5; INT32 = 6; INT64 = 7; BOOL = 9; FLOAT16 = 10; DOUBLE = 11; UINT32 = 12 # noqa: E702
34
+ UINT64 = 13; BFLOAT16 = 16 # noqa: E702
35
+
36
+ def to_dtype(self) -> DType: return dtypes.fields()[self.name.lower()]
37
+
38
+ def dtype_fallback(dtype: DType, fallback_context: str) -> DType:
39
+ if is_dtype_supported(dtype): return dtype
40
+ default_dtype = dtypes.default_int if dtypes.is_int(dtype) else dtypes.default_float
41
+ warnings.warn(f"dtype {dtype} on {Device.DEFAULT} from {fallback_context} is not supported, falling back to {default_dtype}")
42
+ assert is_dtype_supported(default_dtype), f"dtype {default_dtype} must be supported on {Device.DEFAULT}"
43
+ return default_dtype
44
+
45
+ # ***** onnx spec definitions *****
46
+ class Domain(enum.Enum):
47
+ ONNX = "ai.onnx"
48
+ ONNX_ML = "ai.onnx.ml"
49
+ AI_ONNX_TRAINING = "ai.onnx.training"
50
+ AI_ONNX_PREVIEW_TRAINING = "ai.onnx.preview.training"
51
+ MICROSOFT_CONTRIB_OPS = "com.microsoft"
52
+ MICROSOFT_NCHWC = "com.microsoft.nchwc"
53
+ MICROSOFT_EXPERIMENTAL = "com.microsoft.experimental"
54
+ PYTORCH_ATEN = "org.pytorch.aten"
55
+ @classmethod
56
+ def from_onnx(cls, domain: str | None) -> "Domain": return cls.ONNX if domain is None or domain == "" else cls(domain)
57
+
58
+ class OpSetId(NamedTuple):
59
+ domain: Domain
60
+ version: int
61
+
62
+ @dataclasses.dataclass(frozen=True)
63
+ class OnnxValue:
64
+ shape: tuple[str|int, ...]
65
+ dtype: DType
66
+ is_optional: bool
67
+ is_sequence: bool
68
+
69
+ @dataclasses.dataclass(frozen=True)
70
+ class OnnxNode:
71
+ op: str
72
+ opset_id: OpSetId
73
+ inputs: tuple[str, ...]
74
+ outputs: tuple[str, ...]
75
+ opts: dict[str, Any]
76
+
77
+ # ***** protobuf parsing ******
78
+ class PBBufferedReader(BufferedReader):
79
+ def __init__(self, tensor: Tensor):
80
+ assert tensor.dtype == dtypes.uint8, tensor
81
+ super().__init__(TensorIO(tensor))
82
+ self.len = tensor.nbytes()
83
+
84
+ def decode_varint(self) -> int:
85
+ """Reference: https://protobuf.dev/programming-guides/encoding/#varints"""
86
+ result = 0
87
+ shift = 0
88
+ while True:
89
+ data = self.read(1)
90
+ if data == b"": raise EOFError("decode_varint EOF")
91
+ result |= (data[0] & 0x7F) << shift
92
+ if not (data[0] & 0x80): return result
93
+ shift += 7
94
+ if shift >= 70: raise ValueError("Varint too long")
95
+
96
+ def read_delimited(self, use_tensor=False):
97
+ str_len = self.decode_varint()
98
+ if not use_tensor: return self.read(str_len)
99
+ raw = self.raw
100
+ assert isinstance(raw, TensorIO)
101
+ res = raw._tensor[self.tell():(self.tell()+str_len)]
102
+ self.seek(str_len, os.SEEK_CUR)
103
+ return res
104
+ def read_string(self) -> str: return self.read_delimited().decode("utf-8")
105
+ def read_bytes(self) -> Tensor: return self.read_delimited(use_tensor=True)
106
+ def read_float(self) -> float: return struct.unpack("<f", self.read(4))[0]
107
+ def read_packed_floats(self) -> Tensor: return self.read_delimited(use_tensor=True)
108
+ def read_int64(self) -> int: return truncate[dtypes.int64](self.decode_varint())
109
+ def read_packed_int64s(self) -> list[int]:
110
+ total_bytes_len = self.decode_varint()
111
+ old_pos = self.tell()
112
+ values = []
113
+ # need copy here because packed ints are varint
114
+ while self.tell() < total_bytes_len + old_pos: values.append(self.read_int64())
115
+ return values
116
+
117
+ def skip_field(self, wire_type: WireType) -> None:
118
+ """Skip a field based on its wire type."""
119
+ match wire_type:
120
+ case WireType.VARINT: self.decode_varint()
121
+ case WireType.FIXED64: self.seek(8, os.SEEK_CUR)
122
+ case WireType.FIXED32: self.seek(4, os.SEEK_CUR)
123
+ case WireType.LENGTH_DELIMITED: self.seek(self.decode_varint(), os.SEEK_CUR)
124
+ case _: raise ValueError(f"Unknown wire type: {wire_type}")
125
+
126
+ class OnnxPBParser:
127
+ """
128
+ ONNX protobuf parser.
129
+ Reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
130
+ """
131
+ def __init__(self, inp: Tensor|str|pathlib.Path, load_external_data: bool=True):
132
+ self.file_path: pathlib.Path|None = None
133
+ self.load_external_data = load_external_data
134
+ if not isinstance(inp, Tensor):
135
+ self.file_path = pathlib.Path(inp)
136
+ self.tensor = Tensor(self.file_path)
137
+ else: self.tensor = inp
138
+ self.reader = PBBufferedReader(self.tensor)
139
+
140
+ def parse(self) -> dict:
141
+ """Parses the ONNX model into a nested dictionary. """
142
+ return self._parse_ModelProto()
143
+
144
+ def _parse_message(self, end_pos: int) -> Generator[tuple[int, WireType], None, None]:
145
+ while self.reader.tell() < end_pos:
146
+ tag = self.reader.decode_varint()
147
+ yield tag >> 3, WireType(tag & 0x07)
148
+
149
+ def _decode_end_pos(self) -> int:
150
+ str_len = self.reader.decode_varint()
151
+ start_pos = self.reader.tell()
152
+ return start_pos + str_len
153
+
154
+ def _parse_ModelProto(self) -> dict:
155
+ """Entry point for parsing the ONNX model."""
156
+ obj: dict[str, Any] = {"opset_import": []}
157
+ for fid, wire_type in self._parse_message(self.reader.len):
158
+ match fid:
159
+ case 4: obj["domain"] = self.reader.read_string()
160
+ case 5: obj["model_version"] = self.reader.read_int64()
161
+ case 7: obj["graph"] = self._parse_GraphProto()
162
+ case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto())
163
+ case _: self.reader.skip_field(wire_type)
164
+
165
+ # update opset version
166
+ opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]}
167
+ for n in obj["graph"]["node"]:
168
+ n_ = n["parsed_node"]
169
+ n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts)
170
+ return obj
171
+
172
+ def _parse_GraphProto(self) -> dict:
173
+ obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []}
174
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
175
+ match fid:
176
+ case 1: obj["node"].append(self._parse_NodeProto())
177
+ case 2: obj["name"] = self.reader.read_string()
178
+ case 5: obj["initializer"].append(self._parse_TensorProto())
179
+ case 11: obj["input"].append(self._parse_ValueInfoProto())
180
+ case 12: obj["output"].append(self._parse_ValueInfoProto())
181
+ case _: self.reader.skip_field(wire_type)
182
+ return obj
183
+
184
+ def _parse_NodeProto(self) -> dict:
185
+ obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None}
186
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
187
+ match fid:
188
+ case 1: obj["input"].append(self.reader.read_string())
189
+ case 2: obj["output"].append(self.reader.read_string())
190
+ case 3: obj["name"] = self.reader.read_string()
191
+ case 4: obj["op_type"] = self.reader.read_string()
192
+ case 5: obj["attribute"].append(self._parse_AttributeProto())
193
+ case 6: obj["doc_string"] = self.reader.read_string()
194
+ case 7: obj["domain"] = self.reader.read_string()
195
+ case _: self.reader.skip_field(wire_type)
196
+
197
+ # parse node
198
+ attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]}
199
+ opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto
200
+ obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes)
201
+ return obj
202
+
203
+ def _parse_TensorProto(self) -> dict:
204
+ obj: dict[str, Any] = {"dims": []}
205
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
206
+ match fid:
207
+ case 1: obj["dims"].append(self.reader.read_int64())
208
+ case 2: obj["data_type"] = self.reader.read_int64()
209
+ case 4: obj["float_data"] = self.reader.read_packed_floats()
210
+ case 5: obj["int32_data"] = self.reader.read_packed_int64s()
211
+ case 7: obj["int64_data"] = self.reader.read_packed_int64s()
212
+ case 8: obj["name"] = self.reader.read_string()
213
+ case 9: obj["raw_data"] = self.reader.read_bytes()
214
+ case 10: obj["double_data"] = self.reader.read_packed_floats()
215
+ case 11: obj["uint64_data"] = self.reader.read_packed_int64s()
216
+ case 13: obj.setdefault("external_data", []).append(self._parse_StringStringEntryProto())
217
+ case 14: obj["data_location"] = self.reader.read_int64()
218
+ case _: self.reader.skip_field(wire_type)
219
+
220
+ # load external data
221
+ if self.load_external_data and obj.get("data_location", 0) == 1:
222
+ if "external_data" not in obj: raise ValueError("no external_data")
223
+ location, length, offset = None, None, 0
224
+ for kv in obj["external_data"]:
225
+ if kv["key"] == "location": location = kv["value"]
226
+ elif kv["key"] == "offset": offset = int(kv["value"])
227
+ elif kv["key"] == "length": length = int(kv["value"])
228
+ if location is None: raise ValueError("no location in external_data")
229
+
230
+ if self.file_path is None:
231
+ if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
232
+ self.file_path = pathlib.Path(self.tensor.device[5:])
233
+ else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load")
234
+ ext_path = self.file_path.parent.joinpath(location)
235
+ if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}")
236
+
237
+ ext_tensor = Tensor(ext_path)
238
+ obj["raw_data"] = ext_tensor[offset:offset+length] if length is not None else ext_tensor[offset:]
239
+ obj["data_location"] = 0
240
+
241
+ # parse tensor
242
+ to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse")
243
+ shape = tuple(obj['dims'])
244
+ present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj]
245
+ assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}"
246
+ data = obj[present_fields[0]]
247
+ if not isinstance(data, Tensor):
248
+ obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape)
249
+ return obj
250
+ assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data
251
+ data = data.bitcast(true_dtype).reshape(shape)
252
+ data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT)
253
+ # const folding
254
+ if shape == ():
255
+ if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
256
+ data = Tensor(data.item(), dtype=to_dtype).reshape(shape)
257
+ obj["parsed_tensor"] = data
258
+ return obj
259
+
260
+ def _parse_AttributeProto(self) -> dict:
261
+ obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []}
262
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
263
+ match fid:
264
+ case 1: obj["name"] = self.reader.read_string()
265
+ case 2: obj["f"] = self.reader.read_float()
266
+ case 3: obj["i"] = self.reader.read_int64()
267
+ case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8")
268
+ case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor']
269
+ case 7: obj["floats"].append(self.reader.read_float())
270
+ case 8: obj["ints"].append(self.reader.read_int64())
271
+ case 9: obj["strings"].append(self.reader.read_bytes().data().tobytes().decode("utf8"))
272
+ case 20: obj["type"] = self.reader.read_int64()
273
+ case _: self.reader.skip_field(wire_type)
274
+ obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"])
275
+ return obj
276
+
277
+ def _parse_ValueInfoProto(self) -> dict:
278
+ obj: dict[str, Any] = {}
279
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
280
+ match fid:
281
+ case 1: obj["name"] = self.reader.read_string()
282
+ case 2: obj["type"] = self._parse_TypeProto()
283
+ case _: self.reader.skip_field(wire_type)
284
+
285
+ # parse type
286
+ if "type" not in obj: return {**obj, "parsed_type": None}
287
+ type_obj = obj["type"]
288
+ if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"]
289
+ if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"]
290
+ assert "tensor_type" in type_obj, type_obj
291
+ shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', [])
292
+ obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
293
+ OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
294
+ return obj
295
+
296
+ def _parse_TypeProto(self) -> dict:
297
+ obj: dict[str, Any] = {}
298
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
299
+ match fid:
300
+ case 1: obj["tensor_type"] = self._parse_TypeProtoTensor()
301
+ case 4: obj["sequence_type"] = self._parse_TypeProtoSequence()
302
+ case 9: obj["optional_type"] = self._parse_TypeProtoOptional()
303
+ case _: self.reader.skip_field(wire_type)
304
+ return obj
305
+
306
+ def _parse_TypeProtoTensor(self) -> dict:
307
+ obj: dict[str, Any] = {}
308
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
309
+ match fid:
310
+ case 1: obj["elem_type"] = self.reader.read_int64()
311
+ case 2: obj["shape"] = self._parse_TensorShapeProto()
312
+ case _: self.reader.skip_field(wire_type)
313
+ return obj
314
+
315
+ def _parse_TypeProtoSequence(self) -> dict:
316
+ obj = {}
317
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
318
+ match fid:
319
+ case 1: obj["elem_type"] = self._parse_TypeProto()
320
+ case _: self.reader.skip_field(wire_type)
321
+ return obj
322
+
323
+ def _parse_TypeProtoOptional(self) -> dict:
324
+ obj = {}
325
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
326
+ match fid:
327
+ case 1: obj["elem_type"] = self._parse_TypeProto()
328
+ case _: self.reader.skip_field(wire_type)
329
+ return obj
330
+
331
+ def _parse_TensorShapeProto(self) -> dict:
332
+ obj: dict[str, Any] = {"dim": []}
333
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
334
+ match fid:
335
+ case 1: obj["dim"].append(self._parse_TensorShapeProtoDimension())
336
+ case _: self.reader.skip_field(wire_type)
337
+ return obj
338
+
339
+ def _parse_TensorShapeProtoDimension(self) -> dict:
340
+ obj: dict[str, Any] = {}
341
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
342
+ match fid:
343
+ case 1: obj["dim_value"] = self.reader.read_int64()
344
+ case 2: obj["dim_param"] = self.reader.read_string()
345
+ case _: self.reader.skip_field(wire_type)
346
+ return obj
347
+
348
+ def _parse_StringStringEntryProto(self) -> dict:
349
+ obj: dict[str, Any] = {}
350
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
351
+ match fid:
352
+ case 1: obj["key"] = self.reader.read_string()
353
+ case 2: obj["value"] = self.reader.read_string()
354
+ case _: self.reader.skip_field(wire_type)
355
+ return obj
356
+
357
+ def _parse_OperatorSetIdProto(self) -> dict:
358
+ obj: dict[str, Any] = {}
359
+ for fid, wire_type in self._parse_message(self._decode_end_pos()):
360
+ match fid:
361
+ case 1: obj["domain"] = self.reader.read_string()
362
+ case 2: obj["version"] = self.reader.read_int64()
363
+ case _: self.reader.skip_field(wire_type)
364
+ return obj
365
+
366
+ # ***** python const *****
367
+ required_input_python_consts: dict[str, tuple[int, ...]] = {
368
+ "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
369
+ "CumSum": (1,), "TopK": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
370
+ "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
371
+ **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
372
+ **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
373
+ }
374
+
375
+ cache_misses = 0
376
+ @functools.cache
377
+ def _cached_to_python_const(t:Tensor):
378
+ if t.dtype == dtypes.uint8: return t.data().tobytes()
379
+ if 0 in t.shape: return []
380
+ return t.tolist()
381
+
382
+ # Tensor -> python value cache for parameters
383
+ def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
384
+ if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t
385
+ global cache_misses
386
+ ret = _cached_to_python_const(t)
387
+ if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
388
+ print(f"Cache miss for {t}")
389
+ cache_misses = info.misses
390
+ return ret
391
+
392
+ # ***** runner ******
393
+ debug = int(getenv("DEBUGONNX", "0"))
394
+ limit = int(getenv("ONNXLIMIT", "-1"))
395
+ class OnnxRunner:
396
+ """
397
+ `OnnxRunner` executes an ONNX model using Tinygrad.
398
+
399
+ Args:
400
+ model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor.
401
+ """
402
+ def __init__(self, model_path: Tensor | str | pathlib.Path):
403
+ model = OnnxPBParser(model_path, load_external_data=True).parse()
404
+ graph = model["graph"]
405
+ self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
406
+ self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}}
407
+ self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
408
+ self.graph_outputs = tuple(o["name"] for o in graph["output"])
409
+ self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"])
410
+
411
+ self.old_training = Tensor.training
412
+ Tensor.training = True if self.is_training else False
413
+
414
+ self.variable_dims: dict[str, int] = {}
415
+ self.onnx_ops = onnx_ops
416
+
417
+ def _parse_input(self, name: str, value: Any, spec: OnnxValue):
418
+ if spec.is_optional and value is None: return None
419
+ if spec.is_sequence:
420
+ if not isinstance(value, Sequence): raise RuntimeError(f"input {name} received {value}, expected a sequence type")
421
+ sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
422
+ if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for input {name} sequence must be homogeneous")
423
+ if not all(t.dtype is spec.dtype for t in sequence): warnings.warn(f"Dtypes for input {name} sequence aren't all {spec.dtype}")
424
+ return sequence
425
+ dtype = _from_np_dtype(value.dtype) if is_numpy_ndarray(value) else spec.dtype
426
+ tensor = Tensor(value, dtype=dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
427
+ if tensor.dtype is not spec.dtype: warnings.warn(f"input {name} has mismatch on dtype. Expected {spec.dtype}, received {tensor.dtype}.")
428
+ for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
429
+ if isinstance(onnx_dim, str):
430
+ onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
431
+ if user_dim_input != onnx_dim: raise RuntimeError(f"input {name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
432
+ return tensor
433
+
434
+ def _select_op(self, op:str, required_opset:OpSetId) -> types.FunctionType:
435
+ if op not in self.onnx_ops: raise NotImplementedError(f"{op=} is not supported")
436
+ # return default implementation if no opset_id is specified
437
+ if isinstance(impl := self.onnx_ops[op], types.FunctionType): return impl
438
+ # match domain and select implementation with latest compatible version
439
+ eligible_ops = {impl_opset.version:impl_fxn for impl_opset,impl_fxn in impl.items()
440
+ if impl_opset.domain == required_opset.domain and impl_opset.version <= required_opset.version}
441
+ if not eligible_ops: raise NotImplementedError(f"{op=} is not supported for domain {required_opset.domain} and version {required_opset.version}")
442
+ return eligible_ops[max(eligible_ops.keys())]
443
+
444
+ def get_empty_input_data(self, device:str|None=None, dtype:DType|None=None) -> dict[str, Tensor]:
445
+ return {name:Tensor.empty(*spec.shape, device=device, dtype=dtype or spec.dtype) for name, spec in self.graph_inputs.items()}
446
+
447
+ def to(self, device:str|None):
448
+ self.graph_values = {k:v.to(device) if isinstance(v, Tensor) else v for k,v in self.graph_values.items()}
449
+ self.graph_nodes = tuple(OnnxNode(n.op, n.opset_id, tuple(n.inputs), tuple(n.outputs),
450
+ {k:v.to(device) if isinstance(v, Tensor) else v for k,v in n.opts.items()}) for n in self.graph_nodes)
451
+ return self
452
+
453
+ def __call__(self, inputs:dict[str, Any], debug=debug):
454
+ for name, input_spec in self.graph_inputs.items():
455
+ if name not in inputs: raise RuntimeError(f"Please provide input data for {name}")
456
+ self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
457
+
458
+ for num, node in enumerate(self.graph_nodes):
459
+ inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)]
460
+ opts = node.opts
461
+
462
+ # provide additional opts
463
+ if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
464
+ if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values
465
+
466
+ if debug >= 1: print(f"{num}: op '{node.op}' opt {opts}")
467
+ if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
468
+ ret = self._select_op(node.op, node.opset_id)(*inps, **opts)
469
+ ret = ret if isinstance(ret, tuple) else (ret,)
470
+ if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret)))
471
+
472
+ self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True)))
473
+
474
+ if num == limit:
475
+ Tensor.training = self.old_training
476
+ return {name:self.graph_values[name] for name in node.outputs}
477
+ Tensor.training = self.old_training
478
+ return {name:self.graph_values[name] for name in self.graph_outputs}
479
+
480
+ ####################
481
+ ##### ONNX OPS #####
482
+ ####################
483
+ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionType]]:
484
+ # ***** helper functions *****
485
+ def _resolve_const(x: Sequence[ConstType]|ConstType): return get_single_element(x) if isinstance(x, Sequence) else x
486
+
487
+ def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
488
+
489
+ # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
490
+ def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
491
+
492
+ AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
493
+ # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
494
+ def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
495
+ if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
496
+ return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
497
+
498
+ def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
499
+ if auto_pad == "VALID": return [0]*(len(k_)*2)
500
+ i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
501
+ if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
502
+ o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
503
+ return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
504
+
505
+ def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
506
+
507
+ def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0):
508
+ if axis < 0: axis += x.ndim
509
+ # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_quantize_linear.py#L31
510
+ def reshape(val:Tensor):
511
+ if val.numel() == 1: return val
512
+ if block_size == 0: return val.reshape([val.shape[0] if dim == axis else 1 for dim in range(x.ndim)])
513
+ return val.repeat_interleave(block_size, axis)
514
+ return (reshape(scale), reshape(zero_point) if isinstance(zero_point, Tensor) else zero_point)
515
+
516
+ def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
517
+ adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
518
+ return op(*adjusted_inputs, **opts)
519
+
520
+ def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
521
+ # op execution is done in quantized int
522
+ out = _op_integer(op, inputs, zero_points, **opts)
523
+ assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
524
+ out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
525
+ return _clamp_cast(out_quantized, out_zero_point.dtype)
526
+
527
+ def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
528
+ # op execution is done in float32
529
+ dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
530
+ out = op(*dequantized_inputs, **opts)
531
+ assert dtypes.is_float(out.dtype), "op should've done math in float"
532
+ out_quantized = (out / out_scale).round() + out_zero_point
533
+ return _clamp_cast(out_quantized, out_zero_point.dtype)
534
+
535
+ def _onnx_training(input_group_size):
536
+ def __decorator(func):
537
+ def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
538
+ R = R.detach()
539
+ groups = len(inputs) // input_group_size
540
+ ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
541
+ return tuple(flatten(zip(*ret)))
542
+ return ___wrapper
543
+ return __decorator
544
+
545
+ # ***** Property/Graph Ops *****
546
+ def Identity(x:Tensor): return x
547
+ def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
548
+ value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
549
+ if value is not None: return value
550
+ if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
551
+ if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
552
+ if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
553
+ if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
554
+ if value_string is not None or value_strings is not None or sparse_value is not None:
555
+ raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
556
+
557
+ def Range(start:float|int|list[float|int], limit:float|int|list[float|int], delta:float|int|list[float|int]):
558
+ return Tensor.arange(start=_resolve_const(start), stop=_resolve_const(limit), step=_resolve_const(delta))
559
+
560
+ def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
561
+ try: import PIL.Image
562
+ except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
563
+ img = PIL.Image.open(io.BytesIO(encoded_stream))
564
+ if pixel_format == "BGR": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3).flip(-1)
565
+ if pixel_format == "RGB": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3)
566
+ if pixel_format == "Grayscale": return Tensor(img.convert("L").tobytes(), dtype=dtypes.uint8).reshape(*img.size, 1)
567
+ raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
568
+
569
+ def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
570
+ ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_fallback(OnnxDataType(dtype).to_dtype(), "EyeLike op") if dtype is not None else x.dtype)
571
+ return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
572
+
573
+ def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
574
+ def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
575
+ def ConstantOfShape(shape:list[int], value:Tensor|None=None):
576
+ if value is None: value = Tensor(0, dtype=dtypes.float32)
577
+ if shape == [0]: return Tensor([], dtype=value.dtype)
578
+ return value.expand(shape)
579
+
580
+ def Size(data:Tensor): return data.numel()
581
+ def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
582
+
583
+ # ***** Unary Ops (math) *****
584
+ def Not(x:Tensor): return x.logical_not()
585
+ def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): return x if min is None and max is None else x.clip(min, max) # noqa: A002 # pylint: disable=redefined-builtin
586
+ def IsInf(x:Tensor, detect_negative:int=1, detect_positive:int=1): return x.isinf(bool(detect_positive), bool(detect_negative))
587
+
588
+ # ***** Unary Ops (activation) *****
589
+ def softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
590
+ def softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
591
+ Softmax = {OpSetId(Domain.ONNX, 1):softmax_1, OpSetId(Domain.ONNX, 13):softmax_13}
592
+ def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
593
+ def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
594
+ def BiasGelu(x: Tensor, bias: Tensor, approximate: str | None = None) -> Tensor: return Gelu(x + bias, approximate)
595
+ def FastGelu(x:Tensor, bias:Tensor|None=None): return (x + bias).gelu() if bias is not None else x.gelu() # this is tanh approximated
596
+ def PRelu(X:Tensor, slope:Tensor): return (X > 0).where(X, X * slope)
597
+ def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leaky_relu(alpha)
598
+ def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
599
+ def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
600
+ def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
601
+
602
+ # ***** Unary Ops (broadcasted) *****
603
+ def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y
604
+ def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
605
+ def Div(x:Tensor,y:Tensor): return x.div(y, rounding_mode='trunc' if dtypes.is_int(x.dtype) else None)
606
+ def Less(x:Tensor,y:Tensor): return x < y
607
+ def LessOrEqual(x:Tensor,y:Tensor): return x <= y
608
+ def Greater(x:Tensor,y:Tensor): return x > y
609
+ def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
610
+ def Equal(x:Tensor,y:Tensor): return x == y
611
+ def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
612
+ def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
613
+ def Xor(x:Tensor,y:Tensor): return x.bool().bitwise_xor(y.bool())
614
+ def BitwiseAnd(x:Tensor,y:Tensor): return x & y
615
+ def BitwiseOr(x:Tensor,y:Tensor): return x | y
616
+ def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
617
+ def BitwiseNot(x:Tensor): return ~x
618
+ def Mod(x:Tensor,y:Tensor,fmod=0): return x - x.div(y, rounding_mode="trunc") * y if fmod else x % y
619
+
620
+ # ***** Casting Ops *****
621
+ # TODO: saturate
622
+ def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_fallback(OnnxDataType(to).to_dtype(), "Cast op"))
623
+ def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
624
+
625
+ # ***** Reduce Ops *****
626
+ def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
627
+ def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
628
+ def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
629
+ def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
630
+ def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
631
+ return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
632
+ def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
633
+ return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
634
+ def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
635
+ return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
636
+ def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
637
+ return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
638
+ def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
639
+ return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
640
+ def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
641
+ return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
642
+ def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
643
+ return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
644
+ def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
645
+ return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
646
+ def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
647
+ return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
648
+ def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
649
+ return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
650
+ def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
651
+ if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
652
+ return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
653
+ def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
654
+ return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
655
+
656
+ # ***** Movement Ops *****
657
+ def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
658
+ return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
659
+ def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
660
+ def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
661
+ def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
662
+ def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=perm or list(range(x.ndim)[::-1]))
663
+
664
+ def Squeeze(data:Tensor, axes:list[int]|None=None):
665
+ return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
666
+ def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
667
+
668
+ def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
669
+ def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
670
+ def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
671
+ axes = axes or list(range(data.ndim))
672
+ steps = steps or [1]*data.ndim
673
+ slices = [slice(0,x,1) for x in data.shape]
674
+ for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
675
+ return data[tuple(slices)]
676
+
677
+ def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0):
678
+ sz = data.shape[axis]
679
+ if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
680
+ return data.split(split, axis)
681
+
682
+ def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
683
+ mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
684
+ value = constant_value or value
685
+ axes = axes or list(range(x.ndim))
686
+ real_pads = [0] * (x.ndim*2)
687
+ for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
688
+ return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
689
+
690
+ def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
691
+ shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
692
+ pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
693
+ for s, x in zip(shape, axes or range(t.ndim)):
694
+ tx = t.shape[x]
695
+ if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
696
+ elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
697
+ return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
698
+
699
+ # ***** Processing Ops *****
700
+ def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
701
+ dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
702
+ pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
703
+ return X.avg_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
704
+
705
+ def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
706
+ storage_order:int=0, strides:list[int]|int=1):
707
+ pool_pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
708
+ out = X.max_pool2d(tuple(kernel_shape), strides, dilations, pool_pads, ceil_mode=ceil_mode, return_indices=True)
709
+ ret, idx = cast(tuple[Tensor, Tensor], out)
710
+ return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64)
711
+
712
+ def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
713
+ kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
714
+ return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
715
+ padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
716
+
717
+ def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
718
+ kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
719
+ strides:list[int]|int=1):
720
+ input_shape_, kernel_shape_ = X.shape[2:], (kernel_shape or W.shape[2:])
721
+ strides_, dilations_, output_padding_ = (make_tuple(x, len(input_shape_)) for x in (strides, dilations, output_padding))
722
+ if output_shape is not None: # we pad according to output_shape
723
+ pads = _auto_pad([s_*(i-1) + op_ + ((k_-1)*d_+1) - os for s_,i,op_,k_,d_,os in
724
+ zip(strides_, input_shape_, output_padding_, kernel_shape_, dilations_, output_shape)], auto_pad)
725
+ if pads is None: # we generate pads
726
+ output_shape = output_shape or [X.shape[i+2] * strides_[i] for i in range(len(strides_))]
727
+ pads = [strides_[i]*(input_shape_[i]-1)+output_padding_[i]+((kernel_shape_[i]-1)*dilations_[i]+1)-output_shape[i]
728
+ for i in range(len(input_shape_))]
729
+ pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape_) * 2
730
+ pads = _onnx_pads_to_tiny_pads(pads)
731
+ return X.conv_transpose2d(W, B, group, strides_, dilations_, pads, output_padding_)
732
+
733
+ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]|None=None, pads:list[int]|int=0,
734
+ strides:list[int]|int=1):
735
+ if kernel_shape is None: kernel_shape = []
736
+ pads_: int | tuple[int, ...] = tuple(pads) if isinstance(pads, list) else pads
737
+ return Tensor.max_unpool2d(xT, xI, tuple(kernel_shape), strides, 1, pads_, outshape if outshape is None else tuple(outshape))
738
+
739
+ def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
740
+ def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
741
+
742
+ def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
743
+ ret = alpha * (A.transpose(transA) @ B.transpose(transB))
744
+ if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
745
+ return ret
746
+
747
+ def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
748
+
749
+ def CumSum(X:Tensor, axis:int|list[int], exclusive:int=0, reverse:int=0):
750
+ axis = X._resolve_dim(_resolve_const(axis))
751
+ if reverse: X = X.flip(axis)
752
+ if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
753
+ .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
754
+ return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
755
+
756
+ def Trilu(x:Tensor, k:int|list[int]=0, upper:int=1):
757
+ k_ = _resolve_const(k)
758
+ return x.triu(k_) if upper else x.tril(k_)
759
+
760
+ def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
761
+ axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
762
+ extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
763
+ def _apply_transformation(input_sz, output_sz, scale_dim, mode):
764
+ index = Tensor.arange(output_sz, requires_grad=False, device=X.device)
765
+ if mode == "half_pixel": return (index + 0.5) / scale_dim - 0.5
766
+ if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else Tensor.zeros_like(index)
767
+ if mode == "asymmetric": return index / scale_dim
768
+ if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else Tensor.zeros_like(index)
769
+ if mode == "half_pixel_symmetric":
770
+ output_dim_scaled = input_sz * scale_dim
771
+ return (input_sz / 2) * (1 - (output_sz / output_dim_scaled)) + (index + 0.5) / scale_dim - 0.5
772
+ raise ValueError(f"invalid {coordinate_transformation_mode=}")
773
+
774
+ if antialias: raise NotImplementedError("antialias is not implemented")
775
+ axes = axes or list(range(X.ndim))
776
+ perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes)
777
+ # we pre-permute the axes and permute back after resize
778
+ # the permute aligns X's axes to scales, sizes, and roi
779
+ X = X.permute(*perm)
780
+
781
+ input_shape = cast(tuple[int, ...], X.shape[2:])
782
+ if scales is not None: assert all(sc==1 for sc in scales[:-len(input_shape)]), "resizing batch_size dim or channel dim not supported"
783
+ if sizes is not None: assert tuple(sizes[:-2]) == tuple(X.shape[X.ndim-len(sizes):-2]), "resizing batch_size dim or channel dim not supported"
784
+
785
+ scales, sizes = (None if scales is None else scales[-len(input_shape):]), (None if sizes is None else sizes[-len(input_shape):])
786
+ if sizes is not None:
787
+ if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]:
788
+ scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max
789
+ scale = scale_fxn(sz / sh for sz,sh in zip(sizes, input_shape))
790
+ sizes, scales = [int(scale * sh + 0.5) for sh in input_shape], [scale]*len(input_shape)
791
+ else: scales = [sz / sh for sz, sh in zip(sizes, input_shape)]
792
+ else:
793
+ assert scales is not None, "either sizes or scales must be provided"
794
+ sizes = [int(sc * sh) for sc, sh in zip(scales, input_shape)]
795
+
796
+ if all(sz == sh for sz, sh in zip(sizes, input_shape)): return X.permute(*argsort(perm)) if perm else X
797
+
798
+ indexes = []
799
+ for input_sz, output_sz, scale in zip(input_shape, sizes, scales):
800
+ indexes.append(_apply_transformation(input_sz, output_sz, scale, coordinate_transformation_mode))
801
+
802
+ if mode in ["nearest", "linear"]: indexes = [idx.clip(0, sz-1) for idx, sz in zip(indexes, input_shape)]
803
+
804
+ if mode == "nearest":
805
+ mode_operations = {
806
+ "round_prefer_floor": lambda idx: (idx - 0.5).ceil(),
807
+ "round_prefer_ceil": lambda idx: (idx + 0.5).floor(),
808
+ "floor": lambda idx: idx.floor(),
809
+ "ceil": lambda idx: idx.ceil()
810
+ }
811
+ if nearest_mode not in mode_operations: raise ValueError(f"invalid {nearest_mode=}")
812
+ indexes = [mode_operations[nearest_mode](idx).int() for idx in indexes]
813
+ X = X[(..., *Tensor.meshgrid(*indexes))]
814
+
815
+ if mode == "linear":
816
+ expand = list(X.shape)
817
+ for i in range(-len(sizes), 0):
818
+ reshape, index = [1] * X.ndim, indexes[i]
819
+ reshape[i] = expand[i] = sizes[i]
820
+ low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
821
+ X = X.gather(i, low).lerp(X.gather(i, high), perc)
822
+
823
+ if mode == "cubic":
824
+ A = cubic_coeff_a
825
+
826
+ # Keys weights
827
+ # see piecewise function in: https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
828
+ def W0_1(x:Tensor): return polyN(x, [A + 2, -(A + 3), 0, 1])
829
+ def W1_2(x: Tensor): return polyN(x, [A, -5 * A, 8 * A, -4 * A])
830
+
831
+ expand = list(X.shape)
832
+ for i in range(-len(sizes), 0):
833
+ input_sz = cast(int, X.shape[i])
834
+ reshape, index = [1] * X.ndim, indexes[i]
835
+ reshape[i] = expand[i] = sizes[i]
836
+
837
+ p = index.floor().int()
838
+ ratio = index - p # in [0, 1]
839
+
840
+ # Neighbor indices
841
+ idx0, idx1, idx2, idx3 = [p + d for d in [-1, 0, 1, 2]]
842
+ # Weights of distance from index and neighbor indices
843
+ c0, c1, c2, c3 = W1_2(ratio+1), W0_1(ratio), W0_1(-(ratio-1)), W1_2(-(ratio-2))
844
+
845
+ if exclude_outside:
846
+ c0 = ((idx0 >= 0) & (idx0 < input_sz)).where(c0, 0)
847
+ c1 = ((idx1 >= 0) & (idx1 < input_sz)).where(c1, 0)
848
+ c2 = ((idx2 >= 0) & (idx2 < input_sz)).where(c2, 0)
849
+ c3 = ((idx3 >= 0) & (idx3 < input_sz)).where(c3, 0)
850
+
851
+ total = c0 + c1 + c2 + c3
852
+ c0, c1, c2, c3 = c0 / (total + 1e-9), c1 / (total + 1e-9), c2 / (total + 1e-9), c3 / (total + 1e-9)
853
+
854
+ # Reshape and expand
855
+ expanded_indices = [y.clip(0, input_sz - 1).reshape(reshape).expand(expand) for y in [idx0, idx1, idx2, idx3]]
856
+ expanded_coeffs = [y.reshape(reshape).expand(expand) for y in [c0, c1, c2, c3]]
857
+
858
+ # Gather values and apply coefficients
859
+ gathered_values = [X.gather(i, idx) for idx in expanded_indices]
860
+ X = sum(v * c for v, c in zip(gathered_values, expanded_coeffs))
861
+ return X.permute(*argsort(perm)) if perm else X
862
+ def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
863
+
864
+ def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): # noqa: A002 # pylint: disable=redefined-builtin
865
+ val, idx = X.topk(_resolve_const(K), axis, bool(largest), bool(sorted))
866
+ return val, idx.cast(dtypes.int64)
867
+
868
+ # ***** Neural Network Ops *****
869
+ def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
870
+ training_mode:int=0, spatial=1, is_test=0):
871
+ if training_mode:
872
+ x_detached = X.detach()
873
+ current_mean = x_detached.mean(axis=(0,2,3))
874
+ y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
875
+ current_var = (y*y).mean(axis=(0,2,3))
876
+ current_invstd = current_var.add(epsilon).rsqrt()
877
+
878
+ running_mean = input_mean * momentum + current_mean * (1 - momentum)
879
+ running_var = input_var * momentum + current_var * (1 - momentum)
880
+
881
+ return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
882
+ return X.batchnorm(scale, B, input_mean, (input_var + epsilon).rsqrt())
883
+ def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
884
+ x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape)
885
+ return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2))
886
+ def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
887
+ return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon)
888
+ def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
889
+ assert stash_type == 1, "only float32 is supported"
890
+ axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
891
+ mean = x.mean(axis=axes, keepdim=True)
892
+ return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
893
+ def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
894
+ x = x + skip
895
+ if bias is not None: x = x + bias
896
+ ret = x.layernorm(eps=epsilon) * gamma
897
+ if beta is not None: ret = ret + beta
898
+ return ret, None, None, x
899
+ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
900
+ segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
901
+ position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
902
+ # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
903
+ assert (segment_ids is None) is (segment_embedding is None)
904
+ assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
905
+ input_shape = input_ids.shape
906
+ seq_length = input_shape[1]
907
+ compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
908
+ vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
909
+ type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
910
+
911
+ def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
912
+ return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
913
+
914
+ # bert embedding layer
915
+ if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
916
+ wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
917
+ pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
918
+ seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
919
+
920
+ embedding_sum = wrd_embedding_res + pos_embedding_res
921
+ if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
922
+ out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
923
+ return out, None, embedding_sum
924
+ def MeanVarianceNormalization(x:Tensor, axis:list[int]|None=None):
925
+ if axis is None: axis = [0,2,3]
926
+ return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
927
+
928
+ def OneHot(indices:Tensor, depth:float|int|list[int|float], values:Tensor, axis:int=-1):
929
+ # Scalar or Rank 1 tensor containing exactly one element
930
+ depth = int(_resolve_const(depth))
931
+ indices = indices.int()
932
+ indices = (indices < 0).where(indices+depth, indices)
933
+ return indices.unsqueeze(axis)._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
934
+
935
+ def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
936
+ return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
937
+ def SpaceToDepth(X:Tensor, blocksize:int):
938
+ return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
939
+
940
+ # Reimplemented here because you need legacy RNG for passing ONNX tests.
941
+ def dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
942
+ import numpy as np
943
+ if not training_mode: return data, data.full_like(True, dtype=dtypes.bool)
944
+ if seed is not None:
945
+ rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), requires_grad=False, dtype=data.dtype, device=data.device)
946
+ else:
947
+ rand = data.rand_like(requires_grad=False)
948
+ mask = rand >= ratio
949
+ return data * mask / (1.0 - ratio), mask
950
+ # 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
951
+ def dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return dropout_7(data, ratio, training_mode=not is_test)
952
+ Dropout = {OpSetId(Domain.ONNX, 6):dropout_6, OpSetId(Domain.ONNX, 7):dropout_7}
953
+
954
+ def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
955
+ pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
956
+ return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
957
+
958
+ def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
959
+ return x.nll_loss(target, weight, ignore_index, reduction)
960
+ def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
961
+ log_probs = scores.log_softmax(1)
962
+ return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
963
+
964
+ def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
965
+ N, _, *spatial_dims = size
966
+ def generate_grid(steps):
967
+ if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device)
968
+ return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
969
+ grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
970
+ base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
971
+ base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
972
+ return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
973
+
974
+ def attention_contrib(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None,
975
+ attention_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0,
976
+ num_heads:int|None=None, past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None,
977
+ rotary_embedding_dim:int|None=None, scale:float|None=None, unidirectional:int=0):
978
+ assert not do_rotary and not attention_bias, "TODO"
979
+ if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3
980
+ qkv = x.linear(weights, bias)
981
+ q, k, v = qkv.split(qkv_hidden_sizes, dim=2)
982
+
983
+ batch_size, seq_len, _ = x.shape
984
+ assert num_heads is not None, "num_heads must be provided"
985
+ q_head_size, k_head_size, v_head_size = (sz // num_heads for sz in qkv_hidden_sizes)
986
+ q, k, v = (x.reshape(batch_size, seq_len, num_heads, hsz).transpose(1, 2) for x, hsz in zip((q, k, v), (q_head_size, k_head_size, v_head_size)))
987
+
988
+ present = None
989
+ if past is not None:
990
+ k, v = past[0].cat(k, dim=2), past[1].cat(v, dim=2)
991
+ present = k.stack(v)
992
+
993
+ if scale is None: scale = 1.0 / math.sqrt(q_head_size)
994
+ attn_scores = q @ k.transpose(-1, -2) * scale
995
+
996
+ if mask_index is not None:
997
+ assert 4 >= mask_index.ndim >= 1, f"{mask_index.ndim=}"
998
+ assert isinstance(batch_size, int), f"{batch_size=}"
999
+ if mask_index.ndim != 1: mask = mask_index.bool()
1000
+ else:
1001
+ if mask_index.shape[0] == batch_size:
1002
+ mask = Tensor.arange(attn_scores.shape[-1], requires_grad=False, device=mask_index.device).unsqueeze(0) < mask_index.unsqueeze(1)
1003
+ elif mask_index.shape[0] == 2*batch_size:
1004
+ end_positions = mask_index[:batch_size]
1005
+ start_positions = mask_index[batch_size:]
1006
+ arange = Tensor.arange(seq_len).unsqueeze(0)
1007
+ mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1))
1008
+ else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented")
1009
+ while mask.ndim < 4: mask = mask.unsqueeze(1)
1010
+ attn_scores = mask.where(attn_scores, mask_filter_value)
1011
+
1012
+ if unidirectional:
1013
+ causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril()
1014
+ attn_scores = causal_mask.where(attn_scores, mask_filter_value)
1015
+
1016
+ output = attn_scores.softmax(-1) @ v
1017
+ output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
1018
+ return output, present
1019
+
1020
+ def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None,
1021
+ is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None,
1022
+ softcap:float=0.0, softmax_precision:int|None=None):
1023
+ input_shape_len = Q.ndim
1024
+ if input_shape_len == 3:
1025
+ assert q_num_heads is not None and kv_num_heads is not None
1026
+ Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1)
1027
+ K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1)
1028
+ V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1)
1029
+
1030
+ if past_key is not None: K = past_key.cat(K, dim=2)
1031
+ if past_value is not None: V = past_value.cat(V, dim=2)
1032
+ present_key, present_value = K, V
1033
+
1034
+ _q_heads, _kv_heads = q_num_heads or Q.shape[1], kv_num_heads or K.shape[1]
1035
+ if _q_heads != _kv_heads:
1036
+ K = K.repeat((1, _q_heads // _kv_heads, 1, 1))
1037
+ V = V.repeat((1, _q_heads // _kv_heads, 1, 1))
1038
+
1039
+ effective_scale = scale if scale is not None else 1.0 / (cast(int, Q.shape[-1]) ** 0.5)
1040
+ scores = (Q @ K.transpose(-1, -2)) * effective_scale
1041
+ qk_matmul_return_val = scores
1042
+
1043
+ if is_causal:
1044
+ causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, requires_grad=False).tril(0)
1045
+ scores = scores.masked_fill(causal_mask.logical_not(), -float("inf"))
1046
+
1047
+ if attn_mask is not None:
1048
+ mask_to_add = attn_mask.where(0, -float("inf")) if attn_mask.dtype == dtypes.bool else attn_mask
1049
+ scores = scores + mask_to_add
1050
+ if qk_matmul_output_mode == 1: qk_matmul_return_val = scores
1051
+
1052
+ if softcap > 0.0: scores = (scores / softcap).tanh() * softcap
1053
+ if qk_matmul_output_mode == 2: qk_matmul_return_val = scores
1054
+
1055
+ if softmax_precision: scores = scores.cast({1: dtypes.float32, 10: dtypes.float16, 16: dtypes.bfloat16}[softmax_precision])
1056
+ qk_softmax = scores.softmax(-1).cast(Q.dtype)
1057
+ if qk_matmul_output_mode == 3: qk_matmul_return_val = qk_softmax
1058
+
1059
+ output = (qk_softmax @ V).cast(Q.dtype)
1060
+ if input_shape_len == 3: output = output.permute(0, 2, 1, 3).reshape(Q.shape[0], Q.shape[2], -1)
1061
+ return output, present_key, present_value, qk_matmul_return_val
1062
+ Attention = {OpSetId(Domain.ONNX, 1): attention_onnx, OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1): attention_contrib}
1063
+
1064
+ def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5):
1065
+ norm = X.square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt()
1066
+ return X * norm * scale
1067
+
1068
+ def RotaryEmbedding(X:Tensor, cos_cache:Tensor, sin_cache:Tensor, position_ids:Tensor|None=None, interleaved:int=0, num_heads:int|None=None,
1069
+ rotary_embedding_dim:int=0):
1070
+ original_input_shape = X.shape
1071
+
1072
+ if X.ndim == 4: X = X.permute(0, 2, 1, 3)
1073
+ elif X.ndim == 3:
1074
+ assert num_heads is not None, "num_heads must be provided for 3D input"
1075
+ X = X.unflatten(-1, (num_heads, X.shape[-1] // num_heads))
1076
+
1077
+ head_size = cast(int, X.shape[-1])
1078
+ rot_dim = rotary_embedding_dim or head_size
1079
+ x_rotate, x_pass = X[..., :rot_dim], X[..., rot_dim:]
1080
+
1081
+ cos = cos_cache[position_ids] if position_ids is not None else cos_cache[:head_size]
1082
+ sin = sin_cache[position_ids] if position_ids is not None else sin_cache[:head_size]
1083
+ cos = cos[..., :rot_dim//2].unsqueeze(2)
1084
+ sin = sin[..., :rot_dim//2].unsqueeze(2)
1085
+
1086
+ x1, x2 = (x_rotate[..., ::2], x_rotate[..., 1::2]) if interleaved else x_rotate.chunk(2, dim=-1)
1087
+ real = x1 * cos - x2 * sin
1088
+ imag = x1 * sin + x2 * cos
1089
+ x_rotated = real.stack(imag, dim=-1).flatten(start_dim=-2) if interleaved else real.cat(imag, dim=-1)
1090
+
1091
+ output = x_rotated.cat(x_pass, dim=-1)
1092
+ return output.flatten(start_dim=2) if len(original_input_shape) == 3 else output.permute(0, 2, 1, 3)
1093
+
1094
+ # ***** Indexing Ops *****
1095
+ def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
1096
+
1097
+ def Gather(x:Tensor, indices:Tensor, axis:int=0):
1098
+ if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
1099
+ ret_shape = x.shape[:axis] + indices.shape + x.shape[axis+1:]
1100
+ if indices.ndim > 1: indices = indices.flatten()
1101
+ index_consts = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
1102
+ index_consts = [x.shape[axis]+i if i<0 else i for i in index_consts]
1103
+ args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x.shape)] for i in index_consts]
1104
+ return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
1105
+ # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
1106
+ return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
1107
+ def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
1108
+
1109
+ def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
1110
+ if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
1111
+ x_shape, i_shape = x.shape, indices.shape
1112
+ b = math.prod(x.shape[dim] for dim in range(batch_dims))
1113
+ # NOTE: each batched dim of both input and indices are equal
1114
+ x = x.reshape(b, *x.shape[batch_dims:])
1115
+ indices = indices.reshape(b, *indices.shape[batch_dims:])
1116
+ b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
1117
+ ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
1118
+ return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
1119
+ def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
1120
+ assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
1121
+ x = x.contiguous()
1122
+ for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
1123
+ i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
1124
+ u = u.squeeze(0)
1125
+ if reduction == "none": x[i] = u
1126
+ elif reduction == "add": x[i] += u
1127
+ elif reduction == "mul": x[i] *= u
1128
+ else: raise NotImplementedError("reduction doesn't support max or min")
1129
+ return x
1130
+
1131
+ def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
1132
+ indices = (indices < 0).where(x.shape[axis], 0) + indices
1133
+ if reduction == "none": return x.scatter(axis, indices, updates)
1134
+ reduction_ = cast(Literal["sum", "prod", "amin", "amax"], {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}[reduction])
1135
+ return x.scatter_reduce(axis, indices, updates, reduction_)
1136
+ def GatherElements(x:Tensor, indices:Tensor, axis:int):
1137
+ indices = (indices < 0).where(x.shape[axis], 0) + indices
1138
+ return x.gather(axis, indices)
1139
+
1140
+ def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
1141
+ if axis is None:
1142
+ inp = inp.flatten()
1143
+ axis = 0
1144
+ axis = inp._resolve_dim(axis)
1145
+ con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python
1146
+ return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
1147
+
1148
+ # ***** Quantization Ops *****
1149
+ def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
1150
+ if isinstance(y_zero_point, Tensor): out_dtype = y_zero_point.dtype
1151
+ elif output_dtype != 0: out_dtype = dtype_fallback(OnnxDataType(output_dtype).to_dtype(), "QuantizeLinear op")
1152
+ else: out_dtype = dtypes.uint8
1153
+ y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
1154
+ if out_dtype == dtypes.uchar:
1155
+ # this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
1156
+ ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
1157
+ else:
1158
+ ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
1159
+ return ret.contiguous()
1160
+
1161
+ def DynamicQuantizeLinear(x: Tensor):
1162
+ # only support uint8
1163
+ qmin, qmax = dtypes.min(dtypes.uint8), dtypes.max(dtypes.uint8)
1164
+ scale = (x.max().maximum(0) + ((-x).max()).maximum(0)) / (qmax - qmin)
1165
+ zero_point = _clamp_cast((qmin - x.min() / scale).round(), dtypes.uint8)
1166
+ y = _clamp_cast((x / scale).round() + zero_point, dtypes.uint8)
1167
+ return y, scale, zero_point
1168
+
1169
+ def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
1170
+ x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
1171
+ return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
1172
+
1173
+ def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor, w:Tensor, w_scale:Tensor, w_zero_point:Tensor, y_scale:Tensor,
1174
+ y_zero_point:Tensor, B:Tensor|None=None, **opts):
1175
+ return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
1176
+
1177
+ def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, y_scale:Tensor,
1178
+ y_zero_point:Tensor) -> Tensor:
1179
+ return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
1180
+
1181
+ def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
1182
+ return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
1183
+
1184
+ def QLinearMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
1185
+ return _qlinearop_quantized(Tensor.mul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
1186
+
1187
+ def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
1188
+ assert channels_last == 0, "TODO NHWC"
1189
+ return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
1190
+
1191
+ def ConvInteger(x: Tensor, w: Tensor, x_zero_point:Tensor = Tensor(0), w_zero_point:Tensor = Tensor(0), B: Tensor | None = None, **opts) -> Tensor:
1192
+ return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
1193
+
1194
+ def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor = Tensor(0), b_zero_point: Tensor = Tensor(0)) -> Tensor:
1195
+ return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
1196
+
1197
+ # ***** Training Ops *****
1198
+ # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
1199
+ @_onnx_training(3)
1200
+ def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
1201
+ X, G, H = (i.detach() for i in inputs)
1202
+ grad = norm_coefficient * X + G
1203
+ H.assign(H + grad.square())
1204
+ up = grad / (H.sqrt() + epsilon)
1205
+ r = R / (1 + T * decay_factor)
1206
+ X.assign(X.detach() - r * up)
1207
+ return [X, H]
1208
+
1209
+ @_onnx_training(4)
1210
+ def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
1211
+ norm_coefficient_post:float=0.0):
1212
+ from tinygrad.nn.optim import Adam as TinyAdam
1213
+ X, G, V, H = inputs
1214
+ G, V, H = G.detach(), V.detach(), H.detach()
1215
+ X.grad = norm_coefficient * X.detach() + G
1216
+ opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
1217
+ opt.m, opt.v, opt.lr = [V], [H], R
1218
+ # need no-op for m_hat and v_hat if T == 0
1219
+ if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
1220
+ else:
1221
+ # `T-1` since it's applied again at the start of `_step`
1222
+ opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
1223
+ opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
1224
+ opt.step()
1225
+ X = (1 - norm_coefficient_post) * X
1226
+ return [X, V, H]
1227
+
1228
+ @_onnx_training(3)
1229
+ def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
1230
+ X, G, V = (i.detach() for i in inputs)
1231
+ grad = norm_coefficient * X + G
1232
+ # NOTE: this beta_adjusted term makes it so we can't use SGD for nesterov
1233
+ beta_adjusted = beta if T > 0 else 1
1234
+ V.assign(alpha * V + grad * beta_adjusted)
1235
+ X.assign(X - R * (V if mode == "standard" else (grad + alpha * V)))
1236
+ return [X, V]
1237
+
1238
+ def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
1239
+ intermediate_tensors[y].backward()
1240
+ return tuple([t.grad for t in inputs])
1241
+
1242
+ return {
1243
+ # Tensor ops
1244
+ **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
1245
+ "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
1246
+ "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Round", "Erf")},
1247
+ # Implemented ops
1248
+ **{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()},
1249
+ # Version ops
1250
+ **{name:obj for name,obj in locals().items() if isinstance(obj, dict)},
1251
+ }
1252
+
1253
+ onnx_ops = get_onnx_ops()