onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,524 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Quantization operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..op_registry import register
10
+ from ..utils.attributes import get_attribute
11
+ from ..utils.op_helpers import get_optional_input
12
+
13
+ if TYPE_CHECKING:
14
+ from ..graph_builder import GraphBuilder
15
+
16
+
17
+ # =============================================================================
18
+ # Basic quantization operators
19
+ # =============================================================================
20
+
21
+
22
+ @register("QuantizeLinear")
23
+ def quantize_linear(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
24
+ """Quantize input tensor using scale and zero_point."""
25
+ x = builder.get_value(node.input[0])
26
+ y_scale = builder.get_value(node.input[1])
27
+ y_zero_point = get_optional_input(builder, node, 2)
28
+
29
+ if y_zero_point is not None:
30
+
31
+ def _quantize_uint8(
32
+ inp: torch.Tensor, s: torch.Tensor, zp: torch.Tensor
33
+ ) -> torch.Tensor:
34
+ return torch.clamp(torch.round(inp / s) + zp.float(), 0, 255).to(
35
+ torch.uint8
36
+ )
37
+
38
+ return builder.call_function(_quantize_uint8, args=(x, y_scale, y_zero_point))
39
+ else:
40
+
41
+ def _quantize_int8(inp: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
42
+ return torch.clamp(torch.round(inp / s), -128, 127).to(torch.int8)
43
+
44
+ return builder.call_function(_quantize_int8, args=(x, y_scale))
45
+
46
+
47
+ @register("DequantizeLinear")
48
+ def dequantize_linear(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
49
+ """Dequantize input tensor using scale and zero_point."""
50
+ x = builder.get_value(node.input[0])
51
+ x_scale = builder.get_value(node.input[1])
52
+ x_zero_point = get_optional_input(builder, node, 2)
53
+
54
+ if x_zero_point is not None:
55
+
56
+ def _dequantize(
57
+ inp: torch.Tensor, s: torch.Tensor, zp: torch.Tensor
58
+ ) -> torch.Tensor:
59
+ return (inp.float() - zp.float()) * s
60
+
61
+ return builder.call_function(_dequantize, args=(x, x_scale, x_zero_point))
62
+ else:
63
+
64
+ def _dequantize_no_zp(inp: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
65
+ return inp.float() * s
66
+
67
+ return builder.call_function(_dequantize_no_zp, args=(x, x_scale))
68
+
69
+
70
+ @register("DynamicQuantizeLinear")
71
+ def dynamic_quantize_linear(
72
+ builder: "GraphBuilder", node: onnx.NodeProto
73
+ ) -> torch.fx.Node:
74
+ """Dynamic quantization of input tensor to uint8.
75
+
76
+ Returns tuple of (y, y_scale, y_zero_point).
77
+ """
78
+ x = builder.get_value(node.input[0])
79
+
80
+ def _dynamic_quantize(
81
+ inp: torch.Tensor,
82
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83
+ x_min = torch.min(inp)
84
+ x_max = torch.max(inp)
85
+ scale = (x_max - x_min) / 255.0
86
+ zero_point = torch.clamp(torch.round(-x_min / scale), 0, 255).to(torch.uint8)
87
+ y = torch.clamp(torch.round(inp / scale) + zero_point.float(), 0, 255).to(
88
+ torch.uint8
89
+ )
90
+ return y, scale, zero_point
91
+
92
+ return builder.call_function(_dynamic_quantize, args=(x,))
93
+
94
+
95
+ # =============================================================================
96
+ # QLinear operators
97
+ # =============================================================================
98
+
99
+
100
+ @register("QLinearMatMul")
101
+ def qlinear_matmul(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
102
+ """Quantized MatMul with scales and zero points."""
103
+ a = builder.get_value(node.input[0])
104
+ a_scale = builder.get_value(node.input[1])
105
+ a_zero_point = builder.get_value(node.input[2])
106
+ b = builder.get_value(node.input[3])
107
+ b_scale = builder.get_value(node.input[4])
108
+ b_zero_point = builder.get_value(node.input[5])
109
+ y_scale = builder.get_value(node.input[6])
110
+ y_zero_point = builder.get_value(node.input[7])
111
+
112
+ def _qlinear_matmul(
113
+ a: torch.Tensor,
114
+ a_s: torch.Tensor,
115
+ a_zp: torch.Tensor,
116
+ b: torch.Tensor,
117
+ b_s: torch.Tensor,
118
+ b_zp: torch.Tensor,
119
+ y_s: torch.Tensor,
120
+ y_zp: torch.Tensor,
121
+ ) -> torch.Tensor:
122
+ # Dequantize
123
+ a_dq = (a.float() - a_zp.float()) * a_s
124
+ b_dq = (b.float() - b_zp.float()) * b_s
125
+ # MatMul
126
+ result = torch.matmul(a_dq, b_dq)
127
+ # Quantize output
128
+ return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
129
+ torch.uint8
130
+ )
131
+
132
+ return builder.call_function(
133
+ _qlinear_matmul,
134
+ args=(
135
+ a,
136
+ a_scale,
137
+ a_zero_point,
138
+ b,
139
+ b_scale,
140
+ b_zero_point,
141
+ y_scale,
142
+ y_zero_point,
143
+ ),
144
+ )
145
+
146
+
147
+ @register("QLinearConv")
148
+ def qlinear_conv(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
149
+ """Quantized 2D convolution with scales and zero points.
150
+
151
+ Inputs: x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, [B]
152
+ """
153
+ x = builder.get_value(node.input[0])
154
+ x_scale = builder.get_value(node.input[1])
155
+ x_zero_point = builder.get_value(node.input[2])
156
+ w = builder.get_value(node.input[3])
157
+ w_scale = builder.get_value(node.input[4])
158
+ w_zero_point = builder.get_value(node.input[5])
159
+ y_scale = builder.get_value(node.input[6])
160
+ y_zero_point = builder.get_value(node.input[7])
161
+ bias = get_optional_input(builder, node, 8)
162
+
163
+ # Get convolution attributes
164
+ # Note: kernel_shape is inferred from weight tensor, not from attribute
165
+ auto_pad = get_attribute(node, "auto_pad", "NOTSET")
166
+ dilations = get_attribute(node, "dilations", [1, 1])
167
+ group = get_attribute(node, "group", 1)
168
+ pads = get_attribute(node, "pads", [0, 0, 0, 0])
169
+ strides = get_attribute(node, "strides", [1, 1])
170
+
171
+ if auto_pad != "NOTSET":
172
+ # Handle auto_pad - for simplicity, assume SAME_UPPER
173
+ pass
174
+
175
+ # Convert pads from ONNX format [H_begin, W_begin, H_end, W_end] to PyTorch format
176
+ if len(pads) == 4:
177
+ padding = (pads[0], pads[1]) # Symmetric padding
178
+ else:
179
+ padding = tuple(pads)
180
+
181
+ def _qlinear_conv(
182
+ x: torch.Tensor,
183
+ x_s: torch.Tensor,
184
+ x_zp: torch.Tensor,
185
+ w: torch.Tensor,
186
+ w_s: torch.Tensor,
187
+ w_zp: torch.Tensor,
188
+ y_s: torch.Tensor,
189
+ y_zp: torch.Tensor,
190
+ bias: torch.Tensor | None,
191
+ stride: tuple,
192
+ padding: tuple,
193
+ dilation: tuple,
194
+ groups: int,
195
+ ) -> torch.Tensor:
196
+ # Dequantize input and weight
197
+ x_dq = (x.float() - x_zp.float()) * x_s
198
+ w_dq = (w.float() - w_zp.float()) * w_s
199
+
200
+ # Perform convolution
201
+ result = torch.nn.functional.conv2d(
202
+ x_dq,
203
+ w_dq,
204
+ bias=bias,
205
+ stride=stride,
206
+ padding=padding,
207
+ dilation=dilation,
208
+ groups=groups,
209
+ )
210
+
211
+ # Quantize output
212
+ return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
213
+ torch.uint8
214
+ )
215
+
216
+ return builder.call_function(
217
+ _qlinear_conv,
218
+ args=(
219
+ x,
220
+ x_scale,
221
+ x_zero_point,
222
+ w,
223
+ w_scale,
224
+ w_zero_point,
225
+ y_scale,
226
+ y_zero_point,
227
+ bias,
228
+ tuple(strides),
229
+ padding,
230
+ tuple(dilations),
231
+ group,
232
+ ),
233
+ )
234
+
235
+
236
+ @register("QLinearAdd")
237
+ def qlinear_add(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
238
+ """Quantized addition with scales and zero points (com.microsoft domain).
239
+
240
+ Inputs: A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point
241
+ """
242
+ a = builder.get_value(node.input[0])
243
+ a_scale = builder.get_value(node.input[1])
244
+ a_zero_point = builder.get_value(node.input[2])
245
+ b = builder.get_value(node.input[3])
246
+ b_scale = builder.get_value(node.input[4])
247
+ b_zero_point = builder.get_value(node.input[5])
248
+ c_scale = builder.get_value(node.input[6])
249
+ c_zero_point = builder.get_value(node.input[7])
250
+
251
+ def _qlinear_add(
252
+ a: torch.Tensor,
253
+ a_s: torch.Tensor,
254
+ a_zp: torch.Tensor,
255
+ b: torch.Tensor,
256
+ b_s: torch.Tensor,
257
+ b_zp: torch.Tensor,
258
+ c_s: torch.Tensor,
259
+ c_zp: torch.Tensor,
260
+ ) -> torch.Tensor:
261
+ # Dequantize
262
+ a_dq = (a.float() - a_zp.float()) * a_s
263
+ b_dq = (b.float() - b_zp.float()) * b_s
264
+ # Add
265
+ result = a_dq + b_dq
266
+ # Quantize output
267
+ return torch.clamp(torch.round(result / c_s) + c_zp.float(), 0, 255).to(
268
+ torch.uint8
269
+ )
270
+
271
+ return builder.call_function(
272
+ _qlinear_add,
273
+ args=(
274
+ a,
275
+ a_scale,
276
+ a_zero_point,
277
+ b,
278
+ b_scale,
279
+ b_zero_point,
280
+ c_scale,
281
+ c_zero_point,
282
+ ),
283
+ )
284
+
285
+
286
+ @register("QLinearMul")
287
+ def qlinear_mul(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
288
+ """Quantized multiplication with scales and zero points."""
289
+ a = builder.get_value(node.input[0])
290
+ a_scale = builder.get_value(node.input[1])
291
+ a_zero_point = builder.get_value(node.input[2])
292
+ b = builder.get_value(node.input[3])
293
+ b_scale = builder.get_value(node.input[4])
294
+ b_zero_point = builder.get_value(node.input[5])
295
+ c_scale = builder.get_value(node.input[6])
296
+ c_zero_point = builder.get_value(node.input[7])
297
+
298
+ def _qlinear_mul(
299
+ a: torch.Tensor,
300
+ a_s: torch.Tensor,
301
+ a_zp: torch.Tensor,
302
+ b: torch.Tensor,
303
+ b_s: torch.Tensor,
304
+ b_zp: torch.Tensor,
305
+ c_s: torch.Tensor,
306
+ c_zp: torch.Tensor,
307
+ ) -> torch.Tensor:
308
+ # Dequantize
309
+ a_dq = (a.float() - a_zp.float()) * a_s
310
+ b_dq = (b.float() - b_zp.float()) * b_s
311
+ # Multiply
312
+ result = a_dq * b_dq
313
+ # Quantize output
314
+ return torch.clamp(torch.round(result / c_s) + c_zp.float(), 0, 255).to(
315
+ torch.uint8
316
+ )
317
+
318
+ return builder.call_function(
319
+ _qlinear_mul,
320
+ args=(
321
+ a,
322
+ a_scale,
323
+ a_zero_point,
324
+ b,
325
+ b_scale,
326
+ b_zero_point,
327
+ c_scale,
328
+ c_zero_point,
329
+ ),
330
+ )
331
+
332
+
333
+ @register("QLinearSigmoid")
334
+ def qlinear_sigmoid(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
335
+ """Quantized sigmoid."""
336
+ x = builder.get_value(node.input[0])
337
+ x_scale = builder.get_value(node.input[1])
338
+ x_zero_point = builder.get_value(node.input[2])
339
+ y_scale = builder.get_value(node.input[3])
340
+ y_zero_point = builder.get_value(node.input[4])
341
+
342
+ def _qlinear_sigmoid(
343
+ x: torch.Tensor,
344
+ x_s: torch.Tensor,
345
+ x_zp: torch.Tensor,
346
+ y_s: torch.Tensor,
347
+ y_zp: torch.Tensor,
348
+ ) -> torch.Tensor:
349
+ # Dequantize
350
+ x_dq = (x.float() - x_zp.float()) * x_s
351
+ # Sigmoid
352
+ result = torch.sigmoid(x_dq)
353
+ # Quantize output
354
+ return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
355
+ torch.uint8
356
+ )
357
+
358
+ return builder.call_function(
359
+ _qlinear_sigmoid,
360
+ args=(x, x_scale, x_zero_point, y_scale, y_zero_point),
361
+ )
362
+
363
+
364
+ @register("QLinearLeakyRelu")
365
+ def qlinear_leaky_relu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
366
+ """Quantized Leaky ReLU."""
367
+ x = builder.get_value(node.input[0])
368
+ x_scale = builder.get_value(node.input[1])
369
+ x_zero_point = builder.get_value(node.input[2])
370
+ y_scale = builder.get_value(node.input[3])
371
+ y_zero_point = builder.get_value(node.input[4])
372
+ alpha = get_attribute(node, "alpha", 0.01)
373
+
374
+ def _qlinear_leaky_relu(
375
+ x: torch.Tensor,
376
+ x_s: torch.Tensor,
377
+ x_zp: torch.Tensor,
378
+ y_s: torch.Tensor,
379
+ y_zp: torch.Tensor,
380
+ alpha: float,
381
+ ) -> torch.Tensor:
382
+ # Dequantize
383
+ x_dq = (x.float() - x_zp.float()) * x_s
384
+ # LeakyReLU
385
+ result = torch.nn.functional.leaky_relu(x_dq, negative_slope=alpha)
386
+ # Quantize output
387
+ return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
388
+ torch.uint8
389
+ )
390
+
391
+ return builder.call_function(
392
+ _qlinear_leaky_relu,
393
+ args=(x, x_scale, x_zero_point, y_scale, y_zero_point, alpha),
394
+ )
395
+
396
+
397
+ @register("QLinearGlobalAveragePool")
398
+ def qlinear_global_avg_pool(
399
+ builder: "GraphBuilder", node: onnx.NodeProto
400
+ ) -> torch.fx.Node:
401
+ """Quantized Global Average Pooling."""
402
+ x = builder.get_value(node.input[0])
403
+ x_scale = builder.get_value(node.input[1])
404
+ x_zero_point = builder.get_value(node.input[2])
405
+ y_scale = builder.get_value(node.input[3])
406
+ y_zero_point = builder.get_value(node.input[4])
407
+
408
+ def _qlinear_global_avg_pool(
409
+ x: torch.Tensor,
410
+ x_s: torch.Tensor,
411
+ x_zp: torch.Tensor,
412
+ y_s: torch.Tensor,
413
+ y_zp: torch.Tensor,
414
+ ) -> torch.Tensor:
415
+ # Dequantize
416
+ x_dq = (x.float() - x_zp.float()) * x_s
417
+ # Global Average Pool
418
+ result = torch.nn.functional.adaptive_avg_pool2d(x_dq, (1, 1))
419
+ # Quantize output
420
+ return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
421
+ torch.uint8
422
+ )
423
+
424
+ return builder.call_function(
425
+ _qlinear_global_avg_pool,
426
+ args=(x, x_scale, x_zero_point, y_scale, y_zero_point),
427
+ )
428
+
429
+
430
+ # =============================================================================
431
+ # Integer arithmetic operators
432
+ # =============================================================================
433
+
434
+
435
+ @register("ConvInteger")
436
+ def conv_integer(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
437
+ """Integer convolution (returns int32)."""
438
+ x = builder.get_value(node.input[0])
439
+ w = builder.get_value(node.input[1])
440
+ x_zero_point = get_optional_input(builder, node, 2)
441
+ w_zero_point = get_optional_input(builder, node, 3)
442
+
443
+ # Get convolution attributes
444
+ # Note: auto_pad is not implemented; use explicit pads instead
445
+ dilations = get_attribute(node, "dilations", [1, 1])
446
+ group = get_attribute(node, "group", 1)
447
+ pads = get_attribute(node, "pads", [0, 0, 0, 0])
448
+ strides = get_attribute(node, "strides", [1, 1])
449
+
450
+ if len(pads) == 4:
451
+ padding = (pads[0], pads[1])
452
+ else:
453
+ padding = tuple(pads)
454
+
455
+ def _conv_integer(
456
+ x: torch.Tensor,
457
+ w: torch.Tensor,
458
+ x_zp: torch.Tensor | None,
459
+ w_zp: torch.Tensor | None,
460
+ stride: tuple,
461
+ padding: tuple,
462
+ dilation: tuple,
463
+ groups: int,
464
+ ) -> torch.Tensor:
465
+ # Subtract zero points
466
+ x_int = x.int()
467
+ w_int = w.int()
468
+ if x_zp is not None:
469
+ x_int = x_int - x_zp.int()
470
+ if w_zp is not None:
471
+ w_int = w_int - w_zp.int()
472
+
473
+ # Perform convolution in float (PyTorch doesn't support int conv)
474
+ result = torch.nn.functional.conv2d(
475
+ x_int.float(),
476
+ w_int.float(),
477
+ stride=stride,
478
+ padding=padding,
479
+ dilation=dilation,
480
+ groups=groups,
481
+ )
482
+ return result.int()
483
+
484
+ return builder.call_function(
485
+ _conv_integer,
486
+ args=(
487
+ x,
488
+ w,
489
+ x_zero_point,
490
+ w_zero_point,
491
+ tuple(strides),
492
+ padding,
493
+ tuple(dilations),
494
+ group,
495
+ ),
496
+ )
497
+
498
+
499
+ @register("MatMulInteger")
500
+ def matmul_integer(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
501
+ """Integer matrix multiplication (returns int32)."""
502
+ a = builder.get_value(node.input[0])
503
+ b = builder.get_value(node.input[1])
504
+ a_zero_point = get_optional_input(builder, node, 2)
505
+ b_zero_point = get_optional_input(builder, node, 3)
506
+
507
+ def _matmul_integer(
508
+ a: torch.Tensor,
509
+ b: torch.Tensor,
510
+ a_zp: torch.Tensor | None,
511
+ b_zp: torch.Tensor | None,
512
+ ) -> torch.Tensor:
513
+ a_int = a.int()
514
+ b_int = b.int()
515
+ if a_zp is not None:
516
+ a_int = a_int - a_zp.int()
517
+ if b_zp is not None:
518
+ b_int = b_int - b_zp.int()
519
+ return torch.matmul(a_int.float(), b_int.float()).int()
520
+
521
+ return builder.call_function(
522
+ _matmul_integer,
523
+ args=(a, b, a_zero_point, b_zero_point),
524
+ )
onnx2fx/ops/random.py ADDED
@@ -0,0 +1,102 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Random number generation operators.
3
+
4
+ This module implements ONNX operators for generating random tensors,
5
+ including normal and uniform distributions.
6
+
7
+ Note: Window functions (HannWindow, HammingWindow, BlackmanWindow) have been
8
+ moved to signal.py as they are used for signal processing.
9
+ """
10
+
11
+ from typing import TYPE_CHECKING
12
+
13
+ import onnx
14
+ import torch
15
+
16
+ from ..op_registry import register
17
+ from ..utils.attributes import get_attribute
18
+
19
+ if TYPE_CHECKING:
20
+ from ..graph_builder import GraphBuilder
21
+
22
+
23
+ # =============================================================================
24
+ # Random number generation operators
25
+ # =============================================================================
26
+
27
+
28
+ @register("RandomNormal")
29
+ def random_normal(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
30
+ """Generate random values from normal distribution.
31
+
32
+ Note: The seed attribute is not supported; use torch.manual_seed() instead.
33
+ """
34
+ mean = get_attribute(node, "mean", 0.0)
35
+ scale = get_attribute(node, "scale", 1.0)
36
+ shape = get_attribute(node, "shape")
37
+
38
+ def _random_normal(m: float, s: float, sh: list) -> torch.Tensor:
39
+ return torch.randn(sh) * s + m
40
+
41
+ return builder.call_function(_random_normal, args=(mean, scale, list(shape)))
42
+
43
+
44
+ @register("RandomNormalLike")
45
+ def random_normal_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
46
+ """Generate random values like input tensor."""
47
+ x = builder.get_value(node.input[0])
48
+
49
+ mean = get_attribute(node, "mean", 0.0)
50
+ scale = get_attribute(node, "scale", 1.0)
51
+
52
+ def _random_normal_like(t: torch.Tensor, m: float, s: float) -> torch.Tensor:
53
+ return torch.randn_like(t) * s + m
54
+
55
+ return builder.call_function(_random_normal_like, args=(x, mean, scale))
56
+
57
+
58
+ @register("RandomUniform")
59
+ def random_uniform(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
60
+ """Generate random values from uniform distribution."""
61
+ low = get_attribute(node, "low", 0.0)
62
+ high = get_attribute(node, "high", 1.0)
63
+ shape = get_attribute(node, "shape")
64
+
65
+ def _random_uniform(lo: float, hi: float, sh: list) -> torch.Tensor:
66
+ return torch.rand(sh) * (hi - lo) + lo
67
+
68
+ return builder.call_function(_random_uniform, args=(low, high, list(shape)))
69
+
70
+
71
+ @register("RandomUniformLike")
72
+ def random_uniform_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
73
+ """Generate random values like input tensor."""
74
+ x = builder.get_value(node.input[0])
75
+
76
+ low = get_attribute(node, "low", 0.0)
77
+ high = get_attribute(node, "high", 1.0)
78
+
79
+ def _random_uniform_like(t: torch.Tensor, lo: float, hi: float) -> torch.Tensor:
80
+ return torch.rand_like(t) * (hi - lo) + lo
81
+
82
+ return builder.call_function(_random_uniform_like, args=(x, low, high))
83
+
84
+
85
+ @register("Multinomial")
86
+ def multinomial(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
87
+ """Sample from multinomial distribution."""
88
+ x = builder.get_value(node.input[0])
89
+
90
+ sample_size = get_attribute(node, "sample_size", 1)
91
+
92
+ return builder.call_function(
93
+ torch.multinomial, args=(x, sample_size), kwargs={"replacement": True}
94
+ )
95
+
96
+
97
+ @register("Bernoulli")
98
+ def bernoulli(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
99
+ """Sample from Bernoulli distribution."""
100
+ x = builder.get_value(node.input[0])
101
+
102
+ return builder.call_function(torch.bernoulli, args=(x,))