onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,647 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Recurrent neural network operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..op_registry import register
10
+ from ..utils.attributes import get_attribute
11
+ from ..utils.op_helpers import get_optional_input
12
+
13
+ if TYPE_CHECKING:
14
+ from ..graph_builder import GraphBuilder
15
+
16
+
17
+ # =============================================================================
18
+ # Recurrent neural network operators
19
+ # =============================================================================
20
+
21
+
22
+ @register("LSTM")
23
+ def lstm(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
24
+ """LSTM (Long Short-Term Memory) operator.
25
+
26
+ Computes an one-layer LSTM.
27
+
28
+ ONNX LSTM Inputs:
29
+ - X: input tensor [seq_length, batch_size, input_size] (layout=0)
30
+ or [batch_size, seq_length, input_size] (layout=1)
31
+ - W: weight tensor [num_directions, 4*hidden_size, input_size]
32
+ - R: recurrence weight [num_directions, 4*hidden_size, hidden_size]
33
+ - B (optional): bias [num_directions, 8*hidden_size]
34
+ - sequence_lens (optional): [batch_size]
35
+ - initial_h (optional): [num_directions, batch_size, hidden_size]
36
+ - initial_c (optional): [num_directions, batch_size, hidden_size]
37
+ - P (optional): peephole weights [num_directions, 3*hidden_size]
38
+
39
+ ONNX LSTM Outputs:
40
+ - Y (optional): [seq_length, num_directions, batch_size, hidden_size]
41
+ - Y_h (optional): [num_directions, batch_size, hidden_size]
42
+ - Y_c (optional): [num_directions, batch_size, hidden_size]
43
+
44
+ Equations (Default: f=Sigmoid, g=Tanh, h=Tanh):
45
+ - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
46
+ - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
47
+ - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
48
+ - Ct = ft (.) Ct-1 + it (.) ct
49
+ - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
50
+ - Ht = ot (.) h(Ct)
51
+ """
52
+ # Get inputs
53
+ x = builder.get_value(node.input[0])
54
+ w = builder.get_value(node.input[1])
55
+ r = builder.get_value(node.input[2])
56
+
57
+ # Optional inputs
58
+ b = get_optional_input(builder, node, 3)
59
+ sequence_lens = get_optional_input(builder, node, 4)
60
+ initial_h = get_optional_input(builder, node, 5)
61
+ initial_c = get_optional_input(builder, node, 6)
62
+ peepholes = get_optional_input(builder, node, 7)
63
+
64
+ # Get attributes
65
+ hidden_size = get_attribute(node, "hidden_size")
66
+ direction = get_attribute(node, "direction", "forward")
67
+ layout = get_attribute(node, "layout", 0)
68
+ input_forget = get_attribute(node, "input_forget", 0)
69
+ # activations = get_attribute(node, "activations", ["Sigmoid", "Tanh", "Tanh"])
70
+ # clip = get_attribute(node, "clip", None)
71
+
72
+ # Determine output requirements
73
+ output_y = len(node.output) > 0 and node.output[0] != ""
74
+ output_y_h = len(node.output) > 1 and node.output[1] != ""
75
+ output_y_c = len(node.output) > 2 and node.output[2] != ""
76
+
77
+ def _lstm_impl(
78
+ x,
79
+ w,
80
+ r,
81
+ b,
82
+ sequence_lens,
83
+ initial_h,
84
+ initial_c,
85
+ peepholes,
86
+ hidden_size,
87
+ direction,
88
+ layout,
89
+ input_forget,
90
+ output_y,
91
+ output_y_h,
92
+ output_y_c,
93
+ ):
94
+ # Handle layout: convert to seq_first format for processing
95
+ # layout=0: [seq_length, batch_size, input_size]
96
+ # layout=1: [batch_size, seq_length, input_size]
97
+ if layout == 1:
98
+ x = x.transpose(0, 1)
99
+
100
+ seq_length, batch_size, input_size = x.shape
101
+ num_directions = 2 if direction == "bidirectional" else 1
102
+
103
+ # Initialize hidden state if not provided
104
+ if initial_h is None:
105
+ initial_h = torch.zeros(
106
+ num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
107
+ )
108
+
109
+ # Initialize cell state if not provided
110
+ if initial_c is None:
111
+ initial_c = torch.zeros(
112
+ num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
113
+ )
114
+
115
+ # Process each direction
116
+ all_y = []
117
+ all_y_h = []
118
+ all_y_c = []
119
+
120
+ for dir_idx in range(num_directions):
121
+ # Get weights for this direction
122
+ # W shape: [num_directions, 4*hidden_size, input_size]
123
+ # ONNX order: [Wi, Wo, Wf, Wc] concatenated (input, output, forget, cell)
124
+ w_dir = w[dir_idx] # [4*hidden_size, input_size]
125
+ w_i = w_dir[0:hidden_size, :] # [hidden_size, input_size]
126
+ w_o = w_dir[hidden_size : 2 * hidden_size, :]
127
+ w_f = w_dir[2 * hidden_size : 3 * hidden_size, :]
128
+ w_c = w_dir[3 * hidden_size : 4 * hidden_size, :]
129
+
130
+ # R shape: [num_directions, 4*hidden_size, hidden_size]
131
+ r_dir = r[dir_idx] # [4*hidden_size, hidden_size]
132
+ r_i = r_dir[0:hidden_size, :] # [hidden_size, hidden_size]
133
+ r_o = r_dir[hidden_size : 2 * hidden_size, :]
134
+ r_f = r_dir[2 * hidden_size : 3 * hidden_size, :]
135
+ r_c = r_dir[3 * hidden_size : 4 * hidden_size, :]
136
+
137
+ # Biases (optional)
138
+ # B shape: [num_directions, 8*hidden_size]
139
+ # = [Wb_i, Wb_o, Wb_f, Wb_c, Rb_i, Rb_o, Rb_f, Rb_c]
140
+ if b is not None:
141
+ b_dir = b[dir_idx] # [8*hidden_size]
142
+ wb_i = b_dir[0:hidden_size]
143
+ wb_o = b_dir[hidden_size : 2 * hidden_size]
144
+ wb_f = b_dir[2 * hidden_size : 3 * hidden_size]
145
+ wb_c = b_dir[3 * hidden_size : 4 * hidden_size]
146
+ rb_i = b_dir[4 * hidden_size : 5 * hidden_size]
147
+ rb_o = b_dir[5 * hidden_size : 6 * hidden_size]
148
+ rb_f = b_dir[6 * hidden_size : 7 * hidden_size]
149
+ rb_c = b_dir[7 * hidden_size : 8 * hidden_size]
150
+ else:
151
+ wb_i = wb_o = wb_f = wb_c = rb_i = rb_o = rb_f = rb_c = 0
152
+
153
+ # Peepholes (optional)
154
+ # P shape: [num_directions, 3*hidden_size] = [Pi, Po, Pf]
155
+ if peepholes is not None:
156
+ p_dir = peepholes[dir_idx] # [3*hidden_size]
157
+ p_i = p_dir[0:hidden_size]
158
+ p_o = p_dir[hidden_size : 2 * hidden_size]
159
+ p_f = p_dir[2 * hidden_size : 3 * hidden_size]
160
+ else:
161
+ p_i = p_o = p_f = 0
162
+
163
+ # Initial hidden state and cell state for this direction
164
+ h_t = initial_h[dir_idx] # [batch_size, hidden_size]
165
+ c_t = initial_c[dir_idx] # [batch_size, hidden_size]
166
+
167
+ # Process sequence
168
+ outputs = []
169
+ if direction == "reverse" or (
170
+ direction == "bidirectional" and dir_idx == 1
171
+ ):
172
+ time_steps = range(seq_length - 1, -1, -1)
173
+ else:
174
+ time_steps = range(seq_length)
175
+
176
+ for t in time_steps:
177
+ x_t = x[t] # [batch_size, input_size]
178
+
179
+ # Compute gates
180
+ # it = sigmoid(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
181
+ i_t = torch.sigmoid(x_t @ w_i.T + h_t @ r_i.T + p_i * c_t + wb_i + rb_i)
182
+
183
+ # ft = sigmoid(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
184
+ f_t = torch.sigmoid(x_t @ w_f.T + h_t @ r_f.T + p_f * c_t + wb_f + rb_f)
185
+
186
+ # Handle input_forget (coupled input-forget gate)
187
+ if input_forget:
188
+ f_t = 1 - i_t
189
+
190
+ # ct = tanh(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
191
+ c_tilde = torch.tanh(x_t @ w_c.T + h_t @ r_c.T + wb_c + rb_c)
192
+
193
+ # Ct = ft (.) Ct-1 + it (.) ct
194
+ c_t = f_t * c_t + i_t * c_tilde
195
+
196
+ # ot = sigmoid(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
197
+ o_t = torch.sigmoid(x_t @ w_o.T + h_t @ r_o.T + p_o * c_t + wb_o + rb_o)
198
+
199
+ # Ht = ot (.) tanh(Ct)
200
+ h_t = o_t * torch.tanh(c_t)
201
+
202
+ outputs.append(h_t)
203
+
204
+ # Stack outputs
205
+ if direction == "reverse" or (
206
+ direction == "bidirectional" and dir_idx == 1
207
+ ):
208
+ outputs = outputs[::-1]
209
+
210
+ # [seq_length, batch_size, hidden_size]
211
+ dir_y = torch.stack(outputs, dim=0)
212
+ all_y.append(dir_y)
213
+ all_y_h.append(h_t)
214
+ all_y_c.append(c_t)
215
+
216
+ # Combine directions
217
+ # Y: [seq_length, num_directions, batch_size, hidden_size]
218
+ y = torch.stack(all_y, dim=1)
219
+
220
+ # Y_h: [num_directions, batch_size, hidden_size]
221
+ y_h = torch.stack(all_y_h, dim=0)
222
+
223
+ # Y_c: [num_directions, batch_size, hidden_size]
224
+ y_c = torch.stack(all_y_c, dim=0)
225
+
226
+ # Handle layout for output
227
+ if layout == 1:
228
+ # Convert Y from [seq_length, num_directions, batch_size, hidden_size]
229
+ # to [batch_size, seq_length, num_directions, hidden_size]
230
+ y = y.permute(2, 0, 1, 3)
231
+ # Convert Y_h from [num_directions, batch_size, hidden_size]
232
+ # to [batch_size, num_directions, hidden_size]
233
+ y_h = y_h.transpose(0, 1)
234
+ # Convert Y_c from [num_directions, batch_size, hidden_size]
235
+ # to [batch_size, num_directions, hidden_size]
236
+ y_c = y_c.transpose(0, 1)
237
+
238
+ # Return based on required outputs
239
+ if output_y and output_y_h and output_y_c:
240
+ return (y, y_h, y_c)
241
+ elif output_y and output_y_h:
242
+ return (y, y_h)
243
+ elif output_y and output_y_c:
244
+ return (y, y_c)
245
+ elif output_y_h and output_y_c:
246
+ return (y_h, y_c)
247
+ elif output_y:
248
+ return y
249
+ elif output_y_h:
250
+ return y_h
251
+ elif output_y_c:
252
+ return y_c
253
+ else:
254
+ return y_h # Default to returning Y_h
255
+
256
+ return builder.call_function(
257
+ _lstm_impl,
258
+ args=(
259
+ x,
260
+ w,
261
+ r,
262
+ b,
263
+ sequence_lens,
264
+ initial_h,
265
+ initial_c,
266
+ peepholes,
267
+ hidden_size,
268
+ direction,
269
+ layout,
270
+ input_forget,
271
+ output_y,
272
+ output_y_h,
273
+ output_y_c,
274
+ ),
275
+ )
276
+
277
+
278
+ @register("GRU")
279
+ def gru(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
280
+ """GRU (Gated Recurrent Unit) operator.
281
+
282
+ Computes an one-layer GRU.
283
+
284
+ ONNX GRU Inputs:
285
+ - X: input tensor [seq_length, batch_size, input_size] (layout=0)
286
+ or [batch_size, seq_length, input_size] (layout=1)
287
+ - W: weight tensor [num_directions, 3*hidden_size, input_size]
288
+ - R: recurrence weight [num_directions, 3*hidden_size, hidden_size]
289
+ - B (optional): bias [num_directions, 6*hidden_size]
290
+ - sequence_lens (optional): [batch_size]
291
+ - initial_h (optional): [num_directions, batch_size, hidden_size]
292
+
293
+ ONNX GRU Outputs:
294
+ - Y (optional): [seq_length, num_directions, batch_size, hidden_size]
295
+ - Y_h (optional): [num_directions, batch_size, hidden_size]
296
+
297
+ Equations (Default: f=Sigmoid, g=Tanh):
298
+ - zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
299
+ - rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
300
+ - ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # linear_before_reset=0
301
+ - ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # linear_before_reset!=0
302
+ - Ht = (1 - zt) (.) ht + zt (.) Ht-1
303
+ """
304
+ # Get inputs
305
+ x = builder.get_value(node.input[0])
306
+ w = builder.get_value(node.input[1])
307
+ r = builder.get_value(node.input[2])
308
+
309
+ # Optional inputs
310
+ b = get_optional_input(builder, node, 3)
311
+ sequence_lens = get_optional_input(builder, node, 4)
312
+ initial_h = get_optional_input(builder, node, 5)
313
+
314
+ # Get attributes
315
+ hidden_size = get_attribute(node, "hidden_size")
316
+ direction = get_attribute(node, "direction", "forward")
317
+ layout = get_attribute(node, "layout", 0)
318
+ linear_before_reset = get_attribute(node, "linear_before_reset", 0)
319
+ # activations = get_attribute(node, "activations", ["Sigmoid", "Tanh"])
320
+ # clip = get_attribute(node, "clip", None)
321
+
322
+ # Determine output requirements
323
+ output_y = len(node.output) > 0 and node.output[0] != ""
324
+ output_y_h = len(node.output) > 1 and node.output[1] != ""
325
+
326
+ def _gru_impl(
327
+ x,
328
+ w,
329
+ r,
330
+ b,
331
+ sequence_lens,
332
+ initial_h,
333
+ hidden_size,
334
+ direction,
335
+ layout,
336
+ linear_before_reset,
337
+ output_y,
338
+ output_y_h,
339
+ ):
340
+ # Handle layout: convert to seq_first format for processing
341
+ # layout=0: [seq_length, batch_size, input_size]
342
+ # layout=1: [batch_size, seq_length, input_size]
343
+ if layout == 1:
344
+ x = x.transpose(0, 1)
345
+
346
+ seq_length, batch_size, input_size = x.shape
347
+ num_directions = 2 if direction == "bidirectional" else 1
348
+
349
+ # Initialize hidden state if not provided
350
+ if initial_h is None:
351
+ initial_h = torch.zeros(
352
+ num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
353
+ )
354
+
355
+ # Process each direction
356
+ all_y = []
357
+ all_y_h = []
358
+
359
+ for dir_idx in range(num_directions):
360
+ # Get weights for this direction
361
+ # W shape: [num_directions, 3*hidden_size, input_size]
362
+ # ONNX order: [Wz, Wr, Wh] concatenated
363
+ w_dir = w[dir_idx] # [3*hidden_size, input_size]
364
+ w_z = w_dir[0:hidden_size, :] # [hidden_size, input_size]
365
+ w_r = w_dir[hidden_size : 2 * hidden_size, :]
366
+ w_h = w_dir[2 * hidden_size : 3 * hidden_size, :]
367
+
368
+ # R shape: [num_directions, 3*hidden_size, hidden_size]
369
+ r_dir = r[dir_idx] # [3*hidden_size, hidden_size]
370
+ r_z = r_dir[0:hidden_size, :] # [hidden_size, hidden_size]
371
+ r_r = r_dir[hidden_size : 2 * hidden_size, :]
372
+ r_h = r_dir[2 * hidden_size : 3 * hidden_size, :]
373
+
374
+ # Biases (optional)
375
+ # B shape: [num_directions, 6*hidden_size] = [Wb_z, Wb_r, Wb_h, Rb_z, Rb_r, Rb_h]
376
+ if b is not None:
377
+ b_dir = b[dir_idx] # [6*hidden_size]
378
+ wb_z = b_dir[0:hidden_size]
379
+ wb_r = b_dir[hidden_size : 2 * hidden_size]
380
+ wb_h = b_dir[2 * hidden_size : 3 * hidden_size]
381
+ rb_z = b_dir[3 * hidden_size : 4 * hidden_size]
382
+ rb_r = b_dir[4 * hidden_size : 5 * hidden_size]
383
+ rb_h = b_dir[5 * hidden_size : 6 * hidden_size]
384
+ else:
385
+ wb_z = wb_r = wb_h = rb_z = rb_r = rb_h = 0
386
+
387
+ # Initial hidden state for this direction
388
+ h_t = initial_h[dir_idx] # [batch_size, hidden_size]
389
+
390
+ # Process sequence
391
+ outputs = []
392
+ if direction == "reverse" or (
393
+ direction == "bidirectional" and dir_idx == 1
394
+ ):
395
+ time_steps = range(seq_length - 1, -1, -1)
396
+ else:
397
+ time_steps = range(seq_length)
398
+
399
+ for t in time_steps:
400
+ x_t = x[t] # [batch_size, input_size]
401
+
402
+ # Compute gates
403
+ # zt = sigmoid(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
404
+ z_t = torch.sigmoid(
405
+ x_t @ w_z.T + h_t @ r_z.T + wb_z + rb_z
406
+ ) # [batch_size, hidden_size]
407
+
408
+ # rt = sigmoid(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
409
+ r_t = torch.sigmoid(x_t @ w_r.T + h_t @ r_r.T + wb_r + rb_r)
410
+
411
+ # ht = tanh(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # linear_before_reset=0
412
+ # ht = tanh(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # linear_before_reset!=0
413
+ if linear_before_reset:
414
+ h_tilde = torch.tanh(
415
+ x_t @ w_h.T + r_t * (h_t @ r_h.T + rb_h) + wb_h
416
+ )
417
+ else:
418
+ h_tilde = torch.tanh(
419
+ x_t @ w_h.T + (r_t * h_t) @ r_h.T + rb_h + wb_h
420
+ )
421
+
422
+ # Ht = (1 - zt) (.) ht + zt (.) Ht-1
423
+ h_t = (1 - z_t) * h_tilde + z_t * h_t
424
+
425
+ outputs.append(h_t)
426
+
427
+ # Stack outputs
428
+ if direction == "reverse" or (
429
+ direction == "bidirectional" and dir_idx == 1
430
+ ):
431
+ outputs = outputs[::-1]
432
+
433
+ # [seq_length, batch_size, hidden_size]
434
+ dir_y = torch.stack(outputs, dim=0)
435
+ all_y.append(dir_y)
436
+ all_y_h.append(h_t)
437
+
438
+ # Combine directions
439
+ # Y: [seq_length, num_directions, batch_size, hidden_size]
440
+ y = torch.stack(all_y, dim=1)
441
+
442
+ # Y_h: [num_directions, batch_size, hidden_size]
443
+ y_h = torch.stack(all_y_h, dim=0)
444
+
445
+ # Handle layout for output
446
+ if layout == 1:
447
+ # Convert Y from [seq_length, num_directions, batch_size, hidden_size]
448
+ # to [batch_size, seq_length, num_directions, hidden_size]
449
+ y = y.permute(2, 0, 1, 3)
450
+ # Convert Y_h from [num_directions, batch_size, hidden_size]
451
+ # to [batch_size, num_directions, hidden_size]
452
+ y_h = y_h.transpose(0, 1)
453
+
454
+ # Return based on required outputs
455
+ if output_y and output_y_h:
456
+ return (y, y_h)
457
+ elif output_y:
458
+ return y
459
+ elif output_y_h:
460
+ return y_h
461
+ else:
462
+ return y_h # Default to returning Y_h
463
+
464
+ return builder.call_function(
465
+ _gru_impl,
466
+ args=(
467
+ x,
468
+ w,
469
+ r,
470
+ b,
471
+ sequence_lens,
472
+ initial_h,
473
+ hidden_size,
474
+ direction,
475
+ layout,
476
+ linear_before_reset,
477
+ output_y,
478
+ output_y_h,
479
+ ),
480
+ )
481
+
482
+
483
+ @register("RNN")
484
+ def rnn(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
485
+ """RNN (Simple Recurrent Neural Network) operator.
486
+
487
+ Computes an one-layer simple RNN.
488
+
489
+ ONNX RNN Inputs:
490
+ - X: input tensor [seq_length, batch_size, input_size] (layout=0)
491
+ or [batch_size, seq_length, input_size] (layout=1)
492
+ - W: weight tensor [num_directions, hidden_size, input_size]
493
+ - R: recurrence weight [num_directions, hidden_size, hidden_size]
494
+ - B (optional): bias [num_directions, 2*hidden_size] = [Wbi, Rbi]
495
+ - sequence_lens (optional): [batch_size]
496
+ - initial_h (optional): [num_directions, batch_size, hidden_size]
497
+
498
+ ONNX RNN Outputs:
499
+ - Y (optional): [seq_length, num_directions, batch_size, hidden_size]
500
+ - Y_h (optional): [num_directions, batch_size, hidden_size]
501
+
502
+ Equations (Default: f=Tanh):
503
+ - Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
504
+ """
505
+ # Get inputs
506
+ x = builder.get_value(node.input[0])
507
+ w = builder.get_value(node.input[1])
508
+ r = builder.get_value(node.input[2])
509
+
510
+ # Optional inputs
511
+ b = get_optional_input(builder, node, 3)
512
+ sequence_lens = get_optional_input(builder, node, 4)
513
+ initial_h = get_optional_input(builder, node, 5)
514
+
515
+ # Get attributes
516
+ hidden_size = get_attribute(node, "hidden_size")
517
+ direction = get_attribute(node, "direction", "forward")
518
+ layout = get_attribute(node, "layout", 0)
519
+ # activations = get_attribute(node, "activations", ["Tanh"])
520
+ # clip = get_attribute(node, "clip", None)
521
+
522
+ # Determine output requirements
523
+ output_y = len(node.output) > 0 and node.output[0] != ""
524
+ output_y_h = len(node.output) > 1 and node.output[1] != ""
525
+
526
+ def _rnn_impl(
527
+ x,
528
+ w,
529
+ r,
530
+ b,
531
+ sequence_lens,
532
+ initial_h,
533
+ hidden_size,
534
+ direction,
535
+ layout,
536
+ output_y,
537
+ output_y_h,
538
+ ):
539
+ # Handle layout: convert to seq_first format for processing
540
+ # layout=0: [seq_length, batch_size, input_size]
541
+ # layout=1: [batch_size, seq_length, input_size]
542
+ if layout == 1:
543
+ x = x.transpose(0, 1)
544
+
545
+ seq_length, batch_size, input_size = x.shape
546
+ num_directions = 2 if direction == "bidirectional" else 1
547
+
548
+ # Initialize hidden state if not provided
549
+ if initial_h is None:
550
+ initial_h = torch.zeros(
551
+ num_directions, batch_size, hidden_size, dtype=x.dtype, device=x.device
552
+ )
553
+
554
+ # Process each direction
555
+ all_y = []
556
+ all_y_h = []
557
+
558
+ for dir_idx in range(num_directions):
559
+ # Get weights for this direction
560
+ # W shape: [num_directions, hidden_size, input_size]
561
+ w_dir = w[dir_idx] # [hidden_size, input_size]
562
+
563
+ # R shape: [num_directions, hidden_size, hidden_size]
564
+ r_dir = r[dir_idx] # [hidden_size, hidden_size]
565
+
566
+ # Biases (optional)
567
+ # B shape: [num_directions, 2*hidden_size] = [Wbi, Rbi]
568
+ if b is not None:
569
+ b_dir = b[dir_idx] # [2*hidden_size]
570
+ wb_i = b_dir[0:hidden_size]
571
+ rb_i = b_dir[hidden_size : 2 * hidden_size]
572
+ else:
573
+ wb_i = rb_i = 0
574
+
575
+ # Initial hidden state for this direction
576
+ h_t = initial_h[dir_idx] # [batch_size, hidden_size]
577
+
578
+ # Process sequence
579
+ outputs = []
580
+ if direction == "reverse" or (
581
+ direction == "bidirectional" and dir_idx == 1
582
+ ):
583
+ time_steps = range(seq_length - 1, -1, -1)
584
+ else:
585
+ time_steps = range(seq_length)
586
+
587
+ for t in time_steps:
588
+ x_t = x[t] # [batch_size, input_size]
589
+
590
+ # Compute: Ht = tanh(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
591
+ h_t = torch.tanh(x_t @ w_dir.T + h_t @ r_dir.T + wb_i + rb_i)
592
+
593
+ outputs.append(h_t)
594
+
595
+ # Stack outputs
596
+ if direction == "reverse" or (
597
+ direction == "bidirectional" and dir_idx == 1
598
+ ):
599
+ outputs = outputs[::-1]
600
+
601
+ # [seq_length, batch_size, hidden_size]
602
+ dir_y = torch.stack(outputs, dim=0)
603
+ all_y.append(dir_y)
604
+ all_y_h.append(h_t)
605
+
606
+ # Combine directions
607
+ # Y: [seq_length, num_directions, batch_size, hidden_size]
608
+ y = torch.stack(all_y, dim=1)
609
+
610
+ # Y_h: [num_directions, batch_size, hidden_size]
611
+ y_h = torch.stack(all_y_h, dim=0)
612
+
613
+ # Handle layout for output
614
+ if layout == 1:
615
+ # Convert Y from [seq_length, num_directions, batch_size, hidden_size]
616
+ # to [batch_size, seq_length, num_directions, hidden_size]
617
+ y = y.permute(2, 0, 1, 3)
618
+ # Convert Y_h from [num_directions, batch_size, hidden_size]
619
+ # to [batch_size, num_directions, hidden_size]
620
+ y_h = y_h.transpose(0, 1)
621
+
622
+ # Return based on required outputs
623
+ if output_y and output_y_h:
624
+ return (y, y_h)
625
+ elif output_y:
626
+ return y
627
+ elif output_y_h:
628
+ return y_h
629
+ else:
630
+ return y_h # Default to returning Y_h
631
+
632
+ return builder.call_function(
633
+ _rnn_impl,
634
+ args=(
635
+ x,
636
+ w,
637
+ r,
638
+ b,
639
+ sequence_lens,
640
+ initial_h,
641
+ hidden_size,
642
+ direction,
643
+ layout,
644
+ output_y,
645
+ output_y_h,
646
+ ),
647
+ )