returnn 1.20260105.192646__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +1 -1
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +110 -42
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +6 -5
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +19 -0
- returnn/frontend/loss.py +183 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +104 -174
- returnn/native_op.py +36 -31
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +1 -1
- returnn/tf/frontend_layers/_backend.py +3 -1
- returnn/tf/layers/basic.py +13 -2
- returnn/tf/native_op.py +16 -5
- returnn/tf/util/basic.py +7 -201
- returnn/torch/engine.py +120 -3
- returnn/torch/frontend/_backend.py +166 -22
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +3 -1
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +1 -0
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +42 -36
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
|
@@ -23,6 +23,8 @@ from returnn.frontend import _random_journal
|
|
|
23
23
|
from returnn.frontend import _utils
|
|
24
24
|
|
|
25
25
|
from . import raw_ops
|
|
26
|
+
from ..util import native_op
|
|
27
|
+
from ..util.assert_ import assert_
|
|
26
28
|
|
|
27
29
|
_TT = Tensor[torch.Tensor]
|
|
28
30
|
|
|
@@ -44,6 +46,12 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
44
46
|
"""
|
|
45
47
|
return True
|
|
46
48
|
|
|
49
|
+
@staticmethod
|
|
50
|
+
def assert_(condition: Tensor, message: str):
|
|
51
|
+
"""assert"""
|
|
52
|
+
assert condition.dims == (), "condition for assert must be a scalar"
|
|
53
|
+
assert_(condition.raw_tensor, message)
|
|
54
|
+
|
|
47
55
|
@staticmethod
|
|
48
56
|
def set_random_seed(seed: int):
|
|
49
57
|
"""
|
|
@@ -275,7 +283,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
275
283
|
:return: tensor
|
|
276
284
|
"""
|
|
277
285
|
assert len(dims) >= 2
|
|
278
|
-
first_axis = min(source.dims.index(d) for d in dims)
|
|
286
|
+
first_axis = min([source.dims.index(d) for d in dims])
|
|
279
287
|
pre_dims = source.dims[:first_axis]
|
|
280
288
|
post_dims = [d for d in source.dims if d not in dims and d not in pre_dims]
|
|
281
289
|
source = source.copy_transpose(tuple(pre_dims) + tuple(dims) + tuple(post_dims), allow_int=False)
|
|
@@ -666,10 +674,10 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
666
674
|
targets_spatial_dim: Dim,
|
|
667
675
|
blank_index: int,
|
|
668
676
|
max_approx: bool = False,
|
|
677
|
+
use_native_op: Optional[bool] = None,
|
|
678
|
+
label_loop: bool = True,
|
|
669
679
|
) -> Tensor:
|
|
670
680
|
"""CTC"""
|
|
671
|
-
if max_approx:
|
|
672
|
-
raise NotImplementedError("ctc_loss: max_approx not implemented for PyTorch")
|
|
673
681
|
assert targets.sparse_dim and targets.sparse_dim.dimension <= logits.feature_dim.dimension
|
|
674
682
|
# PyTorch expects the logits to be of shape (T, B, C) where T is the input spatial dim.
|
|
675
683
|
batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
|
|
@@ -707,18 +715,42 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
707
715
|
if len(batch_dims) != 1:
|
|
708
716
|
targets_raw = torch.reshape(targets_raw, (batch_n_elems, targets_raw.shape[-1])) # [B', S]
|
|
709
717
|
targets_lengths = torch.reshape(targets_lengths, (batch_n_elems,)) # [B']
|
|
710
|
-
if
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
718
|
+
if use_native_op is None:
|
|
719
|
+
if max_approx or not label_loop:
|
|
720
|
+
use_native_op = True
|
|
721
|
+
else:
|
|
722
|
+
# This was the current default.
|
|
723
|
+
# We might change the default in the future, maybe via new behavior version.
|
|
724
|
+
use_native_op = False
|
|
725
|
+
if use_native_op:
|
|
726
|
+
loss_raw = native_op.ctc_loss(
|
|
727
|
+
logits=log_probs,
|
|
728
|
+
logits_normalize=True,
|
|
729
|
+
logits_seq_lens=input_lengths,
|
|
730
|
+
logits_time_major=True,
|
|
731
|
+
targets=targets_raw,
|
|
732
|
+
targets_seq_lens=targets_lengths,
|
|
733
|
+
blank_index=blank_index,
|
|
734
|
+
max_approx=max_approx,
|
|
735
|
+
label_loop=label_loop,
|
|
736
|
+
)
|
|
737
|
+
else: # not native_op
|
|
738
|
+
if max_approx:
|
|
739
|
+
raise NotImplementedError("ctc_loss: max_approx not implemented for PyTorch")
|
|
740
|
+
if not label_loop:
|
|
741
|
+
raise NotImplementedError("ctc_loss: label_loop=False not implemented for PyTorch")
|
|
742
|
+
if log_probs.dtype == torch.bfloat16:
|
|
743
|
+
# Currently (PyTorch 2.5), ctc_loss does not support bfloat16.
|
|
744
|
+
log_probs = log_probs.to(torch.float32)
|
|
745
|
+
loss_raw = torch.nn.functional.ctc_loss(
|
|
746
|
+
log_probs=log_probs,
|
|
747
|
+
targets=targets_raw,
|
|
748
|
+
input_lengths=input_lengths,
|
|
749
|
+
target_lengths=targets_lengths,
|
|
750
|
+
blank=blank_index,
|
|
751
|
+
zero_infinity=True,
|
|
752
|
+
reduction="none",
|
|
753
|
+
)
|
|
722
754
|
if len(batch_dims) != 1:
|
|
723
755
|
loss_raw = torch.reshape(loss_raw, logits_raw_shape[1:-1])
|
|
724
756
|
loss = Tensor(
|
|
@@ -729,6 +761,103 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
729
761
|
)
|
|
730
762
|
return loss
|
|
731
763
|
|
|
764
|
+
@staticmethod
|
|
765
|
+
def ctc_best_path(
|
|
766
|
+
*,
|
|
767
|
+
logits: Tensor,
|
|
768
|
+
logits_normalized: bool = False,
|
|
769
|
+
targets: Tensor,
|
|
770
|
+
input_spatial_dim: Dim,
|
|
771
|
+
targets_spatial_dim: Dim,
|
|
772
|
+
blank_index: int,
|
|
773
|
+
label_loop: bool = True,
|
|
774
|
+
) -> Tensor:
|
|
775
|
+
"""CTC best path"""
|
|
776
|
+
assert targets.sparse_dim and targets.sparse_dim.dimension <= logits.feature_dim.dimension
|
|
777
|
+
# PyTorch expects the logits to be of shape (T, B, C) where T is the input spatial dim.
|
|
778
|
+
batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
|
|
779
|
+
batch_dims_targets = targets.remaining_dims(targets_spatial_dim)
|
|
780
|
+
if set(batch_dims) != set(batch_dims_targets):
|
|
781
|
+
# Need to broadcast.
|
|
782
|
+
logits = rf.expand_dims(logits, [d for d in batch_dims_targets if d not in batch_dims])
|
|
783
|
+
targets = rf.expand_dims(targets, [d for d in batch_dims if d not in batch_dims_targets])
|
|
784
|
+
batch_dims = logits.remaining_dims((input_spatial_dim, logits.feature_dim))
|
|
785
|
+
batch_shape = [d.get_dim_value() for d in batch_dims]
|
|
786
|
+
batch_n_elems = prod(batch_shape)
|
|
787
|
+
logits = logits.copy_transpose([input_spatial_dim] + batch_dims + [logits.feature_dim])
|
|
788
|
+
logits_raw: torch.Tensor = logits.raw_tensor
|
|
789
|
+
input_lengths: torch.Tensor = input_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(batch_dims)
|
|
790
|
+
if input_lengths.numel() != batch_n_elems:
|
|
791
|
+
input_lengths = input_lengths.expand(batch_shape)
|
|
792
|
+
if len(batch_dims) != 1:
|
|
793
|
+
logits_raw = torch.reshape(
|
|
794
|
+
logits_raw, logits_raw.shape[:1] + (batch_n_elems,) + logits_raw.shape[-1:]
|
|
795
|
+
) # [T, B', C]
|
|
796
|
+
input_lengths = torch.reshape(input_lengths, (batch_n_elems,)) # [B']
|
|
797
|
+
if logits_normalized:
|
|
798
|
+
log_probs = logits_raw
|
|
799
|
+
else:
|
|
800
|
+
log_probs = torch.nn.functional.log_softmax(logits_raw, dim=-1)
|
|
801
|
+
# PyTorch expects the targets to be of shape (B, S) where S is the targets spatial dim.
|
|
802
|
+
targets_raw = targets.copy_compatible_to_dims_raw(batch_dims + [targets_spatial_dim]) # [B..., S]
|
|
803
|
+
targets_raw_shape = batch_shape + [targets_spatial_dim.get_dim_value()]
|
|
804
|
+
if targets_raw.numel() != prod(targets_raw_shape):
|
|
805
|
+
targets_raw = targets_raw.expand(targets_raw_shape)
|
|
806
|
+
targets_lengths = targets_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(batch_dims)
|
|
807
|
+
if targets_lengths.numel() != batch_n_elems:
|
|
808
|
+
targets_lengths = targets_lengths.expand(batch_shape)
|
|
809
|
+
if len(batch_dims) != 1:
|
|
810
|
+
targets_raw = torch.reshape(targets_raw, (batch_n_elems, targets_raw.shape[-1])) # [B', S]
|
|
811
|
+
targets_lengths = torch.reshape(targets_lengths, (batch_n_elems,)) # [B']
|
|
812
|
+
alignment_raw = native_op.ctc_best_path(
|
|
813
|
+
logits=log_probs,
|
|
814
|
+
logits_normalize=True,
|
|
815
|
+
logits_seq_lens=input_lengths,
|
|
816
|
+
logits_time_major=True,
|
|
817
|
+
targets=targets_raw,
|
|
818
|
+
targets_seq_lens=targets_lengths,
|
|
819
|
+
blank_index=blank_index,
|
|
820
|
+
label_loop=label_loop,
|
|
821
|
+
) # (time,batch)
|
|
822
|
+
if len(batch_dims) != 1:
|
|
823
|
+
alignment_raw = torch.reshape(alignment_raw, log_probs.shape[:-1])
|
|
824
|
+
alignment = Tensor(
|
|
825
|
+
name="ctc_best_path",
|
|
826
|
+
dims=[input_spatial_dim] + batch_dims,
|
|
827
|
+
sparse_dim=logits.feature_dim,
|
|
828
|
+
raw_tensor=alignment_raw,
|
|
829
|
+
dtype=TorchBackend.get_dtype_name_raw(alignment_raw),
|
|
830
|
+
)
|
|
831
|
+
return alignment
|
|
832
|
+
|
|
833
|
+
@staticmethod
|
|
834
|
+
def have_edit_distance() -> bool:
|
|
835
|
+
"""whether edit distance is available"""
|
|
836
|
+
return True
|
|
837
|
+
|
|
838
|
+
@staticmethod
|
|
839
|
+
def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim) -> Tensor:
|
|
840
|
+
"""edit distance"""
|
|
841
|
+
a_batch_dims = a.remaining_dims(a_spatial_dim)
|
|
842
|
+
b_batch_dims = b.remaining_dims(b_spatial_dim)
|
|
843
|
+
assert set(a_batch_dims) == set(b_batch_dims), "edit_distance: batch dims must match"
|
|
844
|
+
a_raw = a.copy_compatible_to_dims_raw(a_batch_dims + [a_spatial_dim])
|
|
845
|
+
b_raw = b.copy_compatible_to_dims_raw(a_batch_dims + [b_spatial_dim])
|
|
846
|
+
a_seq_len = a_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(a_batch_dims)
|
|
847
|
+
b_seq_len = b_spatial_dim.dyn_size_ext.copy_compatible_to_dims_raw(a_batch_dims)
|
|
848
|
+
batch_shape = None
|
|
849
|
+
if len(a_batch_dims) != 1:
|
|
850
|
+
batch_shape = [d.get_dim_value() for d in a_batch_dims]
|
|
851
|
+
batch_n_elems = prod(batch_shape)
|
|
852
|
+
a_raw = torch.reshape(a_raw.raw_tensor, (batch_n_elems, a_spatial_dim.get_dim_value()))
|
|
853
|
+
b_raw = torch.reshape(b_raw.raw_tensor, (batch_n_elems, b_spatial_dim.get_dim_value()))
|
|
854
|
+
a_seq_len = torch.reshape(a_seq_len.raw_tensor, (batch_n_elems,))
|
|
855
|
+
b_seq_len = torch.reshape(b_seq_len.raw_tensor, (batch_n_elems,))
|
|
856
|
+
dist_raw = native_op.edit_distance(a_raw, a_seq_len, b_raw, b_seq_len)
|
|
857
|
+
if len(a_batch_dims) != 1:
|
|
858
|
+
dist_raw = torch.reshape(dist_raw, batch_shape)
|
|
859
|
+
return rf.convert_to_tensor(dist_raw, name="edit_distance", dims=a_batch_dims)
|
|
860
|
+
|
|
732
861
|
@staticmethod
|
|
733
862
|
def create_parameter_raw(tensor: rf.Parameter, *, device: Optional[str] = None) -> torch.nn.Parameter:
|
|
734
863
|
"""
|
|
@@ -884,7 +1013,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
884
1013
|
:param perm: e.g. [0, 2, 1]
|
|
885
1014
|
:return: permuted (transposed) raw tensor; wraps torch.permute
|
|
886
1015
|
"""
|
|
887
|
-
if all(p == i for i, p in enumerate(perm)):
|
|
1016
|
+
if all([p == i for i, p in enumerate(perm)]):
|
|
888
1017
|
return raw_tensor
|
|
889
1018
|
return torch.permute(raw_tensor, tuple(perm))
|
|
890
1019
|
|
|
@@ -1361,12 +1490,24 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1361
1490
|
a_dims = a.dims
|
|
1362
1491
|
b_dims = b.dims
|
|
1363
1492
|
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1493
|
+
if not all(dim in a_dims for dim in reduce) or not all(dim in b_dims for dim in reduce):
|
|
1494
|
+
# revert to the generic einsum implementation
|
|
1495
|
+
assert all(dim in a_dims + b_dims for dim in reduce), "Some reduce Dims not in a or b."
|
|
1496
|
+
result_dims = [dim for dim in a_dims if dim not in reduce] + [
|
|
1497
|
+
dim for dim in b_dims if dim not in reduce and dim not in a_dims
|
|
1498
|
+
]
|
|
1499
|
+
map_to_letter = {}
|
|
1500
|
+
for dim in a_dims + b_dims:
|
|
1501
|
+
if dim not in map_to_letter:
|
|
1502
|
+
map_to_letter[dim] = chr(97 + len(map_to_letter)) # 'a', 'b', 'c', ...
|
|
1503
|
+
a_subscript = "".join(map_to_letter[dim] for dim in a_dims)
|
|
1504
|
+
b_subscript = "".join(map_to_letter[dim] for dim in b_dims)
|
|
1505
|
+
out_subscript = "".join(map_to_letter[dim] for dim in result_dims)
|
|
1506
|
+
raw_result = torch.einsum(f"{a_subscript},{b_subscript}->{out_subscript}", a.raw_tensor, b.raw_tensor)
|
|
1507
|
+
result_tensor = Tensor(
|
|
1508
|
+
"einsum", dims=result_dims, raw_tensor=raw_result, dtype=TorchBackend.get_dtype_name_raw(raw_result)
|
|
1509
|
+
)
|
|
1510
|
+
return result_tensor
|
|
1370
1511
|
|
|
1371
1512
|
if len(reduce) > 1:
|
|
1372
1513
|
reduce = list(reduce)
|
|
@@ -1776,6 +1917,9 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1776
1917
|
remaining_dims = [d for d in tensor.dims if d not in mask.dims]
|
|
1777
1918
|
tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
|
|
1778
1919
|
in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
|
|
1920
|
+
if any([in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)]):
|
|
1921
|
+
# unbroadcast
|
|
1922
|
+
in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
|
|
1779
1923
|
if mask.raw_tensor.device.type == "meta":
|
|
1780
1924
|
# This is not supported, but also, we would anyway not know the out shape.
|
|
1781
1925
|
# However, instead of erroring, just assume some dummy mask.
|
returnn/torch/frontend/bridge.py
CHANGED
|
@@ -136,6 +136,15 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
136
136
|
def _get_name(self):
|
|
137
137
|
return self._rf_module.__class__.__name__ + "[RF→PT]"
|
|
138
138
|
|
|
139
|
+
def __repr__(self) -> str:
|
|
140
|
+
"""
|
|
141
|
+
Return a custom repr for Sequential/ModuleList that compresses repeated module representations if possible,
|
|
142
|
+
otherwise fallback to default behavior.
|
|
143
|
+
"""
|
|
144
|
+
if _can_use_compact_repr(self):
|
|
145
|
+
return _repr_compact(self)
|
|
146
|
+
return super().__repr__()
|
|
147
|
+
|
|
139
148
|
@property
|
|
140
149
|
def rf_module(self) -> rf.Module:
|
|
141
150
|
"""RF module"""
|
|
@@ -193,3 +202,55 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
193
202
|
# See similar logic in torch.nn.Module._apply.
|
|
194
203
|
pt_param = torch.nn.Parameter(tensor, tensor.requires_grad)
|
|
195
204
|
rf_param.raw_tensor = pt_param
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _can_use_compact_repr(self: RFModuleAsPTModule) -> bool:
|
|
208
|
+
return list(self._modules.keys()) == [str(i) for i in range(len(self._modules))]
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _repr_compact(self: RFModuleAsPTModule) -> str:
|
|
212
|
+
"""
|
|
213
|
+
Return a custom repr for Sequential/ModuleList that compresses repeated module representations.
|
|
214
|
+
Code copied and adapted from torch.nn.ModuleList.__repr__.
|
|
215
|
+
"""
|
|
216
|
+
list_of_reprs = [repr(item) for item in self._modules.values()]
|
|
217
|
+
if len(list_of_reprs) == 0:
|
|
218
|
+
return self._get_name() + "()"
|
|
219
|
+
|
|
220
|
+
start_end_indices = [[0, 0]]
|
|
221
|
+
repeated_blocks = [list_of_reprs[0]]
|
|
222
|
+
for i, r in enumerate(list_of_reprs[1:], 1):
|
|
223
|
+
if r == repeated_blocks[-1]:
|
|
224
|
+
start_end_indices[-1][1] += 1
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
start_end_indices.append([i, i])
|
|
228
|
+
repeated_blocks.append(r)
|
|
229
|
+
|
|
230
|
+
lines = []
|
|
231
|
+
main_str = self._get_name() + "("
|
|
232
|
+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
|
233
|
+
local_repr = f"({start_id}): {b}" # default repr
|
|
234
|
+
|
|
235
|
+
if start_id != end_id:
|
|
236
|
+
n = end_id - start_id + 1
|
|
237
|
+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
|
|
238
|
+
|
|
239
|
+
local_repr = _add_indent(local_repr, 2)
|
|
240
|
+
lines.append(local_repr)
|
|
241
|
+
|
|
242
|
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
|
243
|
+
main_str += ")"
|
|
244
|
+
return main_str
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _add_indent(s_: str, num_spaces: int) -> str:
|
|
248
|
+
s = s_.split("\n")
|
|
249
|
+
# don't do anything for single-line stuff
|
|
250
|
+
if len(s) == 1:
|
|
251
|
+
return s_
|
|
252
|
+
first = s.pop(0)
|
|
253
|
+
s = [(num_spaces * " ") + line for line in s]
|
|
254
|
+
s = "\n".join(s)
|
|
255
|
+
s = first + "\n" + s
|
|
256
|
+
return s
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helpers to improve torch.compile on RF code.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Any, Iterable, List, Tuple
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from returnn.tensor import Tensor, Dim
|
|
10
|
+
|
|
11
|
+
# noinspection PyProtectedMember
|
|
12
|
+
from returnn.frontend import _native
|
|
13
|
+
|
|
14
|
+
_is_set_up = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup():
|
|
18
|
+
"""
|
|
19
|
+
Set up the torch.compile helpers for RF code, also including :class:`Tensor` and :class:`Dim`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
global _is_set_up
|
|
23
|
+
if _is_set_up:
|
|
24
|
+
return
|
|
25
|
+
_is_set_up = True # only try once
|
|
26
|
+
|
|
27
|
+
assert not _native.is_set_up(), "Call this setup() as early as possible."
|
|
28
|
+
_native.set_enabled(False)
|
|
29
|
+
|
|
30
|
+
# We have lots of dynamic shapes.
|
|
31
|
+
os.environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
|
|
32
|
+
|
|
33
|
+
# noinspection PyProtectedMember
|
|
34
|
+
from torch.utils._pytree import register_pytree_node
|
|
35
|
+
|
|
36
|
+
register_pytree_node(Tensor, _tensor_flatten, _tensor_unflatten)
|
|
37
|
+
register_pytree_node(Dim, _dim_flatten, _dim_unflatten)
|
|
38
|
+
|
|
39
|
+
Dim.get_dim_value = _dim_get_dim_value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _tensor_flatten(t: Tensor) -> Tuple[List[Any], Any]:
|
|
43
|
+
"""
|
|
44
|
+
Flatten the tensor for PyTree.
|
|
45
|
+
"""
|
|
46
|
+
return [t.raw_tensor, t.dims, t.sparse_dim], [
|
|
47
|
+
t.name,
|
|
48
|
+
t.dtype,
|
|
49
|
+
t.version,
|
|
50
|
+
t.feature_dim_axis_or_unspecified,
|
|
51
|
+
t.time_dim_axis_or_unspecified,
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _tensor_unflatten(values: Iterable[Any], metadata: Any) -> Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Unflatten the tensor from PyTree.
|
|
58
|
+
"""
|
|
59
|
+
raw_tensor, dims, sparse_dim = values
|
|
60
|
+
name, dtype, version, feature_dim_axis, time_dim_axis = metadata
|
|
61
|
+
return Tensor(
|
|
62
|
+
name=name,
|
|
63
|
+
dims=dims,
|
|
64
|
+
dtype=dtype,
|
|
65
|
+
sparse_dim=sparse_dim,
|
|
66
|
+
feature_dim_axis=feature_dim_axis,
|
|
67
|
+
time_dim_axis=time_dim_axis,
|
|
68
|
+
raw_tensor=raw_tensor,
|
|
69
|
+
version=version,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _dim_flatten(d: Dim) -> Tuple[List[Any], Any]:
|
|
74
|
+
"""
|
|
75
|
+
Flatten the dim for PyTree.
|
|
76
|
+
"""
|
|
77
|
+
return [d.dyn_size_ext], [d.name, d.dimension, d.size]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _dim_unflatten(values: Iterable[Any], metadata: Any) -> Dim:
|
|
81
|
+
"""
|
|
82
|
+
Unflatten the dim from PyTree.
|
|
83
|
+
"""
|
|
84
|
+
(dyn_size_ext,) = values
|
|
85
|
+
name, dimension, size = metadata
|
|
86
|
+
# TODO this creates a new instance... this is maybe wrong?
|
|
87
|
+
return Dim(name=name, dimension=dimension, size=size, dyn_size_ext=dyn_size_ext)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _dim_get_dim_value(self: Dim) -> int:
|
|
91
|
+
"""
|
|
92
|
+
Infers the dim this axis should have if unbroadcasted.
|
|
93
|
+
If `self.src_data` has a placeholder, will use the shape from there.
|
|
94
|
+
Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
|
|
95
|
+
|
|
96
|
+
:return: max(size or dyn_size)
|
|
97
|
+
"""
|
|
98
|
+
res = self.get_dim_value_tensor()
|
|
99
|
+
if isinstance(res, Tensor):
|
|
100
|
+
assert res.dims == ()
|
|
101
|
+
assert res.raw_tensor is not None
|
|
102
|
+
# Specifically PyTorch would then treat it as a SymInt in torch.compile,
|
|
103
|
+
# which is important to have for some torch functions (e.g. torch.tile and others).
|
|
104
|
+
return int(res.raw_tensor)
|
|
105
|
+
assert isinstance(res, int)
|
|
106
|
+
return res
|
returnn/torch/util/array_.py
CHANGED
|
@@ -60,3 +60,33 @@ def nonzero(mask: torch.Tensor, *, out_len: Union[int, torch.Tensor]) -> torch.T
|
|
|
60
60
|
idx = torch.argsort(mask.to(torch.int8), stable=True, descending=True) # [in_len]
|
|
61
61
|
idx = idx[:out_len] # [out_len]
|
|
62
62
|
return idx
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def sequence_mask(lengths: torch.Tensor, *, maxlen: Optional[int] = None) -> torch.Tensor:
|
|
66
|
+
"""
|
|
67
|
+
Creates a boolean mask from sequence lengths.
|
|
68
|
+
|
|
69
|
+
:param lengths: Tensor of shape [batch_size...] containing sequence lengths
|
|
70
|
+
:param maxlen: Maximum length of the sequences. If None, uses the maximum value in lengths.
|
|
71
|
+
:return: A boolean mask tensor of shape [batch_size..., maxlen]
|
|
72
|
+
"""
|
|
73
|
+
if maxlen is None:
|
|
74
|
+
maxlen = lengths.max()
|
|
75
|
+
indices = torch.arange(0, maxlen, dtype=lengths.dtype, device=lengths.device)
|
|
76
|
+
mask = indices < lengths[..., None]
|
|
77
|
+
return mask
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def sequence_mask_time_major(lengths: torch.Tensor, *, maxlen: Optional[int] = None) -> torch.Tensor:
|
|
81
|
+
"""
|
|
82
|
+
Creates a boolean mask from sequence lengths.
|
|
83
|
+
|
|
84
|
+
:param lengths: Tensor of shape [batch_size...] containing sequence lengths
|
|
85
|
+
:param maxlen: Maximum length of the sequences. If None, uses the maximum value in lengths.
|
|
86
|
+
:return: A boolean mask tensor of shape [maxlen, batch_size...]
|
|
87
|
+
"""
|
|
88
|
+
if maxlen is None:
|
|
89
|
+
maxlen = lengths.max()
|
|
90
|
+
indices = torch.arange(0, maxlen, dtype=lengths.dtype, device=lengths.device)
|
|
91
|
+
mask = indices[(slice(None),) + (None,) * lengths.ndim] < lengths[None]
|
|
92
|
+
return mask
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Async device assertion utility.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import threading
|
|
8
|
+
from textwrap import dedent
|
|
9
|
+
from queue import Queue
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def assert_(cond: torch.Tensor, message: str):
|
|
14
|
+
"""
|
|
15
|
+
Does a device-side assertion.
|
|
16
|
+
For CPU, this will directly check the condition and raise an error if false.
|
|
17
|
+
For CUDA devices, this runs asynchronously on a separate thread (to avoid pin_memory in the current thread),
|
|
18
|
+
and non-blocking (does not trigger a CUDA sync).
|
|
19
|
+
"""
|
|
20
|
+
if cond.device.type == "cpu":
|
|
21
|
+
if not cond.item():
|
|
22
|
+
raise AssertionError(message)
|
|
23
|
+
return
|
|
24
|
+
elif cond.device.type == "cuda":
|
|
25
|
+
# This triggers the Lazy initialization on first call
|
|
26
|
+
_CudaAsyncWorker().push(cond, message)
|
|
27
|
+
else:
|
|
28
|
+
raise NotImplementedError(f"assert_ not implemented for device type: {cond.device.type}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _get_ext():
|
|
32
|
+
global _ext
|
|
33
|
+
if _ext:
|
|
34
|
+
return _ext
|
|
35
|
+
|
|
36
|
+
from .native_op_code_compiler import OpCodeCompiler
|
|
37
|
+
|
|
38
|
+
compiler = OpCodeCompiler(
|
|
39
|
+
"async_assert_ext", use_cuda_if_available=True, code=_cpp_source + _cuda_source, is_python_module=True
|
|
40
|
+
)
|
|
41
|
+
_ext = compiler.load_module()
|
|
42
|
+
return _ext
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
_ext = None
|
|
46
|
+
|
|
47
|
+
_cpp_source = dedent("""\
|
|
48
|
+
#include <torch/extension.h>
|
|
49
|
+
|
|
50
|
+
void async_assert_cuda(const at::Tensor& cond, const at::Tensor& msg_tensor);
|
|
51
|
+
|
|
52
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
53
|
+
m.def("async_assert_cuda", torch::wrap_pybind_function(async_assert_cuda), "Asynchronous CUDA assert");
|
|
54
|
+
}
|
|
55
|
+
""")
|
|
56
|
+
|
|
57
|
+
_cuda_source = dedent("""\
|
|
58
|
+
#include <torch/types.h>
|
|
59
|
+
#include <cuda.h>
|
|
60
|
+
#include <cuda_runtime.h>
|
|
61
|
+
#include <torch/extension.h>
|
|
62
|
+
#include <ATen/cuda/CUDAContext.h>
|
|
63
|
+
#include <c10/cuda/CUDACachingAllocator.h>
|
|
64
|
+
#include <assert.h>
|
|
65
|
+
|
|
66
|
+
__global__ void assert_kernel(const bool* cond, const char* msg) {
|
|
67
|
+
if (blockIdx.x == 0 && threadIdx.x == 0) {
|
|
68
|
+
if (!(*cond)) {
|
|
69
|
+
printf("\\n[GPU ASSERT FAILED]: %s\\n", msg);
|
|
70
|
+
assert(false);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
void async_assert_cuda(const at::Tensor& cond, const at::Tensor& msg_tensor) {
|
|
76
|
+
auto stream = at::cuda::getCurrentCUDAStream();
|
|
77
|
+
|
|
78
|
+
// Safety: Protect memory from GC while the kernel is in flight
|
|
79
|
+
c10::cuda::CUDACachingAllocator::recordStream(cond.storage().data_ptr(), stream);
|
|
80
|
+
c10::cuda::CUDACachingAllocator::recordStream(msg_tensor.storage().data_ptr(), stream);
|
|
81
|
+
|
|
82
|
+
assert_kernel<<<1, 1, 0, stream>>>(
|
|
83
|
+
cond.data_ptr<bool>(),
|
|
84
|
+
(const char*)msg_tensor.data_ptr<uint8_t>()
|
|
85
|
+
);
|
|
86
|
+
}
|
|
87
|
+
""")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class _CudaAsyncWorker:
|
|
91
|
+
_instance = None
|
|
92
|
+
_lock = threading.Lock()
|
|
93
|
+
|
|
94
|
+
def __new__(cls):
|
|
95
|
+
with cls._lock:
|
|
96
|
+
if cls._instance is None:
|
|
97
|
+
cls._instance = super(_CudaAsyncWorker, cls).__new__(cls)
|
|
98
|
+
cls._instance._init_worker()
|
|
99
|
+
return cls._instance
|
|
100
|
+
|
|
101
|
+
def _init_worker(self):
|
|
102
|
+
self.queue = Queue()
|
|
103
|
+
self.thread = threading.Thread(target=self._loop, daemon=True)
|
|
104
|
+
self.thread.start()
|
|
105
|
+
|
|
106
|
+
def _loop(self):
|
|
107
|
+
while True:
|
|
108
|
+
cond, message_str, stream = self.queue.get()
|
|
109
|
+
|
|
110
|
+
# Use the actual Stream object context
|
|
111
|
+
with torch.cuda.stream(stream):
|
|
112
|
+
# Convert string to pinned tensor (Avoiding read-only NP view)
|
|
113
|
+
msg_bytes = list(message_str.encode("utf-8"))
|
|
114
|
+
msg_cpu = torch.tensor(msg_bytes, dtype=torch.uint8, pin_memory=True)
|
|
115
|
+
msg_gpu = msg_cpu.to("cuda", non_blocking=True)
|
|
116
|
+
|
|
117
|
+
# Call JIT-compiled function
|
|
118
|
+
_get_ext().async_assert_cuda(cond, msg_gpu)
|
|
119
|
+
|
|
120
|
+
def push(self, cond: torch.Tensor, message: str):
|
|
121
|
+
"""push to queue"""
|
|
122
|
+
self.queue.put((cond, message, torch.cuda.current_stream()))
|