onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2fx/ops/tensor.py ADDED
@@ -0,0 +1,1161 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Tensor manipulation operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..exceptions import ConversionError
10
+ from ..op_registry import register
11
+ from ..utils.attributes import get_attribute
12
+ from ..utils.names import sanitize_name
13
+ from ..utils.op_helpers import get_optional_input
14
+
15
+ if TYPE_CHECKING:
16
+ from ..graph_builder import GraphBuilder
17
+
18
+
19
+ # =============================================================================
20
+ # Constant and Identity operators
21
+ # =============================================================================
22
+
23
+
24
+ @register("Constant")
25
+ def constant(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
26
+ """Create a constant tensor."""
27
+ value = get_attribute(node, "value", tensor_loader=builder.load_tensor)
28
+ if value is None:
29
+ value_float = get_attribute(node, "value_float")
30
+ if value_float is not None:
31
+ value = torch.tensor(value_float, dtype=torch.float32)
32
+ value_int = get_attribute(node, "value_int")
33
+ if value_int is not None:
34
+ value = torch.tensor(value_int, dtype=torch.int64)
35
+ value_floats = get_attribute(node, "value_floats")
36
+ if value_floats is not None:
37
+ value = torch.tensor(value_floats, dtype=torch.float32)
38
+ value_ints = get_attribute(node, "value_ints")
39
+ if value_ints is not None:
40
+ value = torch.tensor(value_ints, dtype=torch.int64)
41
+
42
+ if value is None:
43
+ raise ConversionError(
44
+ "Constant node has no value attribute",
45
+ node_name=node.name,
46
+ op_type="Constant",
47
+ )
48
+
49
+ output_name = node.output[0]
50
+ safe_name = sanitize_name(output_name)
51
+ builder._constants[safe_name] = value
52
+
53
+ fx_node = builder.graph.get_attr(safe_name)
54
+ fx_node.meta["onnx_op_type"] = "Constant"
55
+ fx_node.meta["onnx_name"] = output_name
56
+ fx_node.meta["onnx_shape"] = list(value.shape) if hasattr(value, "shape") else []
57
+ fx_node.meta["onnx_dtype"] = value.dtype if hasattr(value, "dtype") else None
58
+ return fx_node
59
+
60
+
61
+ @register("Identity")
62
+ def identity(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
63
+ """Identity operator - returns input unchanged."""
64
+ return builder.get_value(node.input[0])
65
+
66
+
67
+ @register("Cast")
68
+ def cast(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
69
+ """Cast tensor to a different data type."""
70
+ from ..utils.dtype import onnx_dtype_to_torch
71
+
72
+ x = builder.get_value(node.input[0])
73
+ to_dtype = get_attribute(node, "to")
74
+ torch_dtype = onnx_dtype_to_torch(to_dtype)
75
+
76
+ if torch_dtype is None:
77
+ raise ConversionError(
78
+ f"Unsupported cast target dtype: {to_dtype}",
79
+ node_name=node.name,
80
+ op_type="Cast",
81
+ )
82
+
83
+ return builder.call_function(lambda t, dtype: t.to(dtype), args=(x, torch_dtype))
84
+
85
+
86
+ @register("CastLike")
87
+ def cast_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
88
+ """Cast tensor to the same data type as the target tensor."""
89
+ x = builder.get_value(node.input[0])
90
+ target = builder.get_value(node.input[1])
91
+
92
+ def _cast_like(t, target):
93
+ return t.to(target.dtype)
94
+
95
+ return builder.call_function(_cast_like, args=(x, target))
96
+
97
+
98
+ # =============================================================================
99
+ # Shape manipulation operators
100
+ # =============================================================================
101
+
102
+
103
+ @register("Reshape")
104
+ def reshape(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
105
+ """Reshape tensor to a new shape.
106
+
107
+ ONNX Reshape semantics:
108
+ - A value of 0 means the dimension is unchanged from the input shape
109
+ - A value of -1 means the dimension is inferred from the remaining elements
110
+ """
111
+ x = builder.get_value(node.input[0])
112
+ shape = builder.get_value(node.input[1])
113
+
114
+ # Check allowzero attribute (default is 0, meaning 0 copies from input)
115
+ allowzero = get_attribute(node, "allowzero", 0)
116
+
117
+ def _reshape(t, shape, allowzero):
118
+ if isinstance(shape, torch.Tensor):
119
+ shape = shape.tolist()
120
+ else:
121
+ shape = list(shape)
122
+
123
+ # Convert to integers (shape may contain floats from tensor operations)
124
+ shape = [int(d) for d in shape]
125
+
126
+ # ONNX: if allowzero=0, a value of 0 in shape means copy from input
127
+ if not allowzero:
128
+ for i, dim in enumerate(shape):
129
+ if dim == 0:
130
+ if i < t.dim():
131
+ shape[i] = t.shape[i]
132
+
133
+ return torch.reshape(t, tuple(shape))
134
+
135
+ return builder.call_function(_reshape, args=(x, shape, allowzero))
136
+
137
+
138
+ @register("Transpose")
139
+ def transpose(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
140
+ """Transpose tensor dimensions."""
141
+ x = builder.get_value(node.input[0])
142
+ perm = get_attribute(node, "perm")
143
+ if perm is None:
144
+ # Default: reverse all dimensions
145
+ return builder.call_function(lambda t: t.T, args=(x,))
146
+ return builder.call_function(torch.permute, args=(x, perm))
147
+
148
+
149
+ @register("Squeeze", since_version=1)
150
+ def squeeze_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
151
+ """Remove dimensions of size 1 for opset 1-12.
152
+
153
+ In opset < 13, axes is an attribute.
154
+ """
155
+ x = builder.get_value(node.input[0])
156
+
157
+ axes = get_attribute(node, "axes")
158
+ if axes is not None:
159
+ # Squeeze specific dimensions
160
+ result = x
161
+ # Sort in reverse to maintain correct indices after each squeeze
162
+ for axis in sorted(axes, reverse=True):
163
+ result = builder.call_function(
164
+ torch.squeeze, args=(result,), kwargs={"dim": axis}
165
+ )
166
+ return result
167
+ return builder.call_function(torch.squeeze, args=(x,))
168
+
169
+
170
+ @register("Squeeze", since_version=13)
171
+ def squeeze_v13(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
172
+ """Remove dimensions of size 1 for opset 13+.
173
+
174
+ In opset 13+, axes is an optional input (not attribute).
175
+ """
176
+ x = builder.get_value(node.input[0])
177
+
178
+ # axes is an optional input in opset 13+
179
+ axes = get_optional_input(builder, node, 1)
180
+ if axes is not None:
181
+
182
+ def _squeeze_dynamic(t, axes):
183
+ if isinstance(axes, torch.Tensor):
184
+ axes = axes.tolist()
185
+ if isinstance(axes, list):
186
+ if len(axes) == 1:
187
+ return torch.squeeze(t, dim=axes[0])
188
+ # Multiple axes - squeeze in reverse order
189
+ result = t
190
+ for axis in sorted(axes, reverse=True):
191
+ result = torch.squeeze(result, dim=int(axis))
192
+ return result
193
+ return torch.squeeze(t, dim=int(axes))
194
+
195
+ return builder.call_function(_squeeze_dynamic, args=(x, axes))
196
+
197
+ # No axes input - squeeze all dimensions of size 1
198
+ return builder.call_function(torch.squeeze, args=(x,))
199
+
200
+
201
+ @register("Unsqueeze", since_version=1)
202
+ def unsqueeze_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
203
+ """Insert dimensions of size 1 for opset 1-12.
204
+
205
+ In opset < 13, axes is a required attribute.
206
+ """
207
+ x = builder.get_value(node.input[0])
208
+
209
+ axes = get_attribute(node, "axes")
210
+ if axes is None:
211
+ raise ConversionError(
212
+ "Unsqueeze requires axes attribute in opset < 13",
213
+ node_name=node.name,
214
+ op_type="Unsqueeze",
215
+ )
216
+
217
+ # Handle single axis
218
+ if isinstance(axes, int):
219
+ return builder.call_function(torch.unsqueeze, args=(x, axes))
220
+
221
+ # Handle multiple axes - unsqueeze in sorted order
222
+ result = x
223
+ for axis in sorted(axes):
224
+ result = builder.call_function(torch.unsqueeze, args=(result, axis))
225
+ return result
226
+
227
+
228
+ @register("Unsqueeze", since_version=13)
229
+ def unsqueeze_v13(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
230
+ """Insert dimensions of size 1 for opset 13+.
231
+
232
+ In opset 13+, axes is a required input (not attribute).
233
+ """
234
+ x = builder.get_value(node.input[0])
235
+
236
+ if len(node.input) < 2 or not node.input[1]:
237
+ raise ConversionError(
238
+ "Unsqueeze requires axes input in opset 13+",
239
+ node_name=node.name,
240
+ op_type="Unsqueeze",
241
+ )
242
+
243
+ axes = builder.get_value(node.input[1])
244
+
245
+ def _unsqueeze_dynamic(t, axes):
246
+ if isinstance(axes, torch.Tensor):
247
+ axes = axes.tolist()
248
+ if isinstance(axes, int):
249
+ return torch.unsqueeze(t, axes)
250
+ # Handle multiple axes - unsqueeze in sorted order
251
+ result = t
252
+ for axis in sorted(axes):
253
+ result = torch.unsqueeze(result, int(axis))
254
+ return result
255
+
256
+ return builder.call_function(_unsqueeze_dynamic, args=(x, axes))
257
+
258
+
259
+ @register("Flatten")
260
+ def flatten(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
261
+ """Flatten tensor to 2D.
262
+
263
+ ONNX Flatten reshapes the input tensor to a 2D tensor:
264
+ - First dimension = product of dimensions from 0 to axis-1
265
+ - Second dimension = product of dimensions from axis to end
266
+ """
267
+ x = builder.get_value(node.input[0])
268
+ axis = get_attribute(node, "axis", 1)
269
+
270
+ def _flatten_to_2d(t, axis):
271
+ shape = t.shape
272
+ # Handle negative axis
273
+ if axis < 0:
274
+ axis = len(shape) + axis
275
+ # Compute dimensions
276
+ dim0 = 1
277
+ for i in range(axis):
278
+ dim0 *= shape[i]
279
+ dim1 = 1
280
+ for i in range(axis, len(shape)):
281
+ dim1 *= shape[i]
282
+ return t.reshape(dim0, dim1)
283
+
284
+ return builder.call_function(_flatten_to_2d, args=(x, axis))
285
+
286
+
287
+ @register("Expand")
288
+ def expand(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
289
+ """Broadcast tensor to a new shape.
290
+
291
+ ONNX Expand uses bidirectional broadcasting, which means:
292
+ - If target dim is 1, keep the original dimension
293
+ - The output shape is max(input_dim, target_dim) for each dimension
294
+ """
295
+ x = builder.get_value(node.input[0])
296
+ shape = builder.get_value(node.input[1])
297
+
298
+ def _expand(t, shape):
299
+ if isinstance(shape, torch.Tensor):
300
+ shape = tuple(int(s) for s in shape.tolist())
301
+ # Use broadcast_shapes to compute the actual broadcast shape
302
+ # This handles cases where target_dim=1 should preserve input_dim
303
+ broadcast_shape = torch.broadcast_shapes(t.shape, shape)
304
+ return t.expand(broadcast_shape)
305
+
306
+ return builder.call_function(_expand, args=(x, shape))
307
+
308
+
309
+ # =============================================================================
310
+ # Concatenation and splitting operators
311
+ # =============================================================================
312
+
313
+
314
+ @register("Concat")
315
+ def concat(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
316
+ """Concatenate tensors along an axis."""
317
+ inputs = [builder.get_value(name) for name in node.input]
318
+ axis = get_attribute(node, "axis", 0)
319
+ return builder.call_function(torch.cat, args=(inputs,), kwargs={"dim": axis})
320
+
321
+
322
+ @register("Split", since_version=1)
323
+ def split_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
324
+ """Split tensor into chunks for opset 1-12.
325
+
326
+ In opset < 13, split sizes is an optional attribute.
327
+ """
328
+ x = builder.get_value(node.input[0])
329
+ axis = get_attribute(node, "axis", 0)
330
+
331
+ split_attr = get_attribute(node, "split")
332
+ if split_attr is not None:
333
+ result = builder.call_function(torch.split, args=(x, list(split_attr), axis))
334
+ else:
335
+ # Default: split into equal parts based on number of outputs
336
+ result = builder.call_function(torch.chunk, args=(x, len(node.output), axis))
337
+
338
+ # Handle multiple outputs
339
+ for i, output_name in enumerate(node.output):
340
+ if output_name:
341
+ idx_node = builder.call_function(lambda t, idx: t[idx], args=(result, i))
342
+ builder.env[output_name] = idx_node
343
+
344
+ return result
345
+
346
+
347
+ @register("Split", since_version=13)
348
+ def split_v13(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
349
+ """Split tensor into chunks for opset 13+.
350
+
351
+ In opset 13+, split sizes is an optional input.
352
+ In opset 18+, num_outputs attribute was added.
353
+ """
354
+ x = builder.get_value(node.input[0])
355
+ axis = get_attribute(node, "axis", 0)
356
+ num_outputs = get_attribute(node, "num_outputs") # Added in opset 18
357
+
358
+ # split sizes is an optional input in opset 13+
359
+ split_sizes = get_optional_input(builder, node, 1)
360
+ if split_sizes is not None:
361
+
362
+ def _split_with_sizes(t, sizes, dim):
363
+ if hasattr(sizes, "tolist"):
364
+ sizes = sizes.tolist()
365
+ return torch.split(t, sizes, dim)
366
+
367
+ result = builder.call_function(_split_with_sizes, args=(x, split_sizes, axis))
368
+ elif num_outputs is not None:
369
+ # Split into equal parts using num_outputs (opset 18+)
370
+ result = builder.call_function(torch.chunk, args=(x, num_outputs, axis))
371
+ else:
372
+ # Default: split into equal parts based on number of outputs
373
+ result = builder.call_function(torch.chunk, args=(x, len(node.output), axis))
374
+
375
+ # Handle multiple outputs
376
+ for i, output_name in enumerate(node.output):
377
+ if output_name:
378
+ idx_node = builder.call_function(lambda t, idx: t[idx], args=(result, i))
379
+ builder.env[output_name] = idx_node
380
+
381
+ return result
382
+
383
+
384
+ # =============================================================================
385
+ # Slicing and indexing operators
386
+ # =============================================================================
387
+
388
+
389
+ @register("Slice", since_version=1)
390
+ def slice_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
391
+ """Slice tensor along axes (opset 1-9).
392
+
393
+ In opset < 10, starts, ends, and axes are attributes.
394
+ """
395
+ x = builder.get_value(node.input[0])
396
+ starts = get_attribute(node, "starts")
397
+ ends = get_attribute(node, "ends")
398
+ axes = get_attribute(node, "axes")
399
+ # Note: steps attribute doesn't exist in opset < 10
400
+
401
+ return builder.call_function(
402
+ _dynamic_slice,
403
+ args=(x, list(starts), list(ends), list(axes) if axes else None, None),
404
+ )
405
+
406
+
407
+ @register("Slice", since_version=10)
408
+ def slice_v10(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
409
+ """Slice tensor along axes (opset 10+).
410
+
411
+ In opset 10+, starts, ends, axes, and steps are inputs.
412
+ """
413
+ x = builder.get_value(node.input[0])
414
+ starts = builder.get_value(node.input[1])
415
+ ends = builder.get_value(node.input[2])
416
+
417
+ axes = get_optional_input(builder, node, 3)
418
+ steps = get_optional_input(builder, node, 4)
419
+
420
+ # Use torch.narrow for simple cases, or dynamic slicing
421
+ return builder.call_function(
422
+ _dynamic_slice,
423
+ args=(x, starts, ends, axes, steps),
424
+ )
425
+
426
+
427
+ def _dynamic_slice(x, starts, ends, axes=None, steps=None):
428
+ """Helper function for dynamic slicing with support for negative steps."""
429
+ import torch
430
+
431
+ # Convert to lists if tensors
432
+ if isinstance(starts, torch.Tensor):
433
+ starts = starts.tolist()
434
+ if isinstance(ends, torch.Tensor):
435
+ ends = ends.tolist()
436
+ if axes is not None and isinstance(axes, torch.Tensor):
437
+ axes = axes.tolist()
438
+ if steps is not None and isinstance(steps, torch.Tensor):
439
+ steps = steps.tolist()
440
+
441
+ if axes is None:
442
+ axes = list(range(len(starts)))
443
+ if steps is None:
444
+ steps = [1] * len(starts)
445
+
446
+ # Handle negative steps by flipping, slicing with positive step, then flipping back
447
+ # We process each axis separately to handle this correctly
448
+ result = x
449
+ for start, end, axis, step in zip(starts, ends, axes, steps):
450
+ dim_size = result.size(axis)
451
+
452
+ if step < 0:
453
+ # For negative steps, ONNX semantics:
454
+ # start defaults to dim_size - 1, end defaults to -dim_size - 1
455
+ # We iterate from start down to end (exclusive) with abs(step)
456
+
457
+ # Handle special ONNX sentinel values and negative indices
458
+ if start >= dim_size:
459
+ start = dim_size - 1
460
+ elif start < 0:
461
+ start = max(-1, dim_size + start)
462
+
463
+ if end < -dim_size:
464
+ end = -1 # Sentinel for "before the beginning"
465
+ elif end < 0:
466
+ end = dim_size + end
467
+
468
+ # For negative step: we go from start down to end (exclusive)
469
+ # Example: start=20, end=0, step=-1 means indices [20, 19, ..., 1]
470
+ # Flip the axis, compute equivalent positive slice, then flip back
471
+
472
+ # Compute the actual range of elements we want
473
+ # start > end for negative step, so we want indices from end+1 to start (inclusive)
474
+ actual_start = end + 1 if end >= 0 else 0
475
+ actual_end = start + 1 if start >= 0 else dim_size
476
+
477
+ # Clamp to valid range
478
+ actual_start = max(0, min(actual_start, dim_size))
479
+ actual_end = max(0, min(actual_end, dim_size))
480
+
481
+ if actual_start >= actual_end:
482
+ # Empty slice
483
+ slices = [slice(None)] * result.dim()
484
+ slices[axis] = slice(0, 0)
485
+ result = result[tuple(slices)]
486
+ else:
487
+ # First slice to get the range
488
+ slices = [slice(None)] * result.dim()
489
+ slices[axis] = slice(int(actual_start), int(actual_end))
490
+ result = result[tuple(slices)]
491
+
492
+ # Then flip to reverse the order
493
+ result = torch.flip(result, dims=[axis])
494
+
495
+ # Apply striding if step < -1
496
+ if step < -1:
497
+ abs_step = -step
498
+ slices = [slice(None)] * result.dim()
499
+ slices[axis] = slice(None, None, int(abs_step))
500
+ result = result[tuple(slices)]
501
+ else:
502
+ # Positive step - original logic
503
+ if start < 0:
504
+ start = max(0, dim_size + start)
505
+ if end < 0:
506
+ end = max(0, dim_size + end)
507
+ # Clamp to valid range
508
+ start = min(start, dim_size)
509
+ end = min(end, dim_size)
510
+ slices = [slice(None)] * result.dim()
511
+ slices[axis] = slice(int(start), int(end), int(step))
512
+ result = result[tuple(slices)]
513
+
514
+ return result
515
+
516
+
517
+ @register("Gather")
518
+ def gather(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
519
+ """Gather elements along an axis.
520
+
521
+ ONNX Gather behavior:
522
+ - output shape = data.shape[:axis] + indices.shape + data.shape[axis+1:]
523
+ - If indices is a scalar, the axis dimension is removed from the output
524
+ - If indices is a multi-dimensional tensor, indices.shape replaces the axis dimension
525
+ """
526
+ x = builder.get_value(node.input[0])
527
+ indices = builder.get_value(node.input[1])
528
+ axis = get_attribute(node, "axis", 0)
529
+
530
+ def _gather(data, indices, axis):
531
+ indices = indices.long()
532
+
533
+ if axis < 0:
534
+ axis = data.dim() + axis
535
+
536
+ # Handle scalar indices - need to squeeze the dimension after gather
537
+ if indices.ndim == 0:
538
+ # Scalar index: select single element along axis, removing that dimension
539
+ return torch.index_select(data, axis, indices.unsqueeze(0)).squeeze(axis)
540
+
541
+ # For multi-dimensional indices, we need proper ONNX Gather semantics
542
+ # Move the gather axis to position 0
543
+ if axis != 0:
544
+ data = data.movedim(axis, 0)
545
+
546
+ # Flatten indices for indexing
547
+ indices_flat = indices.flatten()
548
+ gathered = data[indices_flat] # [num_indices, ...]
549
+
550
+ # Reshape to restore indices dimensions
551
+ new_shape = list(indices.shape) + list(data.shape[1:])
552
+ gathered = gathered.view(new_shape)
553
+
554
+ # Move the original leading dimensions back
555
+ if axis != 0:
556
+ # Permute dimensions to restore original order
557
+ # Current: [idx..., prefix..., suffix...]
558
+ # Target: [prefix..., idx..., suffix...]
559
+ num_idx_dims = indices.ndim
560
+ num_prefix_dims = axis
561
+
562
+ perm = (
563
+ list(range(num_idx_dims, num_idx_dims + num_prefix_dims))
564
+ + list(range(num_idx_dims))
565
+ + list(range(num_idx_dims + num_prefix_dims, gathered.ndim))
566
+ )
567
+ gathered = gathered.permute(perm)
568
+
569
+ return gathered
570
+
571
+ return builder.call_function(_gather, args=(x, indices, axis))
572
+
573
+
574
+ @register("GatherElements")
575
+ def gather_elements(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
576
+ """Gather elements using indices with same rank as input."""
577
+ x = builder.get_value(node.input[0])
578
+ indices = builder.get_value(node.input[1])
579
+ axis = get_attribute(node, "axis", 0)
580
+ return builder.call_function(torch.gather, args=(x, axis, indices))
581
+
582
+
583
+ @register("GatherND")
584
+ def gather_nd(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
585
+ """Gather slices using n-dimensional indices."""
586
+ x = builder.get_value(node.input[0])
587
+ indices = builder.get_value(node.input[1])
588
+ batch_dims = get_attribute(node, "batch_dims", 0)
589
+
590
+ def _gather_nd(data, indices, batch_dims=0):
591
+ # Simplified GatherND implementation
592
+ indices = indices.long()
593
+ if batch_dims == 0:
594
+ # Flatten indices to list of coordinate tuples
595
+ idx_shape = indices.shape
596
+ indices_flat = indices.reshape(-1, idx_shape[-1])
597
+ result = torch.stack([data[tuple(idx)] for idx in indices_flat])
598
+ return result.reshape(idx_shape[:-1] + data.shape[indices.shape[-1] :])
599
+ else:
600
+ raise NotImplementedError("batch_dims > 0 not yet supported for GatherND")
601
+
602
+ return builder.call_function(_gather_nd, args=(x, indices, batch_dims))
603
+
604
+
605
+ @register("ScatterElements")
606
+ def scatter_elements(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
607
+ """Scatter elements using indices."""
608
+ x = builder.get_value(node.input[0])
609
+ indices = builder.get_value(node.input[1])
610
+ updates = builder.get_value(node.input[2])
611
+ axis = get_attribute(node, "axis", 0)
612
+ reduction = get_attribute(node, "reduction", "none")
613
+
614
+ def _scatter_elements(data, axis, idx, upd, reduction):
615
+ # Handle negative axis
616
+ if axis < 0:
617
+ axis = data.ndim + axis
618
+
619
+ # Handle negative indices by converting to positive
620
+ dim_size = data.shape[axis]
621
+ idx = torch.where(idx < 0, idx + dim_size, idx)
622
+
623
+ # Map ONNX reduction to PyTorch reduce argument
624
+ if reduction == "none":
625
+ return data.scatter(axis, idx, upd)
626
+ elif reduction == "add":
627
+ return data.scatter_add(axis, idx, upd)
628
+ elif reduction == "mul":
629
+ return data.scatter_reduce(axis, idx, upd, reduce="prod")
630
+ elif reduction == "max":
631
+ return data.scatter_reduce(axis, idx, upd, reduce="amax")
632
+ elif reduction == "min":
633
+ return data.scatter_reduce(axis, idx, upd, reduce="amin")
634
+ else:
635
+ raise ValueError(f"Unsupported reduction: {reduction}")
636
+
637
+ return builder.call_function(
638
+ _scatter_elements, args=(x, axis, indices, updates, reduction)
639
+ )
640
+
641
+
642
+ @register("Scatter")
643
+ def scatter(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
644
+ """Scatter (deprecated, replaced by ScatterElements in opset 11)."""
645
+ x = builder.get_value(node.input[0])
646
+ indices = builder.get_value(node.input[1])
647
+ updates = builder.get_value(node.input[2])
648
+ axis = get_attribute(node, "axis", 0)
649
+
650
+ def _scatter(data, axis, idx, upd):
651
+ # Handle negative axis
652
+ if axis < 0:
653
+ axis = data.ndim + axis
654
+
655
+ # Handle negative indices by converting to positive
656
+ dim_size = data.shape[axis]
657
+ idx = torch.where(idx < 0, idx + dim_size, idx)
658
+
659
+ return data.scatter(axis, idx, upd)
660
+
661
+ return builder.call_function(_scatter, args=(x, axis, indices, updates))
662
+
663
+
664
+ # =============================================================================
665
+ # Tiling and padding operators
666
+ # =============================================================================
667
+
668
+
669
+ @register("Tile")
670
+ def tile(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
671
+ """Tile tensor by repeating."""
672
+ x = builder.get_value(node.input[0])
673
+ repeats = builder.get_value(node.input[1])
674
+
675
+ def _tile(t, reps):
676
+ if isinstance(reps, torch.Tensor):
677
+ reps = tuple(int(r) for r in reps.tolist())
678
+ return torch.tile(t, reps)
679
+
680
+ return builder.call_function(_tile, args=(x, repeats))
681
+
682
+
683
+ def _pad_impl(x, pads, mode, constant_value):
684
+ """Helper function for Pad operator.
685
+
686
+ Converts ONNX pad format to PyTorch format and applies padding.
687
+ ONNX: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
688
+ PyTorch: [xn_begin, xn_end, ..., x1_begin, x1_end]
689
+ """
690
+ import torch
691
+ import torch.nn.functional as F
692
+
693
+ if isinstance(pads, torch.Tensor):
694
+ pads = pads.tolist()
695
+
696
+ n = len(pads) // 2
697
+ # Reverse and interleave
698
+ torch_pads = []
699
+ for i in range(n - 1, -1, -1):
700
+ torch_pads.extend([int(pads[i]), int(pads[i + n])])
701
+
702
+ mode_map = {"constant": "constant", "reflect": "reflect", "edge": "replicate"}
703
+ torch_mode = mode_map.get(mode, "constant")
704
+
705
+ if torch_mode == "constant":
706
+ return F.pad(x, torch_pads, mode=torch_mode, value=float(constant_value))
707
+
708
+ # For non-constant modes (reflect, replicate), PyTorch only supports padding
709
+ # the last N dimensions. Trim leading zero-padding pairs.
710
+ # torch_pads is ordered as [last_dim_begin, last_dim_end, ..., first_dim_begin, first_dim_end]
711
+ # We need to trim trailing zero pairs (which correspond to first dimensions).
712
+ while len(torch_pads) > 2 and torch_pads[-1] == 0 and torch_pads[-2] == 0:
713
+ torch_pads = torch_pads[:-2]
714
+
715
+ return F.pad(x, torch_pads, mode=torch_mode)
716
+
717
+
718
+ @register("Pad", since_version=1)
719
+ def pad_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
720
+ """Pad tensor (opset 1-10).
721
+
722
+ In opset < 11, pads and value are attributes.
723
+ """
724
+ x = builder.get_value(node.input[0])
725
+ pads = list(get_attribute(node, "pads"))
726
+ mode = get_attribute(node, "mode", "constant")
727
+ constant_value = get_attribute(node, "value", 0.0)
728
+
729
+ return builder.call_function(_pad_impl, args=(x, pads, mode, constant_value))
730
+
731
+
732
+ @register("Pad", since_version=11)
733
+ def pad_v11(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
734
+ """Pad tensor (opset 11+).
735
+
736
+ In opset 11+, pads, constant_value, and axes are inputs.
737
+ """
738
+ x = builder.get_value(node.input[0])
739
+ pads = builder.get_value(node.input[1])
740
+ mode = get_attribute(node, "mode", "constant")
741
+
742
+ constant_value = get_optional_input(builder, node, 2, default=0.0)
743
+
744
+ # Note: axes input (opset 18+) is not yet supported
745
+ # If needed, would require reordering pads based on axes
746
+
747
+ return builder.call_function(_pad_impl, args=(x, pads, mode, constant_value))
748
+
749
+
750
+ # =============================================================================
751
+ # Shape operators
752
+ # =============================================================================
753
+
754
+
755
+ @register("Shape")
756
+ def shape(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
757
+ """Get tensor shape."""
758
+ x = builder.get_value(node.input[0])
759
+ start = get_attribute(node, "start", 0)
760
+ end = get_attribute(node, "end")
761
+
762
+ def _get_shape(t, start, end):
763
+ shape = torch.tensor(t.shape, dtype=torch.int64)
764
+ if end is None:
765
+ return shape[start:]
766
+ return shape[start:end]
767
+
768
+ return builder.call_function(_get_shape, args=(x, start, end))
769
+
770
+
771
+ @register("Size")
772
+ def size(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
773
+ """Get total number of elements."""
774
+ x = builder.get_value(node.input[0])
775
+ return builder.call_function(
776
+ lambda t: torch.tensor(t.numel(), dtype=torch.int64), args=(x,)
777
+ )
778
+
779
+
780
+ @register("ConstantOfShape")
781
+ def constant_of_shape(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
782
+ """Create tensor filled with constant value."""
783
+ shape = builder.get_value(node.input[0])
784
+ value = get_attribute(node, "value", tensor_loader=builder.load_tensor)
785
+
786
+ if value is not None:
787
+ fill_value = (
788
+ value.item() if hasattr(value, "item") else float(value.flatten()[0])
789
+ )
790
+ dtype = value.dtype
791
+ else:
792
+ fill_value = 0.0
793
+ dtype = torch.float32
794
+
795
+ def _constant_of_shape(shape, fill_value, dtype):
796
+ if isinstance(shape, torch.Tensor):
797
+ shape = shape.tolist()
798
+ return torch.full(shape, fill_value, dtype=dtype)
799
+
800
+ return builder.call_function(_constant_of_shape, args=(shape, fill_value, dtype))
801
+
802
+
803
+ # =============================================================================
804
+ # Tensor generation operators
805
+ # =============================================================================
806
+
807
+
808
+ @register("Range")
809
+ def range_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
810
+ """Generate a range of values."""
811
+ start = builder.get_value(node.input[0])
812
+ limit = builder.get_value(node.input[1])
813
+ delta = builder.get_value(node.input[2])
814
+
815
+ def _range(start, limit, delta):
816
+ # Extract scalar values
817
+ st = start.item() if isinstance(start, torch.Tensor) else start
818
+ lim = limit.item() if isinstance(limit, torch.Tensor) else limit
819
+ dlt = delta.item() if isinstance(delta, torch.Tensor) else delta
820
+ dtype = start.dtype if isinstance(start, torch.Tensor) else torch.float32
821
+ return torch.arange(st, lim, dlt, dtype=dtype)
822
+
823
+ return builder.call_function(_range, args=(start, limit, delta))
824
+
825
+
826
+ @register("OneHot")
827
+ def one_hot(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
828
+ """One-hot encoding."""
829
+ indices = builder.get_value(node.input[0])
830
+ depth = builder.get_value(node.input[1])
831
+ values = builder.get_value(node.input[2])
832
+
833
+ axis = get_attribute(node, "axis", -1)
834
+
835
+ def _one_hot(indices, depth, values, axis):
836
+ d = depth.item() if isinstance(depth, torch.Tensor) else depth
837
+ off_value = values[0]
838
+ on_value = values[1]
839
+
840
+ # Create one-hot tensor
841
+ result = torch.nn.functional.one_hot(indices.long(), int(d))
842
+ result = result.to(values.dtype)
843
+
844
+ # Apply on/off values
845
+ result = result * (on_value - off_value) + off_value
846
+
847
+ # Move axis if needed
848
+ if axis != -1 and axis != indices.dim():
849
+ # Permute to move the one-hot dimension to the correct axis
850
+ ndim = result.dim()
851
+ if axis < 0:
852
+ axis = ndim + axis
853
+ perm = list(range(ndim - 1))
854
+ perm.insert(axis, ndim - 1)
855
+ result = result.permute(perm)
856
+
857
+ return result
858
+
859
+ return builder.call_function(_one_hot, args=(indices, depth, values, axis))
860
+
861
+
862
+ @register("NonZero")
863
+ def non_zero(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
864
+ """Find indices of non-zero elements."""
865
+ x = builder.get_value(node.input[0])
866
+
867
+ def _non_zero(x):
868
+ # ONNX returns shape (rank, num_nonzero), PyTorch returns tuple
869
+ result = torch.nonzero(x, as_tuple=False).T
870
+ return result.to(torch.int64)
871
+
872
+ return builder.call_function(_non_zero, args=(x,))
873
+
874
+
875
+ @register("Unique")
876
+ def unique(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
877
+ """Find unique elements."""
878
+ x = builder.get_value(node.input[0])
879
+
880
+ axis = get_attribute(node, "axis")
881
+ sorted_ = get_attribute(node, "sorted", 1)
882
+
883
+ def _unique(x, axis, sorted_):
884
+ if axis is not None:
885
+ return torch.unique(
886
+ x,
887
+ sorted=bool(sorted_),
888
+ return_inverse=True,
889
+ return_counts=True,
890
+ dim=axis,
891
+ )
892
+ return torch.unique(
893
+ x, sorted=bool(sorted_), return_inverse=True, return_counts=True
894
+ )
895
+
896
+ return builder.call_function(_unique, args=(x, axis, sorted_))
897
+
898
+
899
+ @register("Trilu")
900
+ def trilu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
901
+ """Triangular part of matrix."""
902
+ x = builder.get_value(node.input[0])
903
+
904
+ k = get_optional_input(builder, node, 1, default=0)
905
+
906
+ upper = get_attribute(node, "upper", 1)
907
+
908
+ def _trilu(x, k, upper):
909
+ k_val = k.item() if isinstance(k, torch.Tensor) else k
910
+ if upper:
911
+ return torch.triu(x, diagonal=int(k_val))
912
+ return torch.tril(x, diagonal=int(k_val))
913
+
914
+ return builder.call_function(_trilu, args=(x, k, upper))
915
+
916
+
917
+ @register("EyeLike")
918
+ def eye_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
919
+ """Create an identity matrix with the same shape as input.
920
+
921
+ Note: The dtype attribute is ignored; output uses input tensor's dtype.
922
+ """
923
+ x = builder.get_value(node.input[0])
924
+ k = get_attribute(node, "k", 0)
925
+
926
+ def _eye_like(t: torch.Tensor, diag: int) -> torch.Tensor:
927
+ n, m = t.shape[-2], t.shape[-1]
928
+ eye = torch.eye(n, m, dtype=t.dtype, device=t.device)
929
+ if diag != 0:
930
+ eye = torch.diagonal(eye, offset=diag)
931
+ return eye
932
+
933
+ return builder.call_function(_eye_like, args=(x, k))
934
+
935
+
936
+ # =============================================================================
937
+ # Scatter ND operators
938
+ # =============================================================================
939
+
940
+
941
+ @register("ScatterND")
942
+ def scatter_nd(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
943
+ """Scatter updates into data at indices."""
944
+ data = builder.get_value(node.input[0])
945
+ indices = builder.get_value(node.input[1])
946
+ updates = builder.get_value(node.input[2])
947
+
948
+ reduction = get_attribute(node, "reduction", "none")
949
+
950
+ def _scatter_nd_none(
951
+ d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
952
+ ) -> torch.Tensor:
953
+ output = d.clone()
954
+ idx = idx.long()
955
+
956
+ idx_shape = idx.shape[:-1]
957
+ last_dim = idx.shape[-1]
958
+
959
+ flat_idx = idx.reshape(-1, last_dim)
960
+ flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
961
+
962
+ for i in range(flat_idx.shape[0]):
963
+ data_idx = tuple(flat_idx[i].tolist())
964
+ output[data_idx] = flat_upd[i]
965
+
966
+ return output
967
+
968
+ def _scatter_nd_add(
969
+ d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
970
+ ) -> torch.Tensor:
971
+ output = d.clone()
972
+ idx = idx.long()
973
+
974
+ idx_shape = idx.shape[:-1]
975
+ last_dim = idx.shape[-1]
976
+
977
+ flat_idx = idx.reshape(-1, last_dim)
978
+ flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
979
+
980
+ for i in range(flat_idx.shape[0]):
981
+ data_idx = tuple(flat_idx[i].tolist())
982
+ output[data_idx] = output[data_idx] + flat_upd[i]
983
+
984
+ return output
985
+
986
+ def _scatter_nd_mul(
987
+ d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
988
+ ) -> torch.Tensor:
989
+ output = d.clone()
990
+ idx = idx.long()
991
+
992
+ idx_shape = idx.shape[:-1]
993
+ last_dim = idx.shape[-1]
994
+
995
+ flat_idx = idx.reshape(-1, last_dim)
996
+ flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
997
+
998
+ for i in range(flat_idx.shape[0]):
999
+ data_idx = tuple(flat_idx[i].tolist())
1000
+ output[data_idx] = output[data_idx] * flat_upd[i]
1001
+
1002
+ return output
1003
+
1004
+ def _scatter_nd_max(
1005
+ d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
1006
+ ) -> torch.Tensor:
1007
+ output = d.clone()
1008
+ idx = idx.long()
1009
+
1010
+ idx_shape = idx.shape[:-1]
1011
+ last_dim = idx.shape[-1]
1012
+
1013
+ flat_idx = idx.reshape(-1, last_dim)
1014
+ flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
1015
+
1016
+ for i in range(flat_idx.shape[0]):
1017
+ data_idx = tuple(flat_idx[i].tolist())
1018
+ output[data_idx] = torch.maximum(output[data_idx], flat_upd[i])
1019
+
1020
+ return output
1021
+
1022
+ def _scatter_nd_min(
1023
+ d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
1024
+ ) -> torch.Tensor:
1025
+ output = d.clone()
1026
+ idx = idx.long()
1027
+
1028
+ idx_shape = idx.shape[:-1]
1029
+ last_dim = idx.shape[-1]
1030
+
1031
+ flat_idx = idx.reshape(-1, last_dim)
1032
+ flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
1033
+
1034
+ for i in range(flat_idx.shape[0]):
1035
+ data_idx = tuple(flat_idx[i].tolist())
1036
+ output[data_idx] = torch.minimum(output[data_idx], flat_upd[i])
1037
+
1038
+ return output
1039
+
1040
+ if reduction == "add":
1041
+ return builder.call_function(_scatter_nd_add, args=(data, indices, updates))
1042
+ elif reduction == "mul":
1043
+ return builder.call_function(_scatter_nd_mul, args=(data, indices, updates))
1044
+ elif reduction == "max":
1045
+ return builder.call_function(_scatter_nd_max, args=(data, indices, updates))
1046
+ elif reduction == "min":
1047
+ return builder.call_function(_scatter_nd_min, args=(data, indices, updates))
1048
+ else:
1049
+ return builder.call_function(_scatter_nd_none, args=(data, indices, updates))
1050
+
1051
+
1052
+ # =============================================================================
1053
+ # Select and Compress operators
1054
+ # =============================================================================
1055
+
1056
+
1057
+ @register("Select")
1058
+ def select_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
1059
+ """Select elements based on indices (like advanced indexing)."""
1060
+ data = builder.get_value(node.input[0])
1061
+ indices = builder.get_value(node.input[1])
1062
+
1063
+ return builder.call_function(torch.index_select, args=(data, 0, indices))
1064
+
1065
+
1066
+ @register("Compress")
1067
+ def compress_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
1068
+ """Select elements based on a boolean condition tensor."""
1069
+ data = builder.get_value(node.input[0])
1070
+ condition = builder.get_value(node.input[1])
1071
+
1072
+ axis = get_attribute(node, "axis", None)
1073
+
1074
+ if axis is not None:
1075
+
1076
+ def _compress_axis(d: torch.Tensor, c: torch.Tensor, ax: int) -> torch.Tensor:
1077
+ # Get indices where condition is True
1078
+ indices = torch.nonzero(c, as_tuple=True)[0]
1079
+ return torch.index_select(d, ax, indices)
1080
+
1081
+ return builder.call_function(_compress_axis, args=(data, condition, axis))
1082
+ else:
1083
+ # Flatten and compress
1084
+ def _compress_flat(d: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
1085
+ return d.flatten()[c.flatten().bool()]
1086
+
1087
+ return builder.call_function(_compress_flat, args=(data, condition))
1088
+
1089
+
1090
+ # =============================================================================
1091
+ # TensorScatter operator (for KV cache updates in LLMs)
1092
+ # =============================================================================
1093
+
1094
+
1095
+ @register("TensorScatter", since_version=24)
1096
+ def tensor_scatter(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
1097
+ """TensorScatter for KV cache updates.
1098
+
1099
+ Updates a cache tensor at specified indices along a given axis.
1100
+ Commonly used for key/value cache updates in LLM attention.
1101
+
1102
+ Inputs:
1103
+ past_cache: Cache tensor (batch_size, D1, ..., max_sequence_length, ..., Dn)
1104
+ update: Update tensor (batch_size, D1, ..., sequence_length, ..., Dn)
1105
+ write_indices (optional): Start indices per batch sample (batch_size,)
1106
+
1107
+ Attributes:
1108
+ axis: Sequence dimension (default -2)
1109
+ mode: 'linear' or 'circular' (default 'linear')
1110
+ """
1111
+ past_cache = builder.get_value(node.input[0])
1112
+ update = builder.get_value(node.input[1])
1113
+ write_indices = get_optional_input(builder, node, 2)
1114
+
1115
+ axis = get_attribute(node, "axis", -2)
1116
+ mode = get_attribute(node, "mode", "linear")
1117
+
1118
+ def _tensor_scatter(
1119
+ cache: torch.Tensor,
1120
+ upd: torch.Tensor,
1121
+ write_idx: torch.Tensor | None,
1122
+ ax: int,
1123
+ scatter_mode: str,
1124
+ ) -> torch.Tensor:
1125
+ output = cache.clone()
1126
+
1127
+ # Handle negative axis
1128
+ if ax < 0:
1129
+ ax = cache.ndim + ax
1130
+
1131
+ batch_size = cache.shape[0]
1132
+ max_seq_len = cache.shape[ax]
1133
+ seq_len = upd.shape[ax]
1134
+
1135
+ # Default write_indices to zeros if not provided
1136
+ if write_idx is None:
1137
+ write_idx = torch.zeros(batch_size, dtype=torch.int64, device=cache.device)
1138
+
1139
+ # For each batch element, copy the update into the cache at the specified position
1140
+ for b in range(batch_size):
1141
+ start_idx = int(write_idx[b].item())
1142
+
1143
+ for s in range(seq_len):
1144
+ if scatter_mode == "circular":
1145
+ cache_idx = (start_idx + s) % max_seq_len
1146
+ else:
1147
+ cache_idx = start_idx + s
1148
+
1149
+ # Build the index tuple for the cache and update tensors
1150
+ # For cache: (b, D1, ..., cache_idx, ..., Dn)
1151
+ # For update: (b, D1, ..., s, ..., Dn)
1152
+ cache_slices = [b] + [slice(None)] * (ax - 1) + [cache_idx]
1153
+ update_slices = [b] + [slice(None)] * (ax - 1) + [s]
1154
+
1155
+ output[tuple(cache_slices)] = upd[tuple(update_slices)]
1156
+
1157
+ return output
1158
+
1159
+ return builder.call_function(
1160
+ _tensor_scatter, args=(past_cache, update, write_indices, axis, mode)
1161
+ )