onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2fx/ops/signal.py ADDED
@@ -0,0 +1,444 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Signal processing operators.
3
+
4
+ This module implements specialized ONNX operators for signal processing,
5
+ including STFT, mel spectrograms, window functions, and non-maximum suppression.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING
9
+
10
+ import onnx
11
+ import torch
12
+
13
+ from ..op_registry import register
14
+ from ..utils.attributes import get_attribute
15
+ from ..utils.op_helpers import get_optional_input
16
+
17
+ if TYPE_CHECKING:
18
+ from ..graph_builder import GraphBuilder
19
+
20
+
21
+ # =============================================================================
22
+ # STFT (Short-time Fourier Transform)
23
+ # =============================================================================
24
+
25
+
26
+ @register("STFT", since_version=17)
27
+ def stft(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
28
+ """Short-time Fourier Transform.
29
+
30
+ ONNX STFT operator computes the STFT of the input signal.
31
+
32
+ Inputs:
33
+ signal: [batch_size, signal_length, 1] for real or [batch_size, signal_length, 2] for complex
34
+ frame_step: scalar, the hop length
35
+ window (optional): 1D window tensor
36
+ frame_length (optional): scalar, the FFT size
37
+
38
+ Attributes:
39
+ onesided: int (default 1), whether to return one-sided output
40
+
41
+ Output:
42
+ [batch_size, frames, dft_unique_bins, 2] with real and imaginary components
43
+ """
44
+ signal = builder.get_value(node.input[0])
45
+ frame_step = builder.get_value(node.input[1])
46
+
47
+ # Optional window input
48
+ window = get_optional_input(builder, node, 2)
49
+
50
+ # Optional frame_length input
51
+ frame_length = get_optional_input(builder, node, 3)
52
+
53
+ # Get onesided attribute (default is 1)
54
+ onesided = get_attribute(node, "onesided", 1)
55
+
56
+ def _stft(signal, frame_step, window, frame_length, onesided):
57
+ # ONNX signal shape: [batch, signal_length, 1] (real) or [batch, signal_length, 2] (complex)
58
+ # We need to convert to PyTorch format: [batch, signal_length]
59
+
60
+ # Check if input is complex (last dim is 2)
61
+ is_complex_input = signal.shape[-1] == 2
62
+
63
+ if is_complex_input:
64
+ # Convert to complex tensor
65
+ signal_2d = torch.complex(signal[..., 0], signal[..., 1])
66
+ else:
67
+ # Squeeze the last dimension for real input
68
+ signal_2d = signal.squeeze(-1)
69
+
70
+ # Get scalar values
71
+ hop_length = (
72
+ int(frame_step.item())
73
+ if isinstance(frame_step, torch.Tensor)
74
+ else int(frame_step)
75
+ )
76
+
77
+ # Determine n_fft
78
+ if frame_length is not None:
79
+ n_fft = (
80
+ int(frame_length.item())
81
+ if isinstance(frame_length, torch.Tensor)
82
+ else int(frame_length)
83
+ )
84
+ elif window is not None:
85
+ n_fft = window.shape[0]
86
+ else:
87
+ raise ValueError("Either frame_length or window must be provided for STFT")
88
+
89
+ # Determine onesided behavior
90
+ # For complex input, onesided must be False
91
+ onesided_bool = bool(onesided) and not is_complex_input
92
+
93
+ # Call PyTorch stft
94
+ # PyTorch stft returns [batch, n_fft, frames] (complex) or [batch, n_fft, frames, 2] (real)
95
+ result = torch.stft(
96
+ signal_2d,
97
+ n_fft=n_fft,
98
+ hop_length=hop_length,
99
+ win_length=n_fft,
100
+ window=window,
101
+ center=False, # ONNX does not pad
102
+ onesided=onesided_bool,
103
+ return_complex=True,
104
+ )
105
+
106
+ # result shape: [batch, bins, frames] (complex)
107
+ # ONNX expects: [batch, frames, bins, 2]
108
+
109
+ # Permute from [batch, bins, frames] to [batch, frames, bins]
110
+ result = result.permute(0, 2, 1)
111
+
112
+ # Convert complex to real representation [batch, frames, bins, 2]
113
+ result = torch.view_as_real(result)
114
+
115
+ return result
116
+
117
+ return builder.call_function(
118
+ _stft,
119
+ args=(signal, frame_step, window, frame_length, onesided),
120
+ )
121
+
122
+
123
+ # =============================================================================
124
+ # MelWeightMatrix operator
125
+ # =============================================================================
126
+
127
+
128
+ @register("MelWeightMatrix", since_version=17)
129
+ def mel_weight_matrix(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
130
+ """Generate a MelWeightMatrix for mel spectrogram computation.
131
+
132
+ This operator generates a weight matrix that can be used to convert a linearly
133
+ sampled frequency spectra (from DFT or STFT) into mel-scaled frequency bins.
134
+
135
+ The mel scale is defined as: mel(f) = 2595 * log10(1 + f/700)
136
+
137
+ Inputs:
138
+ num_mel_bins: The number of bands in the mel spectrum (scalar, int32/int64)
139
+ dft_length: The size of the original DFT (scalar, int32/int64)
140
+ sample_rate: Samples per second of the input signal (scalar, int32/int64)
141
+ lower_edge_hertz: Lower bound frequency for mel spectrum (scalar, float)
142
+ upper_edge_hertz: Upper bound frequency for mel spectrum (scalar, float)
143
+
144
+ Attributes:
145
+ output_datatype: The data type of the output tensor (default: 1 = FLOAT)
146
+
147
+ Output:
148
+ The Mel Weight Matrix with shape [floor(dft_length/2) + 1, num_mel_bins]
149
+ """
150
+ from ..utils.dtype import onnx_dtype_to_torch
151
+
152
+ num_mel_bins = builder.get_value(node.input[0])
153
+ dft_length = builder.get_value(node.input[1])
154
+ sample_rate = builder.get_value(node.input[2])
155
+ lower_edge_hertz = builder.get_value(node.input[3])
156
+ upper_edge_hertz = builder.get_value(node.input[4])
157
+
158
+ # Get output data type (default is 1 = FLOAT)
159
+ output_datatype = get_attribute(node, "output_datatype", 1)
160
+ output_dtype = onnx_dtype_to_torch(output_datatype)
161
+ if output_dtype is None:
162
+ output_dtype = torch.float32
163
+
164
+ def _mel_weight_matrix(
165
+ num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz, dtype
166
+ ):
167
+ # Convert inputs to Python scalars
168
+ n_mels = int(
169
+ num_mel_bins.item()
170
+ if isinstance(num_mel_bins, torch.Tensor)
171
+ else num_mel_bins
172
+ )
173
+ n_fft = int(
174
+ dft_length.item() if isinstance(dft_length, torch.Tensor) else dft_length
175
+ )
176
+ sr = int(
177
+ sample_rate.item() if isinstance(sample_rate, torch.Tensor) else sample_rate
178
+ )
179
+ f_min = float(
180
+ lower_edge_hertz.item()
181
+ if isinstance(lower_edge_hertz, torch.Tensor)
182
+ else lower_edge_hertz
183
+ )
184
+ f_max = float(
185
+ upper_edge_hertz.item()
186
+ if isinstance(upper_edge_hertz, torch.Tensor)
187
+ else upper_edge_hertz
188
+ )
189
+
190
+ # Number of spectrogram bins (one-sided DFT)
191
+ num_spectrogram_bins = n_fft // 2 + 1
192
+
193
+ # Create frequency bin indices (n_mels + 2 points for n_mels triangular filters)
194
+ frequency_bins = torch.arange(0, n_mels + 2, dtype=torch.float32)
195
+
196
+ # Convert edge frequencies to mel scale
197
+ low_frequency_mel = 2595.0 * torch.log10(torch.tensor(1.0 + f_min / 700.0))
198
+ high_frequency_mel = 2595.0 * torch.log10(torch.tensor(1.0 + f_max / 700.0))
199
+
200
+ # Calculate mel step
201
+ mel_step = (high_frequency_mel - low_frequency_mel) / (frequency_bins.shape[0])
202
+
203
+ # Convert to mel frequencies
204
+ frequency_bins = frequency_bins * mel_step + low_frequency_mel
205
+
206
+ # Convert mel frequencies back to Hz
207
+ frequency_bins = 700.0 * (
208
+ torch.pow(torch.tensor(10.0), frequency_bins / 2595.0) - 1.0
209
+ )
210
+
211
+ # Convert Hz frequencies to FFT bin indices
212
+ frequency_bins = ((n_fft + 1) * frequency_bins) // sr
213
+ frequency_bins = frequency_bins.to(torch.int64)
214
+
215
+ # Create the filterbank matrix
216
+ output = torch.zeros(num_spectrogram_bins, n_mels, dtype=torch.float32)
217
+
218
+ for i in range(n_mels):
219
+ lower_frequency_value = frequency_bins[i].item() # left
220
+ center_frequency_point = frequency_bins[i + 1].item() # center
221
+ higher_frequency_point = frequency_bins[i + 2].item() # right
222
+
223
+ low_to_center = center_frequency_point - lower_frequency_value
224
+ if low_to_center == 0:
225
+ output[center_frequency_point, i] = 1.0
226
+ else:
227
+ for j in range(lower_frequency_value, center_frequency_point + 1):
228
+ output[j, i] = float(j - lower_frequency_value) / float(
229
+ low_to_center
230
+ )
231
+
232
+ center_to_high = higher_frequency_point - center_frequency_point
233
+ if center_to_high > 0:
234
+ for j in range(center_frequency_point, higher_frequency_point):
235
+ output[j, i] = float(higher_frequency_point - j) / float(
236
+ center_to_high
237
+ )
238
+
239
+ return output.to(dtype)
240
+
241
+ return builder.call_function(
242
+ _mel_weight_matrix,
243
+ args=(
244
+ num_mel_bins,
245
+ dft_length,
246
+ sample_rate,
247
+ lower_edge_hertz,
248
+ upper_edge_hertz,
249
+ output_dtype,
250
+ ),
251
+ )
252
+
253
+
254
+ # =============================================================================
255
+ # Window function operators
256
+ # =============================================================================
257
+
258
+
259
+ @register("HannWindow", since_version=17)
260
+ def hann_window(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
261
+ """Generate a Hann window.
262
+
263
+ Attributes:
264
+ periodic: If 1, returns periodic window. If 0, returns symmetric window.
265
+ output_datatype: ONNX TensorProto data type for output (default: FLOAT).
266
+ """
267
+ from ..utils.dtype import onnx_dtype_to_torch
268
+
269
+ size = builder.get_value(node.input[0])
270
+ periodic = get_attribute(node, "periodic", 1)
271
+ output_datatype = get_attribute(node, "output_datatype", 1) # Default: FLOAT
272
+
273
+ dtype = onnx_dtype_to_torch(output_datatype)
274
+
275
+ def _hann_window(
276
+ window_length: torch.Tensor, periodic: bool, dtype: torch.dtype
277
+ ) -> torch.Tensor:
278
+ length = int(window_length.item())
279
+ return torch.hann_window(length, periodic=periodic, dtype=dtype)
280
+
281
+ return builder.call_function(_hann_window, args=(size, bool(periodic), dtype))
282
+
283
+
284
+ @register("HammingWindow", since_version=17)
285
+ def hamming_window(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
286
+ """Generate a Hamming window.
287
+
288
+ Attributes:
289
+ periodic: If 1, returns periodic window. If 0, returns symmetric window.
290
+ output_datatype: ONNX TensorProto data type for output (default: FLOAT).
291
+
292
+ Note:
293
+ ONNX uses specific Hamming coefficients (alpha=0.543478, beta=0.456522)
294
+ which differ from PyTorch's defaults (0.54, 0.46).
295
+ """
296
+ from ..utils.dtype import onnx_dtype_to_torch
297
+
298
+ size = builder.get_value(node.input[0])
299
+ periodic = get_attribute(node, "periodic", 1)
300
+ output_datatype = get_attribute(node, "output_datatype", 1) # Default: FLOAT
301
+
302
+ dtype = onnx_dtype_to_torch(output_datatype)
303
+
304
+ # ONNX HammingWindow uses these specific coefficients
305
+ alpha = 0.543478
306
+ beta = 0.456522
307
+
308
+ def _hamming_window(
309
+ window_length: torch.Tensor, periodic: bool, dtype: torch.dtype
310
+ ) -> torch.Tensor:
311
+ length = int(window_length.item())
312
+ return torch.hamming_window(
313
+ length, periodic=periodic, alpha=alpha, beta=beta, dtype=dtype
314
+ )
315
+
316
+ return builder.call_function(_hamming_window, args=(size, bool(periodic), dtype))
317
+
318
+
319
+ @register("BlackmanWindow", since_version=17)
320
+ def blackman_window(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
321
+ """Generate a Blackman window.
322
+
323
+ Attributes:
324
+ periodic: If 1, returns periodic window. If 0, returns symmetric window.
325
+ output_datatype: ONNX TensorProto data type for output (default: FLOAT).
326
+ """
327
+ from ..utils.dtype import onnx_dtype_to_torch
328
+
329
+ size = builder.get_value(node.input[0])
330
+ periodic = get_attribute(node, "periodic", 1)
331
+ output_datatype = get_attribute(node, "output_datatype", 1) # Default: FLOAT
332
+
333
+ dtype = onnx_dtype_to_torch(output_datatype)
334
+
335
+ def _blackman_window(
336
+ window_length: torch.Tensor, periodic: bool, dtype: torch.dtype
337
+ ) -> torch.Tensor:
338
+ length = int(window_length.item())
339
+ return torch.blackman_window(length, periodic=periodic, dtype=dtype)
340
+
341
+ return builder.call_function(_blackman_window, args=(size, bool(periodic), dtype))
342
+
343
+
344
+ # =============================================================================
345
+ # Non-maximum suppression
346
+ # =============================================================================
347
+
348
+
349
+ @register("NonMaxSuppression")
350
+ def nms(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
351
+ """Non-maximum suppression for object detection."""
352
+ boxes = builder.get_value(node.input[0])
353
+ scores = builder.get_value(node.input[1])
354
+
355
+ max_output = get_optional_input(builder, node, 2)
356
+ iou_threshold = get_optional_input(builder, node, 3)
357
+ score_threshold = get_optional_input(builder, node, 4)
358
+
359
+ # Set defaults if not provided
360
+ if iou_threshold is None:
361
+ iou_threshold = 0.0
362
+ if score_threshold is None:
363
+ score_threshold = float("-inf")
364
+
365
+ center_point_box = get_attribute(node, "center_point_box", 0)
366
+
367
+ def _nms(
368
+ boxes, scores, max_output, iou_threshold, score_threshold, center_point_box
369
+ ):
370
+ from torchvision.ops import nms as tv_nms
371
+
372
+ batch_size = boxes.shape[0]
373
+ num_classes = scores.shape[1]
374
+
375
+ iou_th = (
376
+ iou_threshold.item()
377
+ if isinstance(iou_threshold, torch.Tensor)
378
+ else iou_threshold
379
+ )
380
+ score_th = (
381
+ score_threshold.item()
382
+ if isinstance(score_threshold, torch.Tensor)
383
+ else score_threshold
384
+ )
385
+ max_out = (
386
+ max_output.item()
387
+ if isinstance(max_output, torch.Tensor) and max_output is not None
388
+ else max_output
389
+ )
390
+
391
+ results = []
392
+ for batch_idx in range(batch_size):
393
+ batch_boxes = boxes[batch_idx] # [num_boxes, 4]
394
+
395
+ # Convert center format to corner format if needed
396
+ if center_point_box:
397
+ cx, cy, w, h = (
398
+ batch_boxes[:, 0],
399
+ batch_boxes[:, 1],
400
+ batch_boxes[:, 2],
401
+ batch_boxes[:, 3],
402
+ )
403
+ batch_boxes = torch.stack(
404
+ [cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], dim=1
405
+ )
406
+
407
+ for class_idx in range(num_classes):
408
+ class_scores = scores[batch_idx, class_idx] # [num_boxes]
409
+
410
+ # Filter by score threshold
411
+ mask = class_scores > score_th
412
+ if not mask.any():
413
+ continue
414
+
415
+ filtered_boxes = batch_boxes[mask]
416
+ filtered_scores = class_scores[mask]
417
+
418
+ # Apply NMS
419
+ keep = tv_nms(filtered_boxes, filtered_scores, iou_th)
420
+
421
+ if max_out is not None:
422
+ keep = keep[: int(max_out)]
423
+
424
+ # Get original indices
425
+ original_indices = torch.where(mask)[0][keep]
426
+
427
+ for idx in original_indices:
428
+ results.append([batch_idx, class_idx, idx.item()])
429
+
430
+ if len(results) == 0:
431
+ return torch.zeros((0, 3), dtype=torch.int64)
432
+ return torch.tensor(results, dtype=torch.int64)
433
+
434
+ return builder.call_function(
435
+ _nms,
436
+ args=(
437
+ boxes,
438
+ scores,
439
+ max_output,
440
+ iou_threshold,
441
+ score_threshold,
442
+ center_point_box,
443
+ ),
444
+ )
onnx2fx/ops/string.py ADDED
@@ -0,0 +1,126 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """String operators.
3
+
4
+ This module implements ONNX operators for string processing.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ import onnx
10
+ import torch
11
+
12
+ from ..op_registry import register
13
+
14
+ if TYPE_CHECKING:
15
+ from ..graph_builder import GraphBuilder
16
+
17
+
18
+ # =============================================================================
19
+ # String operators
20
+ # =============================================================================
21
+
22
+
23
+ @register("StringNormalizer")
24
+ def string_normalizer(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
25
+ """String normalization for numpy string arrays.
26
+
27
+ Applies case normalization and stopword filtering to string arrays.
28
+ """
29
+ from ..utils.attributes import get_attribute
30
+
31
+ x = builder.get_value(node.input[0])
32
+ case_change_action = get_attribute(node, "case_change_action", "NONE")
33
+ is_case_sensitive = get_attribute(node, "is_case_sensitive", 0)
34
+ stopwords = get_attribute(node, "stopwords", [])
35
+ locale = get_attribute(node, "locale", "")
36
+
37
+ def _string_normalizer(
38
+ arr, case_action: str, case_sensitive: int, stops: list, loc: str
39
+ ):
40
+ import numpy as np
41
+
42
+ original_shape = arr.shape
43
+ is_2d = len(original_shape) == 2
44
+
45
+ if is_2d:
46
+ # Process each row separately for 2D arrays
47
+ result_rows = []
48
+ for row_idx in range(original_shape[0]):
49
+ row = arr[row_idx]
50
+ # Filter stopwords from this row
51
+ if stops:
52
+ if case_sensitive:
53
+ stop_set = set(stops)
54
+ filtered = [s for s in row if s not in stop_set]
55
+ else:
56
+ stop_set_lower = {s.lower() for s in stops}
57
+ filtered = [
58
+ s
59
+ for s in row
60
+ if not isinstance(s, str) or s.lower() not in stop_set_lower
61
+ ]
62
+ else:
63
+ filtered = list(row)
64
+
65
+ # Apply case change
66
+ if case_action == "LOWER":
67
+ filtered = [
68
+ s.lower() if isinstance(s, str) else s for s in filtered
69
+ ]
70
+ elif case_action == "UPPER":
71
+ filtered = [
72
+ s.upper() if isinstance(s, str) else s for s in filtered
73
+ ]
74
+
75
+ result_rows.append(filtered)
76
+
77
+ # Find max length and build output
78
+ if result_rows:
79
+ # Pad shorter rows if needed (shouldn't happen for valid ONNX)
80
+ output = np.array(result_rows, dtype=object)
81
+ return output
82
+ return np.array([[]], dtype=object)
83
+
84
+ # 1D case
85
+ flat = arr.flatten()
86
+
87
+ # Filter stopwords first (before case change for matching)
88
+ if stops:
89
+ if case_sensitive:
90
+ stop_set = set(stops)
91
+ flat = np.array([s for s in flat if s not in stop_set], dtype=object)
92
+ else:
93
+ stop_set_lower = {s.lower() for s in stops}
94
+ result = []
95
+ for s in flat:
96
+ if isinstance(s, str):
97
+ if s.lower() not in stop_set_lower:
98
+ result.append(s)
99
+ else:
100
+ result.append(s)
101
+ flat = (
102
+ np.array(result, dtype=object)
103
+ if result
104
+ else np.array([], dtype=object)
105
+ )
106
+
107
+ # Apply case change after filtering
108
+ if case_action == "LOWER":
109
+ flat = np.array(
110
+ [s.lower() if isinstance(s, str) else s for s in flat], dtype=object
111
+ )
112
+ elif case_action == "UPPER":
113
+ flat = np.array(
114
+ [s.upper() if isinstance(s, str) else s for s in flat], dtype=object
115
+ )
116
+
117
+ # Handle empty result - ONNX spec says return [''] for empty
118
+ if len(flat) == 0:
119
+ return np.array([""], dtype=object)
120
+
121
+ return flat
122
+
123
+ return builder.call_function(
124
+ _string_normalizer,
125
+ args=(x, case_change_action, is_case_sensitive, stopwords, locale),
126
+ )