returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +130 -42
- returnn/datasets/meta.py +93 -43
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +48 -2
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +222 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +182 -172
- returnn/native_op.py +36 -31
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +8 -5
- returnn/tf/frontend_layers/_backend.py +7 -3
- returnn/tf/layers/basic.py +27 -40
- returnn/tf/native_op.py +27 -63
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +22 -197
- returnn/torch/engine.py +157 -6
- returnn/torch/frontend/_backend.py +280 -29
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +12 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- returnn/util/task_system.py +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
|
@@ -23,6 +23,8 @@ from returnn.frontend import _random_journal
|
|
|
23
23
|
from returnn.frontend import _utils
|
|
24
24
|
|
|
25
25
|
from . import raw_ops
|
|
26
|
+
from ..util import native_op
|
|
27
|
+
from ..util.assert_ import assert_
|
|
26
28
|
|
|
27
29
|
_TT = Tensor[torch.Tensor]
|
|
28
30
|
|
|
@@ -44,6 +46,12 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
44
46
|
"""
|
|
45
47
|
return True
|
|
46
48
|
|
|
49
|
+
@staticmethod
|
|
50
|
+
def assert_(condition: Tensor, message: str):
|
|
51
|
+
"""assert"""
|
|
52
|
+
assert condition.dims == (), "condition for assert must be a scalar"
|
|
53
|
+
assert_(condition.raw_tensor, message)
|
|
54
|
+
|
|
47
55
|
@staticmethod
|
|
48
56
|
def set_random_seed(seed: int):
|
|
49
57
|
"""
|
|
@@ -275,7 +283,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
275
283
|
:return: tensor
|
|
276
284
|
"""
|
|
277
285
|
assert len(dims) >= 2
|
|
278
|
-
first_axis = min(source.dims.index(d) for d in dims)
|
|
286
|
+
first_axis = min([source.dims.index(d) for d in dims])
|
|
279
287
|
pre_dims = source.dims[:first_axis]
|
|
280
288
|
post_dims = [d for d in source.dims if d not in dims and d not in pre_dims]
|
|
281
289
|
source = source.copy_transpose(tuple(pre_dims) + tuple(dims) + tuple(post_dims), allow_int=False)
|
|
@@ -666,10 +674,10 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
666
674
|
targets_spatial_dim: Dim,
|
|
667
675
|
blank_index: int,
|
|
668
676
|
max_approx: bool = False,
|
|
677
|
+
use_native_op: Optional[bool] = None,
|
|
678
|
+
label_loop: bool = True,
|
|
669
679
|
) -> Tensor:
|
|
670
680
|
"""CTC"""
|
|
671
|
-
if max_approx:
|
|
672
|
-
raise NotImplementedError("ctc_loss: max_approx not implemented for PyTorch")
|
|
673
681
|
assert targets.sparse_dim and targets.sparse_dim.dimension <= logits.feature_dim.dimension
|
|
674
682
|
# PyTorch expects the logits to be of shape (T, B, C) where T is the input spatial dim.
|
|
675
683
|
batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
|
|
@@ -707,18 +715,42 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
707
715
|
if len(batch_dims) != 1:
|
|
708
716
|
targets_raw = torch.reshape(targets_raw, (batch_n_elems, targets_raw.shape[-1])) # [B', S]
|
|
709
717
|
targets_lengths = torch.reshape(targets_lengths, (batch_n_elems,)) # [B']
|
|
710
|
-
if
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
718
|
+
if use_native_op is None:
|
|
719
|
+
if max_approx or not label_loop:
|
|
720
|
+
use_native_op = True
|
|
721
|
+
else:
|
|
722
|
+
# This was the current default.
|
|
723
|
+
# We might change the default in the future, maybe via new behavior version.
|
|
724
|
+
use_native_op = False
|
|
725
|
+
if use_native_op:
|
|
726
|
+
loss_raw = native_op.ctc_loss(
|
|
727
|
+
logits=log_probs,
|
|
728
|
+
logits_normalize=True,
|
|
729
|
+
logits_seq_lens=input_lengths,
|
|
730
|
+
logits_time_major=True,
|
|
731
|
+
targets=targets_raw,
|
|
732
|
+
targets_seq_lens=targets_lengths,
|
|
733
|
+
blank_index=blank_index,
|
|
734
|
+
max_approx=max_approx,
|
|
735
|
+
label_loop=label_loop,
|
|
736
|
+
)
|
|
737
|
+
else: # not native_op
|
|
738
|
+
if max_approx:
|
|
739
|
+
raise NotImplementedError("ctc_loss: max_approx not implemented for PyTorch")
|
|
740
|
+
if not label_loop:
|
|
741
|
+
raise NotImplementedError("ctc_loss: label_loop=False not implemented for PyTorch")
|
|
742
|
+
if log_probs.dtype == torch.bfloat16:
|
|
743
|
+
# Currently (PyTorch 2.5), ctc_loss does not support bfloat16.
|
|
744
|
+
log_probs = log_probs.to(torch.float32)
|
|
745
|
+
loss_raw = torch.nn.functional.ctc_loss(
|
|
746
|
+
log_probs=log_probs,
|
|
747
|
+
targets=targets_raw,
|
|
748
|
+
input_lengths=input_lengths,
|
|
749
|
+
target_lengths=targets_lengths,
|
|
750
|
+
blank=blank_index,
|
|
751
|
+
zero_infinity=True,
|
|
752
|
+
reduction="none",
|
|
753
|
+
)
|
|
722
754
|
if len(batch_dims) != 1:
|
|
723
755
|
loss_raw = torch.reshape(loss_raw, logits_raw_shape[1:-1])
|
|
724
756
|
loss = Tensor(
|
|
@@ -729,6 +761,103 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
729
761
|
)
|
|
730
762
|
return loss
|
|
731
763
|
|
|
764
|
+
@staticmethod
|
|
765
|
+
def ctc_best_path(
|
|
766
|
+
*,
|
|
767
|
+
logits: Tensor,
|
|
768
|
+
logits_normalized: bool = False,
|
|
769
|
+
targets: Tensor,
|
|
770
|
+
input_spatial_dim: Dim,
|
|
771
|
+
targets_spatial_dim: Dim,
|
|
772
|
+
blank_index: int,
|
|
773
|
+
label_loop: bool = True,
|
|
774
|
+
) -> Tensor:
|
|
775
|
+
"""CTC best path"""
|
|
776
|
+
assert targets.sparse_dim and targets.sparse_dim.dimension <= logits.feature_dim.dimension
|
|
777
|
+
# PyTorch expects the logits to be of shape (T, B, C) where T is the input spatial dim.
|
|
778
|
+
batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
|
|
779
|
+
batch_dims_targets = targets.remaining_dims(targets_spatial_dim)
|
|
780
|
+
if set(batch_dims) != set(batch_dims_targets):
|
|
781
|
+
# Need to broadcast.
|
|
782
|
+
logits = rf.expand_dims(logits, [d for d in batch_dims_targets if d not in batch_dims])
|
|
783
|
+
targets = rf.expand_dims(targets, [d for d in batch_dims if d not in batch_dims_targets])
|
|
784
|
+
batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
|
|
785
|
+
batch_shape = [d.get_dim_value() for d in batch_dims]
|
|
786
|
+
batch_n_elems = prod(batch_shape)
|
|
787
|
+
logits = logits.copy_transpose([input_spatial_dim] + batch_dims + [logits.feature_dim])
|
|
788
|
+
logits_raw: torch.Tensor = logits.raw_tensor
|
|
789
|
+
input_lengths: torch.Tensor = input_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(batch_dims)
|
|
790
|
+
if input_lengths.numel() != batch_n_elems:
|
|
791
|
+
input_lengths = input_lengths.expand(batch_shape)
|
|
792
|
+
if len(batch_dims) != 1:
|
|
793
|
+
logits_raw = torch.reshape(
|
|
794
|
+
logits_raw, logits_raw.shape[:1] + (batch_n_elems,) + logits_raw.shape[-1:]
|
|
795
|
+
) # [T, B', C]
|
|
796
|
+
input_lengths = torch.reshape(input_lengths, (batch_n_elems,)) # [B']
|
|
797
|
+
if logits_normalized:
|
|
798
|
+
log_probs = logits_raw
|
|
799
|
+
else:
|
|
800
|
+
log_probs = torch.nn.functional.log_softmax(logits_raw, dim=-1)
|
|
801
|
+
# PyTorch expects the targets to be of shape (B, S) where S is the targets spatial dim.
|
|
802
|
+
targets_raw = targets.copy_compatible_to_dims_raw(batch_dims + [targets_spatial_dim]) # [B..., S]
|
|
803
|
+
targets_raw_shape = batch_shape + [targets_spatial_dim.get_dim_value()]
|
|
804
|
+
if targets_raw.numel() != prod(targets_raw_shape):
|
|
805
|
+
targets_raw = targets_raw.expand(targets_raw_shape)
|
|
806
|
+
targets_lengths = targets_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(batch_dims)
|
|
807
|
+
if targets_lengths.numel() != batch_n_elems:
|
|
808
|
+
targets_lengths = targets_lengths.expand(batch_shape)
|
|
809
|
+
if len(batch_dims) != 1:
|
|
810
|
+
targets_raw = torch.reshape(targets_raw, (batch_n_elems, targets_raw.shape[-1])) # [B', S]
|
|
811
|
+
targets_lengths = torch.reshape(targets_lengths, (batch_n_elems,)) # [B']
|
|
812
|
+
alignment_raw = native_op.ctc_best_path(
|
|
813
|
+
logits=log_probs,
|
|
814
|
+
logits_normalize=True,
|
|
815
|
+
logits_seq_lens=input_lengths,
|
|
816
|
+
logits_time_major=True,
|
|
817
|
+
targets=targets_raw,
|
|
818
|
+
targets_seq_lens=targets_lengths,
|
|
819
|
+
blank_index=blank_index,
|
|
820
|
+
label_loop=label_loop,
|
|
821
|
+
) # (time,batch)
|
|
822
|
+
if len(batch_dims) != 1:
|
|
823
|
+
alignment_raw = torch.reshape(alignment_raw, log_probs.shape[:-1])
|
|
824
|
+
alignment = Tensor(
|
|
825
|
+
name="ctc_best_path",
|
|
826
|
+
dims=[input_spatial_dim] + batch_dims,
|
|
827
|
+
sparse_dim=logits.feature_dim,
|
|
828
|
+
raw_tensor=alignment_raw,
|
|
829
|
+
dtype=TorchBackend.get_dtype_name_raw(alignment_raw),
|
|
830
|
+
)
|
|
831
|
+
return alignment
|
|
832
|
+
|
|
833
|
+
@staticmethod
|
|
834
|
+
def have_edit_distance() -> bool:
|
|
835
|
+
"""whether edit distance is available"""
|
|
836
|
+
return True
|
|
837
|
+
|
|
838
|
+
@staticmethod
|
|
839
|
+
def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim) -> Tensor:
|
|
840
|
+
"""edit distance"""
|
|
841
|
+
a_batch_dims = a.remaining_dims(a_spatial_dim)
|
|
842
|
+
b_batch_dims = b.remaining_dims(b_spatial_dim)
|
|
843
|
+
assert set(a_batch_dims) == set(b_batch_dims), "edit_distance: batch dims must match"
|
|
844
|
+
a_raw = a.copy_compatible_to_dims_raw(a_batch_dims + [a_spatial_dim])
|
|
845
|
+
b_raw = b.copy_compatible_to_dims_raw(a_batch_dims + [b_spatial_dim])
|
|
846
|
+
a_seq_len = a_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(a_batch_dims)
|
|
847
|
+
b_seq_len = b_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(a_batch_dims)
|
|
848
|
+
batch_shape = None
|
|
849
|
+
if len(a_batch_dims) != 1:
|
|
850
|
+
batch_shape = [d.get_dim_value() for d in a_batch_dims]
|
|
851
|
+
batch_n_elems = prod(batch_shape)
|
|
852
|
+
a_raw = torch.reshape(a_raw.raw_tensor, (batch_n_elems, a_spatial_dim.get_dim_value()))
|
|
853
|
+
b_raw = torch.reshape(b_raw.raw_tensor, (batch_n_elems, b_spatial_dim.get_dim_value()))
|
|
854
|
+
a_seq_len = torch.reshape(a_seq_len.raw_tensor, (batch_n_elems,))
|
|
855
|
+
b_seq_len = torch.reshape(b_seq_len.raw_tensor, (batch_n_elems,))
|
|
856
|
+
dist_raw = native_op.edit_distance(a_raw, a_seq_len, b_raw, b_seq_len)
|
|
857
|
+
if len(a_batch_dims) != 1:
|
|
858
|
+
dist_raw = torch.reshape(dist_raw, batch_shape)
|
|
859
|
+
return rf.convert_to_tensor(dist_raw, name="edit_distance", dims=a_batch_dims)
|
|
860
|
+
|
|
732
861
|
@staticmethod
|
|
733
862
|
def create_parameter_raw(tensor: rf.Parameter, *, device: Optional[str] = None) -> torch.nn.Parameter:
|
|
734
863
|
"""
|
|
@@ -884,7 +1013,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
884
1013
|
:param perm: e.g. [0, 2, 1]
|
|
885
1014
|
:return: permuted (transposed) raw tensor; wraps torch.permute
|
|
886
1015
|
"""
|
|
887
|
-
if all(p == i for i, p in enumerate(perm)):
|
|
1016
|
+
if all([p == i for i, p in enumerate(perm)]):
|
|
888
1017
|
return raw_tensor
|
|
889
1018
|
return torch.permute(raw_tensor, tuple(perm))
|
|
890
1019
|
|
|
@@ -1166,20 +1295,29 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1166
1295
|
if start is None:
|
|
1167
1296
|
start = 0
|
|
1168
1297
|
if isinstance(size, Dim):
|
|
1298
|
+
assert end is None
|
|
1169
1299
|
size = size.get_dim_value()
|
|
1170
1300
|
elif isinstance(size, Tensor):
|
|
1301
|
+
assert end is None
|
|
1171
1302
|
assert size.dims == () # scalar
|
|
1172
1303
|
size = size.raw_tensor
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
else:
|
|
1304
|
+
elif isinstance(size, int):
|
|
1305
|
+
pass
|
|
1306
|
+
elif size is None:
|
|
1177
1307
|
if isinstance(end, Tensor):
|
|
1178
1308
|
assert end.dims == ()
|
|
1179
1309
|
end = end.raw_tensor
|
|
1180
|
-
|
|
1310
|
+
elif isinstance(end, int):
|
|
1311
|
+
if end < 0:
|
|
1312
|
+
end += axis.get_dim_value()
|
|
1313
|
+
elif end is None:
|
|
1181
1314
|
end = axis.get_dim_value()
|
|
1182
|
-
|
|
1315
|
+
else:
|
|
1316
|
+
raise TypeError(f"slice: unsupported type for end: {type(end)}")
|
|
1317
|
+
size = end - start
|
|
1318
|
+
else:
|
|
1319
|
+
raise TypeError(f"slice: unsupported type for size: {type(size)}")
|
|
1320
|
+
out.raw_tensor = torch.narrow(source.raw_tensor, dim=axis_int, start=start, length=size)
|
|
1183
1321
|
return out
|
|
1184
1322
|
|
|
1185
1323
|
@staticmethod
|
|
@@ -1352,12 +1490,24 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1352
1490
|
a_dims = a.dims
|
|
1353
1491
|
b_dims = b.dims
|
|
1354
1492
|
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1493
|
+
if not all(dim in a_dims for dim in reduce) or not all(dim in b_dims for dim in reduce):
|
|
1494
|
+
# revert to the generic einsum implementation
|
|
1495
|
+
assert all(dim in a_dims + b_dims for dim in reduce), "Some reduce Dims not in a or b."
|
|
1496
|
+
result_dims = [dim for dim in a_dims if dim not in reduce] + [
|
|
1497
|
+
dim for dim in b_dims if dim not in reduce and dim not in a_dims
|
|
1498
|
+
]
|
|
1499
|
+
map_to_letter = {}
|
|
1500
|
+
for dim in a_dims + b_dims:
|
|
1501
|
+
if dim not in map_to_letter:
|
|
1502
|
+
map_to_letter[dim] = chr(97 + len(map_to_letter)) # 'a', 'b', 'c', ...
|
|
1503
|
+
a_subscript = "".join(map_to_letter[dim] for dim in a_dims)
|
|
1504
|
+
b_subscript = "".join(map_to_letter[dim] for dim in b_dims)
|
|
1505
|
+
out_subscript = "".join(map_to_letter[dim] for dim in result_dims)
|
|
1506
|
+
raw_result = torch.einsum(f"{a_subscript},{b_subscript}->{out_subscript}", a.raw_tensor, b.raw_tensor)
|
|
1507
|
+
result_tensor = Tensor(
|
|
1508
|
+
"einsum", dims=result_dims, raw_tensor=raw_result, dtype=TorchBackend.get_dtype_name_raw(raw_result)
|
|
1509
|
+
)
|
|
1510
|
+
return result_tensor
|
|
1361
1511
|
|
|
1362
1512
|
if len(reduce) > 1:
|
|
1363
1513
|
reduce = list(reduce)
|
|
@@ -1767,6 +1917,9 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1767
1917
|
remaining_dims = [d for d in tensor.dims if d not in mask.dims]
|
|
1768
1918
|
tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
|
|
1769
1919
|
in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
|
|
1920
|
+
if any([in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)]):
|
|
1921
|
+
# unbroadcast
|
|
1922
|
+
in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
|
|
1770
1923
|
if mask.raw_tensor.device.type == "meta":
|
|
1771
1924
|
# This is not supported, but also, we would anyway not know the out shape.
|
|
1772
1925
|
# However, instead of erroring, just assume some dummy mask.
|
|
@@ -1920,7 +2073,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1920
2073
|
if not out_spatial_dims:
|
|
1921
2074
|
out_spatial_dims = rf.make_conv_out_spatial_dims(
|
|
1922
2075
|
in_spatial_dims=in_spatial_dims,
|
|
1923
|
-
filter_size=
|
|
2076
|
+
filter_size=filter_size,
|
|
1924
2077
|
strides=strides or 1,
|
|
1925
2078
|
dilation_rate=dilation_rate or 1,
|
|
1926
2079
|
padding=padding,
|
|
@@ -2033,6 +2186,104 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
2033
2186
|
out.feature_dim = out_dim
|
|
2034
2187
|
return out, out_spatial_dims
|
|
2035
2188
|
|
|
2189
|
+
# noinspection PyShadowingBuiltins
|
|
2190
|
+
@staticmethod
|
|
2191
|
+
def transposed_conv(
|
|
2192
|
+
source: Tensor,
|
|
2193
|
+
*,
|
|
2194
|
+
in_dim: Dim,
|
|
2195
|
+
out_dim: Dim,
|
|
2196
|
+
in_spatial_dims: Sequence[Dim],
|
|
2197
|
+
out_spatial_dims: Optional[Sequence[Dim]] = None,
|
|
2198
|
+
filter: Tensor,
|
|
2199
|
+
filter_size: Sequence[Dim],
|
|
2200
|
+
padding: str,
|
|
2201
|
+
remove_padding: Union[Sequence[int], int] = 0,
|
|
2202
|
+
output_padding: Optional[Union[Sequence[Optional[int]], int]] = None,
|
|
2203
|
+
strides: Optional[Sequence[int]] = None,
|
|
2204
|
+
bias: Optional[Tensor] = None,
|
|
2205
|
+
) -> Tuple[Tensor, Sequence[Dim]]:
|
|
2206
|
+
"""transposed convolution"""
|
|
2207
|
+
if not out_spatial_dims:
|
|
2208
|
+
out_spatial_dims = rf.make_transposed_conv_out_spatial_dims(
|
|
2209
|
+
in_spatial_dims=in_spatial_dims,
|
|
2210
|
+
filter_size=filter_size,
|
|
2211
|
+
strides=strides,
|
|
2212
|
+
padding=padding,
|
|
2213
|
+
output_padding=output_padding,
|
|
2214
|
+
)
|
|
2215
|
+
assert remove_padding == 0 # not implemented yet otherwise...
|
|
2216
|
+
if strides is None:
|
|
2217
|
+
strides = [fs.dimension for fs in filter_size]
|
|
2218
|
+
filter_dims = (in_dim, out_dim) + tuple(filter_size)
|
|
2219
|
+
filter = filter.copy_transpose(filter_dims)
|
|
2220
|
+
batch_dims = [d for d in source.dims if d not in (in_dim,) + tuple(in_spatial_dims)]
|
|
2221
|
+
# Torch conv expects (N,C,<spatial dims>) as shape.
|
|
2222
|
+
source = source.copy_transpose(batch_dims + [in_dim] + list(in_spatial_dims))
|
|
2223
|
+
if len(batch_dims) == 1:
|
|
2224
|
+
src_raw = source.raw_tensor
|
|
2225
|
+
else:
|
|
2226
|
+
src_raw = torch.reshape(
|
|
2227
|
+
source.raw_tensor,
|
|
2228
|
+
# potentially merge batch dims all together
|
|
2229
|
+
[-1, in_dim.get_dim_value()] + [d.get_dim_value() for d in in_spatial_dims],
|
|
2230
|
+
)
|
|
2231
|
+
if padding == "same":
|
|
2232
|
+
raise NotImplementedError("transposed_conv with padding='same' not implemented")
|
|
2233
|
+
if padding == "valid":
|
|
2234
|
+
padding_val = 0
|
|
2235
|
+
else:
|
|
2236
|
+
raise ValueError(f"invalid padding {padding!r}, expected 'same' or 'valid'")
|
|
2237
|
+
if len(filter_size) == 1:
|
|
2238
|
+
out_raw = torch.nn.functional.conv_transpose1d(
|
|
2239
|
+
src_raw,
|
|
2240
|
+
weight=filter.raw_tensor,
|
|
2241
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2242
|
+
stride=strides,
|
|
2243
|
+
padding=padding_val,
|
|
2244
|
+
output_padding=output_padding or 0,
|
|
2245
|
+
)
|
|
2246
|
+
elif len(filter_size) == 2:
|
|
2247
|
+
out_raw = torch.nn.functional.conv_transpose2d(
|
|
2248
|
+
src_raw,
|
|
2249
|
+
weight=filter.raw_tensor,
|
|
2250
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2251
|
+
stride=strides,
|
|
2252
|
+
padding=padding_val,
|
|
2253
|
+
output_padding=output_padding or 0,
|
|
2254
|
+
)
|
|
2255
|
+
elif len(filter_size) == 3:
|
|
2256
|
+
out_raw = torch.nn.functional.conv_transpose3d(
|
|
2257
|
+
src_raw,
|
|
2258
|
+
weight=filter.raw_tensor,
|
|
2259
|
+
bias=bias.raw_tensor if bias is not None else None,
|
|
2260
|
+
stride=strides,
|
|
2261
|
+
padding=padding_val,
|
|
2262
|
+
output_padding=output_padding or 0,
|
|
2263
|
+
)
|
|
2264
|
+
else:
|
|
2265
|
+
raise ValueError(f"invalid number of filter dims {filter_size}, expected 1, 2, or 3")
|
|
2266
|
+
if remove_padding:
|
|
2267
|
+
if isinstance(remove_padding, int):
|
|
2268
|
+
remove_padding = [remove_padding] * len(out_spatial_dims)
|
|
2269
|
+
assert len(remove_padding) == len(out_spatial_dims)
|
|
2270
|
+
slices = [slice(None)] * out_raw.ndim
|
|
2271
|
+
for i, pad in enumerate(remove_padding):
|
|
2272
|
+
if pad > 0:
|
|
2273
|
+
slices[2 + i] = slice(0, -pad)
|
|
2274
|
+
out_raw = out_raw[tuple(slices)]
|
|
2275
|
+
out = Tensor(
|
|
2276
|
+
"transposed_conv",
|
|
2277
|
+
dims=batch_dims + [out_dim] + list(out_spatial_dims),
|
|
2278
|
+
dtype=TorchBackend.get_dtype_name_raw(out_raw),
|
|
2279
|
+
)
|
|
2280
|
+
if len(batch_dims) == 1:
|
|
2281
|
+
out.raw_tensor = out_raw
|
|
2282
|
+
else:
|
|
2283
|
+
out.raw_tensor = torch.reshape(out_raw, [d.get_dim_value() for d in out.dims])
|
|
2284
|
+
out.feature_dim = out_dim
|
|
2285
|
+
return out, out_spatial_dims
|
|
2286
|
+
|
|
2036
2287
|
@staticmethod
|
|
2037
2288
|
def pool(
|
|
2038
2289
|
source: Tensor,
|
returnn/torch/frontend/bridge.py
CHANGED
|
@@ -136,6 +136,15 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
136
136
|
def _get_name(self):
|
|
137
137
|
return self._rf_module.__class__.__name__ + "[RF→PT]"
|
|
138
138
|
|
|
139
|
+
def __repr__(self) -> str:
|
|
140
|
+
"""
|
|
141
|
+
Return a custom repr for Sequential/ModuleList that compresses repeated module representations if possible,
|
|
142
|
+
otherwise fallback to default behavior.
|
|
143
|
+
"""
|
|
144
|
+
if _can_use_compact_repr(self):
|
|
145
|
+
return _repr_compact(self)
|
|
146
|
+
return super().__repr__()
|
|
147
|
+
|
|
139
148
|
@property
|
|
140
149
|
def rf_module(self) -> rf.Module:
|
|
141
150
|
"""RF module"""
|
|
@@ -193,3 +202,55 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
193
202
|
# See similar logic in torch.nn.Module._apply.
|
|
194
203
|
pt_param = torch.nn.Parameter(tensor, tensor.requires_grad)
|
|
195
204
|
rf_param.raw_tensor = pt_param
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _can_use_compact_repr(self: RFModuleAsPTModule) -> bool:
|
|
208
|
+
return list(self._modules.keys()) == [str(i) for i in range(len(self._modules))]
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _repr_compact(self: RFModuleAsPTModule) -> str:
|
|
212
|
+
"""
|
|
213
|
+
Return a custom repr for Sequential/ModuleList that compresses repeated module representations.
|
|
214
|
+
Code copied and adapted from torch.nn.ModuleList.__repr__.
|
|
215
|
+
"""
|
|
216
|
+
list_of_reprs = [repr(item) for item in self._modules.values()]
|
|
217
|
+
if len(list_of_reprs) == 0:
|
|
218
|
+
return self._get_name() + "()"
|
|
219
|
+
|
|
220
|
+
start_end_indices = [[0, 0]]
|
|
221
|
+
repeated_blocks = [list_of_reprs[0]]
|
|
222
|
+
for i, r in enumerate(list_of_reprs[1:], 1):
|
|
223
|
+
if r == repeated_blocks[-1]:
|
|
224
|
+
start_end_indices[-1][1] += 1
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
start_end_indices.append([i, i])
|
|
228
|
+
repeated_blocks.append(r)
|
|
229
|
+
|
|
230
|
+
lines = []
|
|
231
|
+
main_str = self._get_name() + "("
|
|
232
|
+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
|
233
|
+
local_repr = f"({start_id}): {b}" # default repr
|
|
234
|
+
|
|
235
|
+
if start_id != end_id:
|
|
236
|
+
n = end_id - start_id + 1
|
|
237
|
+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
|
|
238
|
+
|
|
239
|
+
local_repr = _add_indent(local_repr, 2)
|
|
240
|
+
lines.append(local_repr)
|
|
241
|
+
|
|
242
|
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
|
243
|
+
main_str += ")"
|
|
244
|
+
return main_str
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _add_indent(s_: str, num_spaces: int) -> str:
|
|
248
|
+
s = s_.split("\n")
|
|
249
|
+
# don't do anything for single-line stuff
|
|
250
|
+
if len(s) == 1:
|
|
251
|
+
return s_
|
|
252
|
+
first = s.pop(0)
|
|
253
|
+
s = [(num_spaces * " ") + line for line in s]
|
|
254
|
+
s = "\n".join(s)
|
|
255
|
+
s = first + "\n" + s
|
|
256
|
+
return s
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helpers to improve torch.compile on RF code.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Any, Iterable, List, Tuple
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from returnn.tensor import Tensor, Dim
|
|
10
|
+
|
|
11
|
+
# noinspection PyProtectedMember
|
|
12
|
+
from returnn.frontend import _native
|
|
13
|
+
|
|
14
|
+
_is_set_up = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup():
|
|
18
|
+
"""
|
|
19
|
+
Set up the torch.compile helpers for RF code, also including :class:`Tensor` and :class:`Dim`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
global _is_set_up
|
|
23
|
+
if _is_set_up:
|
|
24
|
+
return
|
|
25
|
+
_is_set_up = True # only try once
|
|
26
|
+
|
|
27
|
+
assert not _native.is_set_up(), "Call this setup() as early as possible."
|
|
28
|
+
_native.set_enabled(False)
|
|
29
|
+
|
|
30
|
+
# We have lots of dynamic shapes.
|
|
31
|
+
os.environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
|
|
32
|
+
|
|
33
|
+
# noinspection PyProtectedMember
|
|
34
|
+
from torch.utils._pytree import register_pytree_node
|
|
35
|
+
|
|
36
|
+
register_pytree_node(Tensor, _tensor_flatten, _tensor_unflatten)
|
|
37
|
+
register_pytree_node(Dim, _dim_flatten, _dim_unflatten)
|
|
38
|
+
|
|
39
|
+
Dim.get_dim_value = _dim_get_dim_value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _tensor_flatten(t: Tensor) -> Tuple[List[Any], Any]:
|
|
43
|
+
"""
|
|
44
|
+
Flatten the tensor for PyTree.
|
|
45
|
+
"""
|
|
46
|
+
return [t.raw_tensor, t.dims, t.sparse_dim], [
|
|
47
|
+
t.name,
|
|
48
|
+
t.dtype,
|
|
49
|
+
t.version,
|
|
50
|
+
t.feature_dim_axis_or_unspecified,
|
|
51
|
+
t.time_dim_axis_or_unspecified,
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _tensor_unflatten(values: Iterable[Any], metadata: Any) -> Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Unflatten the tensor from PyTree.
|
|
58
|
+
"""
|
|
59
|
+
raw_tensor, dims, sparse_dim = values
|
|
60
|
+
name, dtype, version, feature_dim_axis, time_dim_axis = metadata
|
|
61
|
+
return Tensor(
|
|
62
|
+
name=name,
|
|
63
|
+
dims=dims,
|
|
64
|
+
dtype=dtype,
|
|
65
|
+
sparse_dim=sparse_dim,
|
|
66
|
+
feature_dim_axis=feature_dim_axis,
|
|
67
|
+
time_dim_axis=time_dim_axis,
|
|
68
|
+
raw_tensor=raw_tensor,
|
|
69
|
+
version=version,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _dim_flatten(d: Dim) -> Tuple[List[Any], Any]:
|
|
74
|
+
"""
|
|
75
|
+
Flatten the dim for PyTree.
|
|
76
|
+
"""
|
|
77
|
+
return [d.dyn_size_ext], [d.name, d.dimension, d.size]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _dim_unflatten(values: Iterable[Any], metadata: Any) -> Dim:
|
|
81
|
+
"""
|
|
82
|
+
Unflatten the dim from PyTree.
|
|
83
|
+
"""
|
|
84
|
+
(dyn_size_ext,) = values
|
|
85
|
+
name, dimension, size = metadata
|
|
86
|
+
# TODO this creates a new instance... this is maybe wrong?
|
|
87
|
+
return Dim(name=name, dimension=dimension, size=size, dyn_size_ext=dyn_size_ext)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _dim_get_dim_value(self: Dim) -> int:
|
|
91
|
+
"""
|
|
92
|
+
Infers the dim this axis should have if unbroadcasted.
|
|
93
|
+
If `self.src_data` has a placeholder, will use the shape from there.
|
|
94
|
+
Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
|
|
95
|
+
|
|
96
|
+
:return: max(size or dyn_size)
|
|
97
|
+
"""
|
|
98
|
+
res = self.get_dim_value_tensor()
|
|
99
|
+
if isinstance(res, Tensor):
|
|
100
|
+
assert res.dims == ()
|
|
101
|
+
assert res.raw_tensor is not None
|
|
102
|
+
# Specifically PyTorch would then treat it as a SymInt in torch.compile,
|
|
103
|
+
# which is important to have for some torch functions (e.g. torch.tile and others).
|
|
104
|
+
return int(res.raw_tensor)
|
|
105
|
+
assert isinstance(res, int)
|
|
106
|
+
return res
|
returnn/torch/util/array_.py
CHANGED
|
@@ -60,3 +60,33 @@ def nonzero(mask: torch.Tensor, *, out_len: Union[int, torch.Tensor]) -> torch.T
|
|
|
60
60
|
idx = torch.argsort(mask.to(torch.int8), stable=True, descending=True) # [in_len]
|
|
61
61
|
idx = idx[:out_len] # [out_len]
|
|
62
62
|
return idx
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def sequence_mask(lengths: torch.Tensor, *, maxlen: Optional[int] = None) -> torch.Tensor:
|
|
66
|
+
"""
|
|
67
|
+
Creates a boolean mask from sequence lengths.
|
|
68
|
+
|
|
69
|
+
:param lengths: Tensor of shape [batch_size...] containing sequence lengths
|
|
70
|
+
:param maxlen: Maximum length of the sequences. If None, uses the maximum value in lengths.
|
|
71
|
+
:return: A boolean mask tensor of shape [batch_size..., maxlen]
|
|
72
|
+
"""
|
|
73
|
+
if maxlen is None:
|
|
74
|
+
maxlen = lengths.max()
|
|
75
|
+
indices = torch.arange(0, maxlen, dtype=lengths.dtype, device=lengths.device)
|
|
76
|
+
mask = indices < lengths[..., None]
|
|
77
|
+
return mask
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def sequence_mask_time_major(lengths: torch.Tensor, *, maxlen: Optional[int] = None) -> torch.Tensor:
|
|
81
|
+
"""
|
|
82
|
+
Creates a boolean mask from sequence lengths.
|
|
83
|
+
|
|
84
|
+
:param lengths: Tensor of shape [batch_size...] containing sequence lengths
|
|
85
|
+
:param maxlen: Maximum length of the sequences. If None, uses the maximum value in lengths.
|
|
86
|
+
:return: A boolean mask tensor of shape [maxlen, batch_size...]
|
|
87
|
+
"""
|
|
88
|
+
if maxlen is None:
|
|
89
|
+
maxlen = lengths.max()
|
|
90
|
+
indices = torch.arange(0, maxlen, dtype=lengths.dtype, device=lengths.device)
|
|
91
|
+
mask = indices[(slice(None),) + (None,) * lengths.ndim] < lengths[None]
|
|
92
|
+
return mask
|