onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/recurrent.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Recurrent neural network operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..op_registry import register
|
|
10
|
+
from ..utils.attributes import get_attribute
|
|
11
|
+
from ..utils.op_helpers import get_optional_input
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..graph_builder import GraphBuilder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# =============================================================================
|
|
18
|
+
# Recurrent neural network operators
|
|
19
|
+
# =============================================================================
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@register("LSTM")
|
|
23
|
+
def lstm(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
24
|
+
"""LSTM (Long Short-Term Memory) operator.
|
|
25
|
+
|
|
26
|
+
Computes an one-layer LSTM.
|
|
27
|
+
|
|
28
|
+
ONNX LSTM Inputs:
|
|
29
|
+
- X: input tensor [seq_length, batch_size, input_size] (layout=0)
|
|
30
|
+
or [batch_size, seq_length, input_size] (layout=1)
|
|
31
|
+
- W: weight tensor [num_directions, 4*hidden_size, input_size]
|
|
32
|
+
- R: recurrence weight [num_directions, 4*hidden_size, hidden_size]
|
|
33
|
+
- B (optional): bias [num_directions, 8*hidden_size]
|
|
34
|
+
- sequence_lens (optional): [batch_size]
|
|
35
|
+
- initial_h (optional): [num_directions, batch_size, hidden_size]
|
|
36
|
+
- initial_c (optional): [num_directions, batch_size, hidden_size]
|
|
37
|
+
- P (optional): peephole weights [num_directions, 3*hidden_size]
|
|
38
|
+
|
|
39
|
+
ONNX LSTM Outputs:
|
|
40
|
+
- Y (optional): [seq_length, num_directions, batch_size, hidden_size]
|
|
41
|
+
- Y_h (optional): [num_directions, batch_size, hidden_size]
|
|
42
|
+
- Y_c (optional): [num_directions, batch_size, hidden_size]
|
|
43
|
+
|
|
44
|
+
Equations (Default: f=Sigmoid, g=Tanh, h=Tanh):
|
|
45
|
+
- it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
|
46
|
+
- ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
|
47
|
+
- ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
|
48
|
+
- Ct = ft (.) Ct-1 + it (.) ct
|
|
49
|
+
- ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
|
50
|
+
- Ht = ot (.) h(Ct)
|
|
51
|
+
"""
|
|
52
|
+
# Get inputs
|
|
53
|
+
x = builder.get_value(node.input[0])
|
|
54
|
+
w = builder.get_value(node.input[1])
|
|
55
|
+
r = builder.get_value(node.input[2])
|
|
56
|
+
|
|
57
|
+
# Optional inputs
|
|
58
|
+
b = get_optional_input(builder, node, 3)
|
|
59
|
+
sequence_lens = get_optional_input(builder, node, 4)
|
|
60
|
+
initial_h = get_optional_input(builder, node, 5)
|
|
61
|
+
initial_c = get_optional_input(builder, node, 6)
|
|
62
|
+
peepholes = get_optional_input(builder, node, 7)
|
|
63
|
+
|
|
64
|
+
# Get attributes
|
|
65
|
+
hidden_size = get_attribute(node, "hidden_size")
|
|
66
|
+
direction = get_attribute(node, "direction", "forward")
|
|
67
|
+
layout = get_attribute(node, "layout", 0)
|
|
68
|
+
input_forget = get_attribute(node, "input_forget", 0)
|
|
69
|
+
# activations = get_attribute(node, "activations", ["Sigmoid", "Tanh", "Tanh"])
|
|
70
|
+
# clip = get_attribute(node, "clip", None)
|
|
71
|
+
|
|
72
|
+
# Determine output requirements
|
|
73
|
+
output_y = len(node.output) > 0 and node.output[0] != ""
|
|
74
|
+
output_y_h = len(node.output) > 1 and node.output[1] != ""
|
|
75
|
+
output_y_c = len(node.output) > 2 and node.output[2] != ""
|
|
76
|
+
|
|
77
|
+
def _lstm_impl(
|
|
78
|
+
x,
|
|
79
|
+
w,
|
|
80
|
+
r,
|
|
81
|
+
b,
|
|
82
|
+
sequence_lens,
|
|
83
|
+
initial_h,
|
|
84
|
+
initial_c,
|
|
85
|
+
peepholes,
|
|
86
|
+
hidden_size,
|
|
87
|
+
direction,
|
|
88
|
+
layout,
|
|
89
|
+
input_forget,
|
|
90
|
+
output_y,
|
|
91
|
+
output_y_h,
|
|
92
|
+
output_y_c,
|
|
93
|
+
):
|
|
94
|
+
# Handle layout: convert to seq_first format for processing
|
|
95
|
+
# layout=0: [seq_length, batch_size, input_size]
|
|
96
|
+
# layout=1: [batch_size, seq_length, input_size]
|
|
97
|
+
if layout == 1:
|
|
98
|
+
x = x.transpose(0, 1)
|
|
99
|
+
|
|
100
|
+
seq_length, batch_size, input_size = x.shape
|
|
101
|
+
num_directions = 2 if direction == "bidirectional" else 1
|
|
102
|
+
|
|
103
|
+
# Initialize hidden state if not provided
|
|
104
|
+
if initial_h is None:
|
|
105
|
+
initial_h = torch.zeros(
|
|
106
|
+
num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Initialize cell state if not provided
|
|
110
|
+
if initial_c is None:
|
|
111
|
+
initial_c = torch.zeros(
|
|
112
|
+
num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Process each direction
|
|
116
|
+
all_y = []
|
|
117
|
+
all_y_h = []
|
|
118
|
+
all_y_c = []
|
|
119
|
+
|
|
120
|
+
for dir_idx in range(num_directions):
|
|
121
|
+
# Get weights for this direction
|
|
122
|
+
# W shape: [num_directions, 4*hidden_size, input_size]
|
|
123
|
+
# ONNX order: [Wi, Wo, Wf, Wc] concatenated (input, output, forget, cell)
|
|
124
|
+
w_dir = w[dir_idx] # [4*hidden_size, input_size]
|
|
125
|
+
w_i = w_dir[0:hidden_size, :] # [hidden_size, input_size]
|
|
126
|
+
w_o = w_dir[hidden_size : 2 * hidden_size, :]
|
|
127
|
+
w_f = w_dir[2 * hidden_size : 3 * hidden_size, :]
|
|
128
|
+
w_c = w_dir[3 * hidden_size : 4 * hidden_size, :]
|
|
129
|
+
|
|
130
|
+
# R shape: [num_directions, 4*hidden_size, hidden_size]
|
|
131
|
+
r_dir = r[dir_idx] # [4*hidden_size, hidden_size]
|
|
132
|
+
r_i = r_dir[0:hidden_size, :] # [hidden_size, hidden_size]
|
|
133
|
+
r_o = r_dir[hidden_size : 2 * hidden_size, :]
|
|
134
|
+
r_f = r_dir[2 * hidden_size : 3 * hidden_size, :]
|
|
135
|
+
r_c = r_dir[3 * hidden_size : 4 * hidden_size, :]
|
|
136
|
+
|
|
137
|
+
# Biases (optional)
|
|
138
|
+
# B shape: [num_directions, 8*hidden_size]
|
|
139
|
+
# = [Wb_i, Wb_o, Wb_f, Wb_c, Rb_i, Rb_o, Rb_f, Rb_c]
|
|
140
|
+
if b is not None:
|
|
141
|
+
b_dir = b[dir_idx] # [8*hidden_size]
|
|
142
|
+
wb_i = b_dir[0:hidden_size]
|
|
143
|
+
wb_o = b_dir[hidden_size : 2 * hidden_size]
|
|
144
|
+
wb_f = b_dir[2 * hidden_size : 3 * hidden_size]
|
|
145
|
+
wb_c = b_dir[3 * hidden_size : 4 * hidden_size]
|
|
146
|
+
rb_i = b_dir[4 * hidden_size : 5 * hidden_size]
|
|
147
|
+
rb_o = b_dir[5 * hidden_size : 6 * hidden_size]
|
|
148
|
+
rb_f = b_dir[6 * hidden_size : 7 * hidden_size]
|
|
149
|
+
rb_c = b_dir[7 * hidden_size : 8 * hidden_size]
|
|
150
|
+
else:
|
|
151
|
+
wb_i = wb_o = wb_f = wb_c = rb_i = rb_o = rb_f = rb_c = 0
|
|
152
|
+
|
|
153
|
+
# Peepholes (optional)
|
|
154
|
+
# P shape: [num_directions, 3*hidden_size] = [Pi, Po, Pf]
|
|
155
|
+
if peepholes is not None:
|
|
156
|
+
p_dir = peepholes[dir_idx] # [3*hidden_size]
|
|
157
|
+
p_i = p_dir[0:hidden_size]
|
|
158
|
+
p_o = p_dir[hidden_size : 2 * hidden_size]
|
|
159
|
+
p_f = p_dir[2 * hidden_size : 3 * hidden_size]
|
|
160
|
+
else:
|
|
161
|
+
p_i = p_o = p_f = 0
|
|
162
|
+
|
|
163
|
+
# Initial hidden state and cell state for this direction
|
|
164
|
+
h_t = initial_h[dir_idx] # [batch_size, hidden_size]
|
|
165
|
+
c_t = initial_c[dir_idx] # [batch_size, hidden_size]
|
|
166
|
+
|
|
167
|
+
# Process sequence
|
|
168
|
+
outputs = []
|
|
169
|
+
if direction == "reverse" or (
|
|
170
|
+
direction == "bidirectional" and dir_idx == 1
|
|
171
|
+
):
|
|
172
|
+
time_steps = range(seq_length - 1, -1, -1)
|
|
173
|
+
else:
|
|
174
|
+
time_steps = range(seq_length)
|
|
175
|
+
|
|
176
|
+
for t in time_steps:
|
|
177
|
+
x_t = x[t] # [batch_size, input_size]
|
|
178
|
+
|
|
179
|
+
# Compute gates
|
|
180
|
+
# it = sigmoid(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
|
181
|
+
i_t = torch.sigmoid(x_t @ w_i.T + h_t @ r_i.T + p_i * c_t + wb_i + rb_i)
|
|
182
|
+
|
|
183
|
+
# ft = sigmoid(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
|
184
|
+
f_t = torch.sigmoid(x_t @ w_f.T + h_t @ r_f.T + p_f * c_t + wb_f + rb_f)
|
|
185
|
+
|
|
186
|
+
# Handle input_forget (coupled input-forget gate)
|
|
187
|
+
if input_forget:
|
|
188
|
+
f_t = 1 - i_t
|
|
189
|
+
|
|
190
|
+
# ct = tanh(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
|
191
|
+
c_tilde = torch.tanh(x_t @ w_c.T + h_t @ r_c.T + wb_c + rb_c)
|
|
192
|
+
|
|
193
|
+
# Ct = ft (.) Ct-1 + it (.) ct
|
|
194
|
+
c_t = f_t * c_t + i_t * c_tilde
|
|
195
|
+
|
|
196
|
+
# ot = sigmoid(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
|
197
|
+
o_t = torch.sigmoid(x_t @ w_o.T + h_t @ r_o.T + p_o * c_t + wb_o + rb_o)
|
|
198
|
+
|
|
199
|
+
# Ht = ot (.) tanh(Ct)
|
|
200
|
+
h_t = o_t * torch.tanh(c_t)
|
|
201
|
+
|
|
202
|
+
outputs.append(h_t)
|
|
203
|
+
|
|
204
|
+
# Stack outputs
|
|
205
|
+
if direction == "reverse" or (
|
|
206
|
+
direction == "bidirectional" and dir_idx == 1
|
|
207
|
+
):
|
|
208
|
+
outputs = outputs[::-1]
|
|
209
|
+
|
|
210
|
+
# [seq_length, batch_size, hidden_size]
|
|
211
|
+
dir_y = torch.stack(outputs, dim=0)
|
|
212
|
+
all_y.append(dir_y)
|
|
213
|
+
all_y_h.append(h_t)
|
|
214
|
+
all_y_c.append(c_t)
|
|
215
|
+
|
|
216
|
+
# Combine directions
|
|
217
|
+
# Y: [seq_length, num_directions, batch_size, hidden_size]
|
|
218
|
+
y = torch.stack(all_y, dim=1)
|
|
219
|
+
|
|
220
|
+
# Y_h: [num_directions, batch_size, hidden_size]
|
|
221
|
+
y_h = torch.stack(all_y_h, dim=0)
|
|
222
|
+
|
|
223
|
+
# Y_c: [num_directions, batch_size, hidden_size]
|
|
224
|
+
y_c = torch.stack(all_y_c, dim=0)
|
|
225
|
+
|
|
226
|
+
# Handle layout for output
|
|
227
|
+
if layout == 1:
|
|
228
|
+
# Convert Y from [seq_length, num_directions, batch_size, hidden_size]
|
|
229
|
+
# to [batch_size, seq_length, num_directions, hidden_size]
|
|
230
|
+
y = y.permute(2, 0, 1, 3)
|
|
231
|
+
# Convert Y_h from [num_directions, batch_size, hidden_size]
|
|
232
|
+
# to [batch_size, num_directions, hidden_size]
|
|
233
|
+
y_h = y_h.transpose(0, 1)
|
|
234
|
+
# Convert Y_c from [num_directions, batch_size, hidden_size]
|
|
235
|
+
# to [batch_size, num_directions, hidden_size]
|
|
236
|
+
y_c = y_c.transpose(0, 1)
|
|
237
|
+
|
|
238
|
+
# Return based on required outputs
|
|
239
|
+
if output_y and output_y_h and output_y_c:
|
|
240
|
+
return (y, y_h, y_c)
|
|
241
|
+
elif output_y and output_y_h:
|
|
242
|
+
return (y, y_h)
|
|
243
|
+
elif output_y and output_y_c:
|
|
244
|
+
return (y, y_c)
|
|
245
|
+
elif output_y_h and output_y_c:
|
|
246
|
+
return (y_h, y_c)
|
|
247
|
+
elif output_y:
|
|
248
|
+
return y
|
|
249
|
+
elif output_y_h:
|
|
250
|
+
return y_h
|
|
251
|
+
elif output_y_c:
|
|
252
|
+
return y_c
|
|
253
|
+
else:
|
|
254
|
+
return y_h # Default to returning Y_h
|
|
255
|
+
|
|
256
|
+
return builder.call_function(
|
|
257
|
+
_lstm_impl,
|
|
258
|
+
args=(
|
|
259
|
+
x,
|
|
260
|
+
w,
|
|
261
|
+
r,
|
|
262
|
+
b,
|
|
263
|
+
sequence_lens,
|
|
264
|
+
initial_h,
|
|
265
|
+
initial_c,
|
|
266
|
+
peepholes,
|
|
267
|
+
hidden_size,
|
|
268
|
+
direction,
|
|
269
|
+
layout,
|
|
270
|
+
input_forget,
|
|
271
|
+
output_y,
|
|
272
|
+
output_y_h,
|
|
273
|
+
output_y_c,
|
|
274
|
+
),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@register("GRU")
|
|
279
|
+
def gru(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
280
|
+
"""GRU (Gated Recurrent Unit) operator.
|
|
281
|
+
|
|
282
|
+
Computes an one-layer GRU.
|
|
283
|
+
|
|
284
|
+
ONNX GRU Inputs:
|
|
285
|
+
- X: input tensor [seq_length, batch_size, input_size] (layout=0)
|
|
286
|
+
or [batch_size, seq_length, input_size] (layout=1)
|
|
287
|
+
- W: weight tensor [num_directions, 3*hidden_size, input_size]
|
|
288
|
+
- R: recurrence weight [num_directions, 3*hidden_size, hidden_size]
|
|
289
|
+
- B (optional): bias [num_directions, 6*hidden_size]
|
|
290
|
+
- sequence_lens (optional): [batch_size]
|
|
291
|
+
- initial_h (optional): [num_directions, batch_size, hidden_size]
|
|
292
|
+
|
|
293
|
+
ONNX GRU Outputs:
|
|
294
|
+
- Y (optional): [seq_length, num_directions, batch_size, hidden_size]
|
|
295
|
+
- Y_h (optional): [num_directions, batch_size, hidden_size]
|
|
296
|
+
|
|
297
|
+
Equations (Default: f=Sigmoid, g=Tanh):
|
|
298
|
+
- zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
|
299
|
+
- rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
|
300
|
+
- ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # linear_before_reset=0
|
|
301
|
+
- ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # linear_before_reset!=0
|
|
302
|
+
- Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
|
303
|
+
"""
|
|
304
|
+
# Get inputs
|
|
305
|
+
x = builder.get_value(node.input[0])
|
|
306
|
+
w = builder.get_value(node.input[1])
|
|
307
|
+
r = builder.get_value(node.input[2])
|
|
308
|
+
|
|
309
|
+
# Optional inputs
|
|
310
|
+
b = get_optional_input(builder, node, 3)
|
|
311
|
+
sequence_lens = get_optional_input(builder, node, 4)
|
|
312
|
+
initial_h = get_optional_input(builder, node, 5)
|
|
313
|
+
|
|
314
|
+
# Get attributes
|
|
315
|
+
hidden_size = get_attribute(node, "hidden_size")
|
|
316
|
+
direction = get_attribute(node, "direction", "forward")
|
|
317
|
+
layout = get_attribute(node, "layout", 0)
|
|
318
|
+
linear_before_reset = get_attribute(node, "linear_before_reset", 0)
|
|
319
|
+
# activations = get_attribute(node, "activations", ["Sigmoid", "Tanh"])
|
|
320
|
+
# clip = get_attribute(node, "clip", None)
|
|
321
|
+
|
|
322
|
+
# Determine output requirements
|
|
323
|
+
output_y = len(node.output) > 0 and node.output[0] != ""
|
|
324
|
+
output_y_h = len(node.output) > 1 and node.output[1] != ""
|
|
325
|
+
|
|
326
|
+
def _gru_impl(
|
|
327
|
+
x,
|
|
328
|
+
w,
|
|
329
|
+
r,
|
|
330
|
+
b,
|
|
331
|
+
sequence_lens,
|
|
332
|
+
initial_h,
|
|
333
|
+
hidden_size,
|
|
334
|
+
direction,
|
|
335
|
+
layout,
|
|
336
|
+
linear_before_reset,
|
|
337
|
+
output_y,
|
|
338
|
+
output_y_h,
|
|
339
|
+
):
|
|
340
|
+
# Handle layout: convert to seq_first format for processing
|
|
341
|
+
# layout=0: [seq_length, batch_size, input_size]
|
|
342
|
+
# layout=1: [batch_size, seq_length, input_size]
|
|
343
|
+
if layout == 1:
|
|
344
|
+
x = x.transpose(0, 1)
|
|
345
|
+
|
|
346
|
+
seq_length, batch_size, input_size = x.shape
|
|
347
|
+
num_directions = 2 if direction == "bidirectional" else 1
|
|
348
|
+
|
|
349
|
+
# Initialize hidden state if not provided
|
|
350
|
+
if initial_h is None:
|
|
351
|
+
initial_h = torch.zeros(
|
|
352
|
+
num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Process each direction
|
|
356
|
+
all_y = []
|
|
357
|
+
all_y_h = []
|
|
358
|
+
|
|
359
|
+
for dir_idx in range(num_directions):
|
|
360
|
+
# Get weights for this direction
|
|
361
|
+
# W shape: [num_directions, 3*hidden_size, input_size]
|
|
362
|
+
# ONNX order: [Wz, Wr, Wh] concatenated
|
|
363
|
+
w_dir = w[dir_idx] # [3*hidden_size, input_size]
|
|
364
|
+
w_z = w_dir[0:hidden_size, :] # [hidden_size, input_size]
|
|
365
|
+
w_r = w_dir[hidden_size : 2 * hidden_size, :]
|
|
366
|
+
w_h = w_dir[2 * hidden_size : 3 * hidden_size, :]
|
|
367
|
+
|
|
368
|
+
# R shape: [num_directions, 3*hidden_size, hidden_size]
|
|
369
|
+
r_dir = r[dir_idx] # [3*hidden_size, hidden_size]
|
|
370
|
+
r_z = r_dir[0:hidden_size, :] # [hidden_size, hidden_size]
|
|
371
|
+
r_r = r_dir[hidden_size : 2 * hidden_size, :]
|
|
372
|
+
r_h = r_dir[2 * hidden_size : 3 * hidden_size, :]
|
|
373
|
+
|
|
374
|
+
# Biases (optional)
|
|
375
|
+
# B shape: [num_directions, 6*hidden_size] = [Wb_z, Wb_r, Wb_h, Rb_z, Rb_r, Rb_h]
|
|
376
|
+
if b is not None:
|
|
377
|
+
b_dir = b[dir_idx] # [6*hidden_size]
|
|
378
|
+
wb_z = b_dir[0:hidden_size]
|
|
379
|
+
wb_r = b_dir[hidden_size : 2 * hidden_size]
|
|
380
|
+
wb_h = b_dir[2 * hidden_size : 3 * hidden_size]
|
|
381
|
+
rb_z = b_dir[3 * hidden_size : 4 * hidden_size]
|
|
382
|
+
rb_r = b_dir[4 * hidden_size : 5 * hidden_size]
|
|
383
|
+
rb_h = b_dir[5 * hidden_size : 6 * hidden_size]
|
|
384
|
+
else:
|
|
385
|
+
wb_z = wb_r = wb_h = rb_z = rb_r = rb_h = 0
|
|
386
|
+
|
|
387
|
+
# Initial hidden state for this direction
|
|
388
|
+
h_t = initial_h[dir_idx] # [batch_size, hidden_size]
|
|
389
|
+
|
|
390
|
+
# Process sequence
|
|
391
|
+
outputs = []
|
|
392
|
+
if direction == "reverse" or (
|
|
393
|
+
direction == "bidirectional" and dir_idx == 1
|
|
394
|
+
):
|
|
395
|
+
time_steps = range(seq_length - 1, -1, -1)
|
|
396
|
+
else:
|
|
397
|
+
time_steps = range(seq_length)
|
|
398
|
+
|
|
399
|
+
for t in time_steps:
|
|
400
|
+
x_t = x[t] # [batch_size, input_size]
|
|
401
|
+
|
|
402
|
+
# Compute gates
|
|
403
|
+
# zt = sigmoid(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
|
404
|
+
z_t = torch.sigmoid(
|
|
405
|
+
x_t @ w_z.T + h_t @ r_z.T + wb_z + rb_z
|
|
406
|
+
) # [batch_size, hidden_size]
|
|
407
|
+
|
|
408
|
+
# rt = sigmoid(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
|
409
|
+
r_t = torch.sigmoid(x_t @ w_r.T + h_t @ r_r.T + wb_r + rb_r)
|
|
410
|
+
|
|
411
|
+
# ht = tanh(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # linear_before_reset=0
|
|
412
|
+
# ht = tanh(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # linear_before_reset!=0
|
|
413
|
+
if linear_before_reset:
|
|
414
|
+
h_tilde = torch.tanh(
|
|
415
|
+
x_t @ w_h.T + r_t * (h_t @ r_h.T + rb_h) + wb_h
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
h_tilde = torch.tanh(
|
|
419
|
+
x_t @ w_h.T + (r_t * h_t) @ r_h.T + rb_h + wb_h
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
|
423
|
+
h_t = (1 - z_t) * h_tilde + z_t * h_t
|
|
424
|
+
|
|
425
|
+
outputs.append(h_t)
|
|
426
|
+
|
|
427
|
+
# Stack outputs
|
|
428
|
+
if direction == "reverse" or (
|
|
429
|
+
direction == "bidirectional" and dir_idx == 1
|
|
430
|
+
):
|
|
431
|
+
outputs = outputs[::-1]
|
|
432
|
+
|
|
433
|
+
# [seq_length, batch_size, hidden_size]
|
|
434
|
+
dir_y = torch.stack(outputs, dim=0)
|
|
435
|
+
all_y.append(dir_y)
|
|
436
|
+
all_y_h.append(h_t)
|
|
437
|
+
|
|
438
|
+
# Combine directions
|
|
439
|
+
# Y: [seq_length, num_directions, batch_size, hidden_size]
|
|
440
|
+
y = torch.stack(all_y, dim=1)
|
|
441
|
+
|
|
442
|
+
# Y_h: [num_directions, batch_size, hidden_size]
|
|
443
|
+
y_h = torch.stack(all_y_h, dim=0)
|
|
444
|
+
|
|
445
|
+
# Handle layout for output
|
|
446
|
+
if layout == 1:
|
|
447
|
+
# Convert Y from [seq_length, num_directions, batch_size, hidden_size]
|
|
448
|
+
# to [batch_size, seq_length, num_directions, hidden_size]
|
|
449
|
+
y = y.permute(2, 0, 1, 3)
|
|
450
|
+
# Convert Y_h from [num_directions, batch_size, hidden_size]
|
|
451
|
+
# to [batch_size, num_directions, hidden_size]
|
|
452
|
+
y_h = y_h.transpose(0, 1)
|
|
453
|
+
|
|
454
|
+
# Return based on required outputs
|
|
455
|
+
if output_y and output_y_h:
|
|
456
|
+
return (y, y_h)
|
|
457
|
+
elif output_y:
|
|
458
|
+
return y
|
|
459
|
+
elif output_y_h:
|
|
460
|
+
return y_h
|
|
461
|
+
else:
|
|
462
|
+
return y_h # Default to returning Y_h
|
|
463
|
+
|
|
464
|
+
return builder.call_function(
|
|
465
|
+
_gru_impl,
|
|
466
|
+
args=(
|
|
467
|
+
x,
|
|
468
|
+
w,
|
|
469
|
+
r,
|
|
470
|
+
b,
|
|
471
|
+
sequence_lens,
|
|
472
|
+
initial_h,
|
|
473
|
+
hidden_size,
|
|
474
|
+
direction,
|
|
475
|
+
layout,
|
|
476
|
+
linear_before_reset,
|
|
477
|
+
output_y,
|
|
478
|
+
output_y_h,
|
|
479
|
+
),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@register("RNN")
|
|
484
|
+
def rnn(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
485
|
+
"""RNN (Simple Recurrent Neural Network) operator.
|
|
486
|
+
|
|
487
|
+
Computes an one-layer simple RNN.
|
|
488
|
+
|
|
489
|
+
ONNX RNN Inputs:
|
|
490
|
+
- X: input tensor [seq_length, batch_size, input_size] (layout=0)
|
|
491
|
+
or [batch_size, seq_length, input_size] (layout=1)
|
|
492
|
+
- W: weight tensor [num_directions, hidden_size, input_size]
|
|
493
|
+
- R: recurrence weight [num_directions, hidden_size, hidden_size]
|
|
494
|
+
- B (optional): bias [num_directions, 2*hidden_size] = [Wbi, Rbi]
|
|
495
|
+
- sequence_lens (optional): [batch_size]
|
|
496
|
+
- initial_h (optional): [num_directions, batch_size, hidden_size]
|
|
497
|
+
|
|
498
|
+
ONNX RNN Outputs:
|
|
499
|
+
- Y (optional): [seq_length, num_directions, batch_size, hidden_size]
|
|
500
|
+
- Y_h (optional): [num_directions, batch_size, hidden_size]
|
|
501
|
+
|
|
502
|
+
Equations (Default: f=Tanh):
|
|
503
|
+
- Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
|
504
|
+
"""
|
|
505
|
+
# Get inputs
|
|
506
|
+
x = builder.get_value(node.input[0])
|
|
507
|
+
w = builder.get_value(node.input[1])
|
|
508
|
+
r = builder.get_value(node.input[2])
|
|
509
|
+
|
|
510
|
+
# Optional inputs
|
|
511
|
+
b = get_optional_input(builder, node, 3)
|
|
512
|
+
sequence_lens = get_optional_input(builder, node, 4)
|
|
513
|
+
initial_h = get_optional_input(builder, node, 5)
|
|
514
|
+
|
|
515
|
+
# Get attributes
|
|
516
|
+
hidden_size = get_attribute(node, "hidden_size")
|
|
517
|
+
direction = get_attribute(node, "direction", "forward")
|
|
518
|
+
layout = get_attribute(node, "layout", 0)
|
|
519
|
+
# activations = get_attribute(node, "activations", ["Tanh"])
|
|
520
|
+
# clip = get_attribute(node, "clip", None)
|
|
521
|
+
|
|
522
|
+
# Determine output requirements
|
|
523
|
+
output_y = len(node.output) > 0 and node.output[0] != ""
|
|
524
|
+
output_y_h = len(node.output) > 1 and node.output[1] != ""
|
|
525
|
+
|
|
526
|
+
def _rnn_impl(
|
|
527
|
+
x,
|
|
528
|
+
w,
|
|
529
|
+
r,
|
|
530
|
+
b,
|
|
531
|
+
sequence_lens,
|
|
532
|
+
initial_h,
|
|
533
|
+
hidden_size,
|
|
534
|
+
direction,
|
|
535
|
+
layout,
|
|
536
|
+
output_y,
|
|
537
|
+
output_y_h,
|
|
538
|
+
):
|
|
539
|
+
# Handle layout: convert to seq_first format for processing
|
|
540
|
+
# layout=0: [seq_length, batch_size, input_size]
|
|
541
|
+
# layout=1: [batch_size, seq_length, input_size]
|
|
542
|
+
if layout == 1:
|
|
543
|
+
x = x.transpose(0, 1)
|
|
544
|
+
|
|
545
|
+
seq_length, batch_size, input_size = x.shape
|
|
546
|
+
num_directions = 2 if direction == "bidirectional" else 1
|
|
547
|
+
|
|
548
|
+
# Initialize hidden state if not provided
|
|
549
|
+
if initial_h is None:
|
|
550
|
+
initial_h = torch.zeros(
|
|
551
|
+
num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Process each direction
|
|
555
|
+
all_y = []
|
|
556
|
+
all_y_h = []
|
|
557
|
+
|
|
558
|
+
for dir_idx in range(num_directions):
|
|
559
|
+
# Get weights for this direction
|
|
560
|
+
# W shape: [num_directions, hidden_size, input_size]
|
|
561
|
+
w_dir = w[dir_idx] # [hidden_size, input_size]
|
|
562
|
+
|
|
563
|
+
# R shape: [num_directions, hidden_size, hidden_size]
|
|
564
|
+
r_dir = r[dir_idx] # [hidden_size, hidden_size]
|
|
565
|
+
|
|
566
|
+
# Biases (optional)
|
|
567
|
+
# B shape: [num_directions, 2*hidden_size] = [Wbi, Rbi]
|
|
568
|
+
if b is not None:
|
|
569
|
+
b_dir = b[dir_idx] # [2*hidden_size]
|
|
570
|
+
wb_i = b_dir[0:hidden_size]
|
|
571
|
+
rb_i = b_dir[hidden_size : 2 * hidden_size]
|
|
572
|
+
else:
|
|
573
|
+
wb_i = rb_i = 0
|
|
574
|
+
|
|
575
|
+
# Initial hidden state for this direction
|
|
576
|
+
h_t = initial_h[dir_idx] # [batch_size, hidden_size]
|
|
577
|
+
|
|
578
|
+
# Process sequence
|
|
579
|
+
outputs = []
|
|
580
|
+
if direction == "reverse" or (
|
|
581
|
+
direction == "bidirectional" and dir_idx == 1
|
|
582
|
+
):
|
|
583
|
+
time_steps = range(seq_length - 1, -1, -1)
|
|
584
|
+
else:
|
|
585
|
+
time_steps = range(seq_length)
|
|
586
|
+
|
|
587
|
+
for t in time_steps:
|
|
588
|
+
x_t = x[t] # [batch_size, input_size]
|
|
589
|
+
|
|
590
|
+
# Compute: Ht = tanh(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
|
591
|
+
h_t = torch.tanh(x_t @ w_dir.T + h_t @ r_dir.T + wb_i + rb_i)
|
|
592
|
+
|
|
593
|
+
outputs.append(h_t)
|
|
594
|
+
|
|
595
|
+
# Stack outputs
|
|
596
|
+
if direction == "reverse" or (
|
|
597
|
+
direction == "bidirectional" and dir_idx == 1
|
|
598
|
+
):
|
|
599
|
+
outputs = outputs[::-1]
|
|
600
|
+
|
|
601
|
+
# [seq_length, batch_size, hidden_size]
|
|
602
|
+
dir_y = torch.stack(outputs, dim=0)
|
|
603
|
+
all_y.append(dir_y)
|
|
604
|
+
all_y_h.append(h_t)
|
|
605
|
+
|
|
606
|
+
# Combine directions
|
|
607
|
+
# Y: [seq_length, num_directions, batch_size, hidden_size]
|
|
608
|
+
y = torch.stack(all_y, dim=1)
|
|
609
|
+
|
|
610
|
+
# Y_h: [num_directions, batch_size, hidden_size]
|
|
611
|
+
y_h = torch.stack(all_y_h, dim=0)
|
|
612
|
+
|
|
613
|
+
# Handle layout for output
|
|
614
|
+
if layout == 1:
|
|
615
|
+
# Convert Y from [seq_length, num_directions, batch_size, hidden_size]
|
|
616
|
+
# to [batch_size, seq_length, num_directions, hidden_size]
|
|
617
|
+
y = y.permute(2, 0, 1, 3)
|
|
618
|
+
# Convert Y_h from [num_directions, batch_size, hidden_size]
|
|
619
|
+
# to [batch_size, num_directions, hidden_size]
|
|
620
|
+
y_h = y_h.transpose(0, 1)
|
|
621
|
+
|
|
622
|
+
# Return based on required outputs
|
|
623
|
+
if output_y and output_y_h:
|
|
624
|
+
return (y, y_h)
|
|
625
|
+
elif output_y:
|
|
626
|
+
return y
|
|
627
|
+
elif output_y_h:
|
|
628
|
+
return y_h
|
|
629
|
+
else:
|
|
630
|
+
return y_h # Default to returning Y_h
|
|
631
|
+
|
|
632
|
+
return builder.call_function(
|
|
633
|
+
_rnn_impl,
|
|
634
|
+
args=(
|
|
635
|
+
x,
|
|
636
|
+
w,
|
|
637
|
+
r,
|
|
638
|
+
b,
|
|
639
|
+
sequence_lens,
|
|
640
|
+
initial_h,
|
|
641
|
+
hidden_size,
|
|
642
|
+
direction,
|
|
643
|
+
layout,
|
|
644
|
+
output_y,
|
|
645
|
+
output_y_h,
|
|
646
|
+
),
|
|
647
|
+
)
|