onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/signal.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Signal processing operators.
|
|
3
|
+
|
|
4
|
+
This module implements specialized ONNX operators for signal processing,
|
|
5
|
+
including STFT, mel spectrograms, window functions, and non-maximum suppression.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ..op_registry import register
|
|
14
|
+
from ..utils.attributes import get_attribute
|
|
15
|
+
from ..utils.op_helpers import get_optional_input
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from ..graph_builder import GraphBuilder
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# =============================================================================
|
|
22
|
+
# STFT (Short-time Fourier Transform)
|
|
23
|
+
# =============================================================================
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@register("STFT", since_version=17)
|
|
27
|
+
def stft(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
28
|
+
"""Short-time Fourier Transform.
|
|
29
|
+
|
|
30
|
+
ONNX STFT operator computes the STFT of the input signal.
|
|
31
|
+
|
|
32
|
+
Inputs:
|
|
33
|
+
signal: [batch_size, signal_length, 1] for real or [batch_size, signal_length, 2] for complex
|
|
34
|
+
frame_step: scalar, the hop length
|
|
35
|
+
window (optional): 1D window tensor
|
|
36
|
+
frame_length (optional): scalar, the FFT size
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
onesided: int (default 1), whether to return one-sided output
|
|
40
|
+
|
|
41
|
+
Output:
|
|
42
|
+
[batch_size, frames, dft_unique_bins, 2] with real and imaginary components
|
|
43
|
+
"""
|
|
44
|
+
signal = builder.get_value(node.input[0])
|
|
45
|
+
frame_step = builder.get_value(node.input[1])
|
|
46
|
+
|
|
47
|
+
# Optional window input
|
|
48
|
+
window = get_optional_input(builder, node, 2)
|
|
49
|
+
|
|
50
|
+
# Optional frame_length input
|
|
51
|
+
frame_length = get_optional_input(builder, node, 3)
|
|
52
|
+
|
|
53
|
+
# Get onesided attribute (default is 1)
|
|
54
|
+
onesided = get_attribute(node, "onesided", 1)
|
|
55
|
+
|
|
56
|
+
def _stft(signal, frame_step, window, frame_length, onesided):
|
|
57
|
+
# ONNX signal shape: [batch, signal_length, 1] (real) or [batch, signal_length, 2] (complex)
|
|
58
|
+
# We need to convert to PyTorch format: [batch, signal_length]
|
|
59
|
+
|
|
60
|
+
# Check if input is complex (last dim is 2)
|
|
61
|
+
is_complex_input = signal.shape[-1] == 2
|
|
62
|
+
|
|
63
|
+
if is_complex_input:
|
|
64
|
+
# Convert to complex tensor
|
|
65
|
+
signal_2d = torch.complex(signal[..., 0], signal[..., 1])
|
|
66
|
+
else:
|
|
67
|
+
# Squeeze the last dimension for real input
|
|
68
|
+
signal_2d = signal.squeeze(-1)
|
|
69
|
+
|
|
70
|
+
# Get scalar values
|
|
71
|
+
hop_length = (
|
|
72
|
+
int(frame_step.item())
|
|
73
|
+
if isinstance(frame_step, torch.Tensor)
|
|
74
|
+
else int(frame_step)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Determine n_fft
|
|
78
|
+
if frame_length is not None:
|
|
79
|
+
n_fft = (
|
|
80
|
+
int(frame_length.item())
|
|
81
|
+
if isinstance(frame_length, torch.Tensor)
|
|
82
|
+
else int(frame_length)
|
|
83
|
+
)
|
|
84
|
+
elif window is not None:
|
|
85
|
+
n_fft = window.shape[0]
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError("Either frame_length or window must be provided for STFT")
|
|
88
|
+
|
|
89
|
+
# Determine onesided behavior
|
|
90
|
+
# For complex input, onesided must be False
|
|
91
|
+
onesided_bool = bool(onesided) and not is_complex_input
|
|
92
|
+
|
|
93
|
+
# Call PyTorch stft
|
|
94
|
+
# PyTorch stft returns [batch, n_fft, frames] (complex) or [batch, n_fft, frames, 2] (real)
|
|
95
|
+
result = torch.stft(
|
|
96
|
+
signal_2d,
|
|
97
|
+
n_fft=n_fft,
|
|
98
|
+
hop_length=hop_length,
|
|
99
|
+
win_length=n_fft,
|
|
100
|
+
window=window,
|
|
101
|
+
center=False, # ONNX does not pad
|
|
102
|
+
onesided=onesided_bool,
|
|
103
|
+
return_complex=True,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# result shape: [batch, bins, frames] (complex)
|
|
107
|
+
# ONNX expects: [batch, frames, bins, 2]
|
|
108
|
+
|
|
109
|
+
# Permute from [batch, bins, frames] to [batch, frames, bins]
|
|
110
|
+
result = result.permute(0, 2, 1)
|
|
111
|
+
|
|
112
|
+
# Convert complex to real representation [batch, frames, bins, 2]
|
|
113
|
+
result = torch.view_as_real(result)
|
|
114
|
+
|
|
115
|
+
return result
|
|
116
|
+
|
|
117
|
+
return builder.call_function(
|
|
118
|
+
_stft,
|
|
119
|
+
args=(signal, frame_step, window, frame_length, onesided),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# =============================================================================
|
|
124
|
+
# MelWeightMatrix operator
|
|
125
|
+
# =============================================================================
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@register("MelWeightMatrix", since_version=17)
|
|
129
|
+
def mel_weight_matrix(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
130
|
+
"""Generate a MelWeightMatrix for mel spectrogram computation.
|
|
131
|
+
|
|
132
|
+
This operator generates a weight matrix that can be used to convert a linearly
|
|
133
|
+
sampled frequency spectra (from DFT or STFT) into mel-scaled frequency bins.
|
|
134
|
+
|
|
135
|
+
The mel scale is defined as: mel(f) = 2595 * log10(1 + f/700)
|
|
136
|
+
|
|
137
|
+
Inputs:
|
|
138
|
+
num_mel_bins: The number of bands in the mel spectrum (scalar, int32/int64)
|
|
139
|
+
dft_length: The size of the original DFT (scalar, int32/int64)
|
|
140
|
+
sample_rate: Samples per second of the input signal (scalar, int32/int64)
|
|
141
|
+
lower_edge_hertz: Lower bound frequency for mel spectrum (scalar, float)
|
|
142
|
+
upper_edge_hertz: Upper bound frequency for mel spectrum (scalar, float)
|
|
143
|
+
|
|
144
|
+
Attributes:
|
|
145
|
+
output_datatype: The data type of the output tensor (default: 1 = FLOAT)
|
|
146
|
+
|
|
147
|
+
Output:
|
|
148
|
+
The Mel Weight Matrix with shape [floor(dft_length/2) + 1, num_mel_bins]
|
|
149
|
+
"""
|
|
150
|
+
from ..utils.dtype import onnx_dtype_to_torch
|
|
151
|
+
|
|
152
|
+
num_mel_bins = builder.get_value(node.input[0])
|
|
153
|
+
dft_length = builder.get_value(node.input[1])
|
|
154
|
+
sample_rate = builder.get_value(node.input[2])
|
|
155
|
+
lower_edge_hertz = builder.get_value(node.input[3])
|
|
156
|
+
upper_edge_hertz = builder.get_value(node.input[4])
|
|
157
|
+
|
|
158
|
+
# Get output data type (default is 1 = FLOAT)
|
|
159
|
+
output_datatype = get_attribute(node, "output_datatype", 1)
|
|
160
|
+
output_dtype = onnx_dtype_to_torch(output_datatype)
|
|
161
|
+
if output_dtype is None:
|
|
162
|
+
output_dtype = torch.float32
|
|
163
|
+
|
|
164
|
+
def _mel_weight_matrix(
|
|
165
|
+
num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz, dtype
|
|
166
|
+
):
|
|
167
|
+
# Convert inputs to Python scalars
|
|
168
|
+
n_mels = int(
|
|
169
|
+
num_mel_bins.item()
|
|
170
|
+
if isinstance(num_mel_bins, torch.Tensor)
|
|
171
|
+
else num_mel_bins
|
|
172
|
+
)
|
|
173
|
+
n_fft = int(
|
|
174
|
+
dft_length.item() if isinstance(dft_length, torch.Tensor) else dft_length
|
|
175
|
+
)
|
|
176
|
+
sr = int(
|
|
177
|
+
sample_rate.item() if isinstance(sample_rate, torch.Tensor) else sample_rate
|
|
178
|
+
)
|
|
179
|
+
f_min = float(
|
|
180
|
+
lower_edge_hertz.item()
|
|
181
|
+
if isinstance(lower_edge_hertz, torch.Tensor)
|
|
182
|
+
else lower_edge_hertz
|
|
183
|
+
)
|
|
184
|
+
f_max = float(
|
|
185
|
+
upper_edge_hertz.item()
|
|
186
|
+
if isinstance(upper_edge_hertz, torch.Tensor)
|
|
187
|
+
else upper_edge_hertz
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Number of spectrogram bins (one-sided DFT)
|
|
191
|
+
num_spectrogram_bins = n_fft // 2 + 1
|
|
192
|
+
|
|
193
|
+
# Create frequency bin indices (n_mels + 2 points for n_mels triangular filters)
|
|
194
|
+
frequency_bins = torch.arange(0, n_mels + 2, dtype=torch.float32)
|
|
195
|
+
|
|
196
|
+
# Convert edge frequencies to mel scale
|
|
197
|
+
low_frequency_mel = 2595.0 * torch.log10(torch.tensor(1.0 + f_min / 700.0))
|
|
198
|
+
high_frequency_mel = 2595.0 * torch.log10(torch.tensor(1.0 + f_max / 700.0))
|
|
199
|
+
|
|
200
|
+
# Calculate mel step
|
|
201
|
+
mel_step = (high_frequency_mel - low_frequency_mel) / (frequency_bins.shape[0])
|
|
202
|
+
|
|
203
|
+
# Convert to mel frequencies
|
|
204
|
+
frequency_bins = frequency_bins * mel_step + low_frequency_mel
|
|
205
|
+
|
|
206
|
+
# Convert mel frequencies back to Hz
|
|
207
|
+
frequency_bins = 700.0 * (
|
|
208
|
+
torch.pow(torch.tensor(10.0), frequency_bins / 2595.0) - 1.0
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Convert Hz frequencies to FFT bin indices
|
|
212
|
+
frequency_bins = ((n_fft + 1) * frequency_bins) // sr
|
|
213
|
+
frequency_bins = frequency_bins.to(torch.int64)
|
|
214
|
+
|
|
215
|
+
# Create the filterbank matrix
|
|
216
|
+
output = torch.zeros(num_spectrogram_bins, n_mels, dtype=torch.float32)
|
|
217
|
+
|
|
218
|
+
for i in range(n_mels):
|
|
219
|
+
lower_frequency_value = frequency_bins[i].item() # left
|
|
220
|
+
center_frequency_point = frequency_bins[i + 1].item() # center
|
|
221
|
+
higher_frequency_point = frequency_bins[i + 2].item() # right
|
|
222
|
+
|
|
223
|
+
low_to_center = center_frequency_point - lower_frequency_value
|
|
224
|
+
if low_to_center == 0:
|
|
225
|
+
output[center_frequency_point, i] = 1.0
|
|
226
|
+
else:
|
|
227
|
+
for j in range(lower_frequency_value, center_frequency_point + 1):
|
|
228
|
+
output[j, i] = float(j - lower_frequency_value) / float(
|
|
229
|
+
low_to_center
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
center_to_high = higher_frequency_point - center_frequency_point
|
|
233
|
+
if center_to_high > 0:
|
|
234
|
+
for j in range(center_frequency_point, higher_frequency_point):
|
|
235
|
+
output[j, i] = float(higher_frequency_point - j) / float(
|
|
236
|
+
center_to_high
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return output.to(dtype)
|
|
240
|
+
|
|
241
|
+
return builder.call_function(
|
|
242
|
+
_mel_weight_matrix,
|
|
243
|
+
args=(
|
|
244
|
+
num_mel_bins,
|
|
245
|
+
dft_length,
|
|
246
|
+
sample_rate,
|
|
247
|
+
lower_edge_hertz,
|
|
248
|
+
upper_edge_hertz,
|
|
249
|
+
output_dtype,
|
|
250
|
+
),
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# =============================================================================
|
|
255
|
+
# Window function operators
|
|
256
|
+
# =============================================================================
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@register("HannWindow", since_version=17)
|
|
260
|
+
def hann_window(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
261
|
+
"""Generate a Hann window.
|
|
262
|
+
|
|
263
|
+
Attributes:
|
|
264
|
+
periodic: If 1, returns periodic window. If 0, returns symmetric window.
|
|
265
|
+
output_datatype: ONNX TensorProto data type for output (default: FLOAT).
|
|
266
|
+
"""
|
|
267
|
+
from ..utils.dtype import onnx_dtype_to_torch
|
|
268
|
+
|
|
269
|
+
size = builder.get_value(node.input[0])
|
|
270
|
+
periodic = get_attribute(node, "periodic", 1)
|
|
271
|
+
output_datatype = get_attribute(node, "output_datatype", 1) # Default: FLOAT
|
|
272
|
+
|
|
273
|
+
dtype = onnx_dtype_to_torch(output_datatype)
|
|
274
|
+
|
|
275
|
+
def _hann_window(
|
|
276
|
+
window_length: torch.Tensor, periodic: bool, dtype: torch.dtype
|
|
277
|
+
) -> torch.Tensor:
|
|
278
|
+
length = int(window_length.item())
|
|
279
|
+
return torch.hann_window(length, periodic=periodic, dtype=dtype)
|
|
280
|
+
|
|
281
|
+
return builder.call_function(_hann_window, args=(size, bool(periodic), dtype))
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@register("HammingWindow", since_version=17)
|
|
285
|
+
def hamming_window(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
286
|
+
"""Generate a Hamming window.
|
|
287
|
+
|
|
288
|
+
Attributes:
|
|
289
|
+
periodic: If 1, returns periodic window. If 0, returns symmetric window.
|
|
290
|
+
output_datatype: ONNX TensorProto data type for output (default: FLOAT).
|
|
291
|
+
|
|
292
|
+
Note:
|
|
293
|
+
ONNX uses specific Hamming coefficients (alpha=0.543478, beta=0.456522)
|
|
294
|
+
which differ from PyTorch's defaults (0.54, 0.46).
|
|
295
|
+
"""
|
|
296
|
+
from ..utils.dtype import onnx_dtype_to_torch
|
|
297
|
+
|
|
298
|
+
size = builder.get_value(node.input[0])
|
|
299
|
+
periodic = get_attribute(node, "periodic", 1)
|
|
300
|
+
output_datatype = get_attribute(node, "output_datatype", 1) # Default: FLOAT
|
|
301
|
+
|
|
302
|
+
dtype = onnx_dtype_to_torch(output_datatype)
|
|
303
|
+
|
|
304
|
+
# ONNX HammingWindow uses these specific coefficients
|
|
305
|
+
alpha = 0.543478
|
|
306
|
+
beta = 0.456522
|
|
307
|
+
|
|
308
|
+
def _hamming_window(
|
|
309
|
+
window_length: torch.Tensor, periodic: bool, dtype: torch.dtype
|
|
310
|
+
) -> torch.Tensor:
|
|
311
|
+
length = int(window_length.item())
|
|
312
|
+
return torch.hamming_window(
|
|
313
|
+
length, periodic=periodic, alpha=alpha, beta=beta, dtype=dtype
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return builder.call_function(_hamming_window, args=(size, bool(periodic), dtype))
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@register("BlackmanWindow", since_version=17)
|
|
320
|
+
def blackman_window(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
321
|
+
"""Generate a Blackman window.
|
|
322
|
+
|
|
323
|
+
Attributes:
|
|
324
|
+
periodic: If 1, returns periodic window. If 0, returns symmetric window.
|
|
325
|
+
output_datatype: ONNX TensorProto data type for output (default: FLOAT).
|
|
326
|
+
"""
|
|
327
|
+
from ..utils.dtype import onnx_dtype_to_torch
|
|
328
|
+
|
|
329
|
+
size = builder.get_value(node.input[0])
|
|
330
|
+
periodic = get_attribute(node, "periodic", 1)
|
|
331
|
+
output_datatype = get_attribute(node, "output_datatype", 1) # Default: FLOAT
|
|
332
|
+
|
|
333
|
+
dtype = onnx_dtype_to_torch(output_datatype)
|
|
334
|
+
|
|
335
|
+
def _blackman_window(
|
|
336
|
+
window_length: torch.Tensor, periodic: bool, dtype: torch.dtype
|
|
337
|
+
) -> torch.Tensor:
|
|
338
|
+
length = int(window_length.item())
|
|
339
|
+
return torch.blackman_window(length, periodic=periodic, dtype=dtype)
|
|
340
|
+
|
|
341
|
+
return builder.call_function(_blackman_window, args=(size, bool(periodic), dtype))
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# =============================================================================
|
|
345
|
+
# Non-maximum suppression
|
|
346
|
+
# =============================================================================
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
@register("NonMaxSuppression")
|
|
350
|
+
def nms(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
351
|
+
"""Non-maximum suppression for object detection."""
|
|
352
|
+
boxes = builder.get_value(node.input[0])
|
|
353
|
+
scores = builder.get_value(node.input[1])
|
|
354
|
+
|
|
355
|
+
max_output = get_optional_input(builder, node, 2)
|
|
356
|
+
iou_threshold = get_optional_input(builder, node, 3)
|
|
357
|
+
score_threshold = get_optional_input(builder, node, 4)
|
|
358
|
+
|
|
359
|
+
# Set defaults if not provided
|
|
360
|
+
if iou_threshold is None:
|
|
361
|
+
iou_threshold = 0.0
|
|
362
|
+
if score_threshold is None:
|
|
363
|
+
score_threshold = float("-inf")
|
|
364
|
+
|
|
365
|
+
center_point_box = get_attribute(node, "center_point_box", 0)
|
|
366
|
+
|
|
367
|
+
def _nms(
|
|
368
|
+
boxes, scores, max_output, iou_threshold, score_threshold, center_point_box
|
|
369
|
+
):
|
|
370
|
+
from torchvision.ops import nms as tv_nms
|
|
371
|
+
|
|
372
|
+
batch_size = boxes.shape[0]
|
|
373
|
+
num_classes = scores.shape[1]
|
|
374
|
+
|
|
375
|
+
iou_th = (
|
|
376
|
+
iou_threshold.item()
|
|
377
|
+
if isinstance(iou_threshold, torch.Tensor)
|
|
378
|
+
else iou_threshold
|
|
379
|
+
)
|
|
380
|
+
score_th = (
|
|
381
|
+
score_threshold.item()
|
|
382
|
+
if isinstance(score_threshold, torch.Tensor)
|
|
383
|
+
else score_threshold
|
|
384
|
+
)
|
|
385
|
+
max_out = (
|
|
386
|
+
max_output.item()
|
|
387
|
+
if isinstance(max_output, torch.Tensor) and max_output is not None
|
|
388
|
+
else max_output
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
results = []
|
|
392
|
+
for batch_idx in range(batch_size):
|
|
393
|
+
batch_boxes = boxes[batch_idx] # [num_boxes, 4]
|
|
394
|
+
|
|
395
|
+
# Convert center format to corner format if needed
|
|
396
|
+
if center_point_box:
|
|
397
|
+
cx, cy, w, h = (
|
|
398
|
+
batch_boxes[:, 0],
|
|
399
|
+
batch_boxes[:, 1],
|
|
400
|
+
batch_boxes[:, 2],
|
|
401
|
+
batch_boxes[:, 3],
|
|
402
|
+
)
|
|
403
|
+
batch_boxes = torch.stack(
|
|
404
|
+
[cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], dim=1
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
for class_idx in range(num_classes):
|
|
408
|
+
class_scores = scores[batch_idx, class_idx] # [num_boxes]
|
|
409
|
+
|
|
410
|
+
# Filter by score threshold
|
|
411
|
+
mask = class_scores > score_th
|
|
412
|
+
if not mask.any():
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
filtered_boxes = batch_boxes[mask]
|
|
416
|
+
filtered_scores = class_scores[mask]
|
|
417
|
+
|
|
418
|
+
# Apply NMS
|
|
419
|
+
keep = tv_nms(filtered_boxes, filtered_scores, iou_th)
|
|
420
|
+
|
|
421
|
+
if max_out is not None:
|
|
422
|
+
keep = keep[: int(max_out)]
|
|
423
|
+
|
|
424
|
+
# Get original indices
|
|
425
|
+
original_indices = torch.where(mask)[0][keep]
|
|
426
|
+
|
|
427
|
+
for idx in original_indices:
|
|
428
|
+
results.append([batch_idx, class_idx, idx.item()])
|
|
429
|
+
|
|
430
|
+
if len(results) == 0:
|
|
431
|
+
return torch.zeros((0, 3), dtype=torch.int64)
|
|
432
|
+
return torch.tensor(results, dtype=torch.int64)
|
|
433
|
+
|
|
434
|
+
return builder.call_function(
|
|
435
|
+
_nms,
|
|
436
|
+
args=(
|
|
437
|
+
boxes,
|
|
438
|
+
scores,
|
|
439
|
+
max_output,
|
|
440
|
+
iou_threshold,
|
|
441
|
+
score_threshold,
|
|
442
|
+
center_point_box,
|
|
443
|
+
),
|
|
444
|
+
)
|
onnx2fx/ops/string.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""String operators.
|
|
3
|
+
|
|
4
|
+
This module implements ONNX operators for string processing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import onnx
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from ..op_registry import register
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from ..graph_builder import GraphBuilder
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# String operators
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@register("StringNormalizer")
|
|
24
|
+
def string_normalizer(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
25
|
+
"""String normalization for numpy string arrays.
|
|
26
|
+
|
|
27
|
+
Applies case normalization and stopword filtering to string arrays.
|
|
28
|
+
"""
|
|
29
|
+
from ..utils.attributes import get_attribute
|
|
30
|
+
|
|
31
|
+
x = builder.get_value(node.input[0])
|
|
32
|
+
case_change_action = get_attribute(node, "case_change_action", "NONE")
|
|
33
|
+
is_case_sensitive = get_attribute(node, "is_case_sensitive", 0)
|
|
34
|
+
stopwords = get_attribute(node, "stopwords", [])
|
|
35
|
+
locale = get_attribute(node, "locale", "")
|
|
36
|
+
|
|
37
|
+
def _string_normalizer(
|
|
38
|
+
arr, case_action: str, case_sensitive: int, stops: list, loc: str
|
|
39
|
+
):
|
|
40
|
+
import numpy as np
|
|
41
|
+
|
|
42
|
+
original_shape = arr.shape
|
|
43
|
+
is_2d = len(original_shape) == 2
|
|
44
|
+
|
|
45
|
+
if is_2d:
|
|
46
|
+
# Process each row separately for 2D arrays
|
|
47
|
+
result_rows = []
|
|
48
|
+
for row_idx in range(original_shape[0]):
|
|
49
|
+
row = arr[row_idx]
|
|
50
|
+
# Filter stopwords from this row
|
|
51
|
+
if stops:
|
|
52
|
+
if case_sensitive:
|
|
53
|
+
stop_set = set(stops)
|
|
54
|
+
filtered = [s for s in row if s not in stop_set]
|
|
55
|
+
else:
|
|
56
|
+
stop_set_lower = {s.lower() for s in stops}
|
|
57
|
+
filtered = [
|
|
58
|
+
s
|
|
59
|
+
for s in row
|
|
60
|
+
if not isinstance(s, str) or s.lower() not in stop_set_lower
|
|
61
|
+
]
|
|
62
|
+
else:
|
|
63
|
+
filtered = list(row)
|
|
64
|
+
|
|
65
|
+
# Apply case change
|
|
66
|
+
if case_action == "LOWER":
|
|
67
|
+
filtered = [
|
|
68
|
+
s.lower() if isinstance(s, str) else s for s in filtered
|
|
69
|
+
]
|
|
70
|
+
elif case_action == "UPPER":
|
|
71
|
+
filtered = [
|
|
72
|
+
s.upper() if isinstance(s, str) else s for s in filtered
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
result_rows.append(filtered)
|
|
76
|
+
|
|
77
|
+
# Find max length and build output
|
|
78
|
+
if result_rows:
|
|
79
|
+
# Pad shorter rows if needed (shouldn't happen for valid ONNX)
|
|
80
|
+
output = np.array(result_rows, dtype=object)
|
|
81
|
+
return output
|
|
82
|
+
return np.array([[]], dtype=object)
|
|
83
|
+
|
|
84
|
+
# 1D case
|
|
85
|
+
flat = arr.flatten()
|
|
86
|
+
|
|
87
|
+
# Filter stopwords first (before case change for matching)
|
|
88
|
+
if stops:
|
|
89
|
+
if case_sensitive:
|
|
90
|
+
stop_set = set(stops)
|
|
91
|
+
flat = np.array([s for s in flat if s not in stop_set], dtype=object)
|
|
92
|
+
else:
|
|
93
|
+
stop_set_lower = {s.lower() for s in stops}
|
|
94
|
+
result = []
|
|
95
|
+
for s in flat:
|
|
96
|
+
if isinstance(s, str):
|
|
97
|
+
if s.lower() not in stop_set_lower:
|
|
98
|
+
result.append(s)
|
|
99
|
+
else:
|
|
100
|
+
result.append(s)
|
|
101
|
+
flat = (
|
|
102
|
+
np.array(result, dtype=object)
|
|
103
|
+
if result
|
|
104
|
+
else np.array([], dtype=object)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Apply case change after filtering
|
|
108
|
+
if case_action == "LOWER":
|
|
109
|
+
flat = np.array(
|
|
110
|
+
[s.lower() if isinstance(s, str) else s for s in flat], dtype=object
|
|
111
|
+
)
|
|
112
|
+
elif case_action == "UPPER":
|
|
113
|
+
flat = np.array(
|
|
114
|
+
[s.upper() if isinstance(s, str) else s for s in flat], dtype=object
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Handle empty result - ONNX spec says return [''] for empty
|
|
118
|
+
if len(flat) == 0:
|
|
119
|
+
return np.array([""], dtype=object)
|
|
120
|
+
|
|
121
|
+
return flat
|
|
122
|
+
|
|
123
|
+
return builder.call_function(
|
|
124
|
+
_string_normalizer,
|
|
125
|
+
args=(x, case_change_action, is_case_sensitive, stopwords, locale),
|
|
126
|
+
)
|