returnn 1.20250508.93313__py3-none-any.whl → 1.20250513.145447__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (67) hide show
  1. returnn/PKG-INFO +1 -1
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/datasets/basic.py +24 -25
  4. returnn/datasets/cached.py +4 -3
  5. returnn/datasets/distrib_files.py +1 -2
  6. returnn/datasets/generating.py +20 -20
  7. returnn/datasets/hdf.py +9 -9
  8. returnn/datasets/lm.py +25 -13
  9. returnn/datasets/meta.py +39 -38
  10. returnn/datasets/normalization_data.py +1 -1
  11. returnn/datasets/postprocessing.py +20 -13
  12. returnn/datasets/sprint.py +8 -7
  13. returnn/datasets/util/strings.py +0 -1
  14. returnn/datasets/util/vocabulary.py +3 -3
  15. returnn/extern/graph_editor/subgraph.py +1 -2
  16. returnn/extern/graph_editor/transform.py +1 -2
  17. returnn/extern/graph_editor/util.py +1 -2
  18. returnn/frontend/_backend.py +4 -3
  19. returnn/frontend/_utils.py +1 -1
  20. returnn/frontend/audio/mel.py +0 -1
  21. returnn/frontend/const.py +3 -3
  22. returnn/frontend/device.py +0 -1
  23. returnn/frontend/dropout.py +1 -1
  24. returnn/frontend/encoder/e_branchformer.py +1 -1
  25. returnn/frontend/loop.py +3 -3
  26. returnn/frontend/loss.py +0 -1
  27. returnn/frontend/matmul.py +0 -1
  28. returnn/frontend/run_ctx.py +9 -9
  29. returnn/frontend/signal.py +0 -1
  30. returnn/frontend/types.py +2 -4
  31. returnn/native_op.py +13 -0
  32. returnn/sprint/cache.py +2 -4
  33. returnn/sprint/interface.py +3 -4
  34. returnn/tensor/_dim_extra.py +9 -9
  35. returnn/tensor/_tensor_extra.py +20 -19
  36. returnn/tensor/_tensor_op_overloads.py +0 -1
  37. returnn/tensor/tensor.py +1 -1
  38. returnn/tensor/tensor_dict.py +9 -9
  39. returnn/tf/engine.py +60 -65
  40. returnn/tf/frontend_layers/_backend.py +3 -3
  41. returnn/tf/frontend_layers/cond.py +6 -6
  42. returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
  43. returnn/tf/frontend_layers/layer.py +12 -12
  44. returnn/tf/frontend_layers/loop.py +3 -3
  45. returnn/tf/frontend_layers/make_layer.py +0 -1
  46. returnn/tf/layers/base.py +56 -49
  47. returnn/tf/layers/basic.py +60 -65
  48. returnn/tf/layers/rec.py +74 -74
  49. returnn/tf/native_op.py +1 -3
  50. returnn/tf/network.py +60 -57
  51. returnn/tf/updater.py +3 -3
  52. returnn/tf/util/basic.py +24 -23
  53. returnn/torch/data/extern_data.py +4 -5
  54. returnn/torch/data/pipeline.py +3 -4
  55. returnn/torch/engine.py +16 -16
  56. returnn/torch/frontend/_backend.py +15 -15
  57. returnn/torch/frontend/bridge.py +3 -3
  58. returnn/torch/updater.py +8 -9
  59. returnn/torch/util/debug_inf_nan.py +0 -2
  60. returnn/torch/util/exception_helper.py +1 -1
  61. returnn/torch/util/scaled_gradient.py +0 -1
  62. returnn/util/basic.py +1 -2
  63. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/METADATA +1 -1
  64. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/RECORD +67 -67
  65. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/LICENSE +0 -0
  66. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/WHEEL +0 -0
  67. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/top_level.txt +0 -0
@@ -263,7 +263,7 @@ def _slice_find_sparse_dim(v: Union[Tensor, slice, Any]) -> Optional[Dim]:
263
263
 
264
264
 
265
265
  def _map_slice_value_raw(
266
- v: Union[None, slice, int, numpy.number, numpy.ndarray, Tensor[T]]
266
+ v: Union[None, slice, int, numpy.number, numpy.ndarray, Tensor[T]],
267
267
  ) -> Union[None, slice, int, numpy.number, T]:
268
268
  if v is None:
269
269
  return None
@@ -2,7 +2,6 @@
2
2
  Mel filterbank.
3
3
  """
4
4
 
5
-
6
5
  from __future__ import annotations
7
6
  from typing import Optional, Union, Tuple
8
7
  import functools
returnn/frontend/const.py CHANGED
@@ -54,9 +54,9 @@ def full(
54
54
  "Use rf.convert_to_tensor to convert an arbitrary array to a tensor."
55
55
  )
56
56
  if isinstance(fill_value, Tensor):
57
- assert (
58
- fill_value.dims == ()
59
- ), f"full/fill/constant: expect scalar fill_value, got tensor with shape {fill_value.dims}."
57
+ assert fill_value.dims == (), (
58
+ f"full/fill/constant: expect scalar fill_value, got tensor with shape {fill_value.dims}."
59
+ )
60
60
  return global_backend.full(
61
61
  dims, fill_value, dtype=dtype, device=device, sparse_dim=sparse_dim, feature_dim=feature_dim
62
62
  )
@@ -2,7 +2,6 @@
2
2
  Device handling.
3
3
  """
4
4
 
5
-
6
5
  from __future__ import annotations
7
6
  from typing import Optional
8
7
  from contextlib import contextmanager
@@ -50,7 +50,7 @@ def dropout(
50
50
  raise ValueError(f"dropout axis {axis} not in source {source}")
51
51
 
52
52
  if isinstance(keep_prob, (float, int)) and not 0 < keep_prob <= 1:
53
- raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob)
53
+ raise ValueError("keep_prob must be a scalar tensor or a float in the range (0, 1], got %g" % keep_prob)
54
54
 
55
55
  # Do nothing if we know keep_prob == 1
56
56
  if isinstance(keep_prob, (float, int)) and keep_prob == 1:
@@ -268,7 +268,7 @@ class Merge(rf.Module):
268
268
 
269
269
 
270
270
  def _make_activation(
271
- activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module]
271
+ activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module],
272
272
  ) -> Union[Callable[[Tensor], Tensor], rf.Module]:
273
273
  if isinstance(activation, dict):
274
274
  activation = rf.build_from_dict(activation)
returnn/frontend/loop.py CHANGED
@@ -273,9 +273,9 @@ def _check_matching_loop_var_templates(loop_var_templates: S, loop_vars: S):
273
273
  x._push_back_delayed_check()
274
274
 
275
275
  else: # other cases: just check same type
276
- assert type(template) is type(
277
- x
278
- ), f"loop var {path} template type {type(template)} does not match var type {type(x)}"
276
+ assert type(template) is type(x), (
277
+ f"loop var {path} template type {type(template)} does not match var type {type(x)}"
278
+ )
279
279
  assert not isinstance(x, Tensor), f"loop var {path} is a Tensor but should not be"
280
280
 
281
281
  tree.map_structure_with_path(_check, loop_var_templates, loop_vars)
returnn/frontend/loss.py CHANGED
@@ -137,7 +137,6 @@ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim,
137
137
  # We are going diagonal over (Ta+1) and (Tb+1). (Similar as RETURNN native EditDistanceOp.)
138
138
  # You need to draw the grid on paper to understand all the index math...
139
139
  for u in range(1, n_a_max_len + n_b_max_len + 1):
140
-
141
140
  prev2_dist, _ = rf.slice(
142
141
  buffer, axis=buffer_dim, start=buffer_offsets[u % 3], size=b_spatial_dim1, out_dim=b_spatial_dim1
143
142
  ) # [Tb+1,B]
@@ -2,7 +2,6 @@
2
2
  Dot / matmul
3
3
  """
4
4
 
5
-
6
5
  from __future__ import annotations
7
6
  from typing import Sequence, Union, TypeVar
8
7
  from returnn.tensor import Tensor, Dim
@@ -306,19 +306,19 @@ class RunCtx:
306
306
  assert self.stage == "forward_step"
307
307
 
308
308
  if self.expected_outputs is not None:
309
- assert (
310
- name in self.expected_outputs.data
311
- ), f"mark_as_output: unexpected output {name!r}, we expect outputs: {self.expected_outputs}"
309
+ assert name in self.expected_outputs.data, (
310
+ f"mark_as_output: unexpected output {name!r}, we expect outputs: {self.expected_outputs}"
311
+ )
312
312
  expected_output = self.expected_outputs.data[name] if self.expected_outputs else None
313
- assert dims is None or (
314
- isinstance(dims, (list, tuple)) and all(isinstance(dim, Dim) for dim in dims)
315
- ), f"dims should be a tuple of Dims, got {dims}"
313
+ assert dims is None or (isinstance(dims, (list, tuple)) and all(isinstance(dim, Dim) for dim in dims)), (
314
+ f"dims should be a tuple of Dims, got {dims}"
315
+ )
316
316
  if dims is None and expected_output is not None:
317
317
  dims = expected_output.dims
318
318
  if dims is not None and expected_output is not None:
319
- assert expected_output.dims == tuple(
320
- dims
321
- ), f"mark_as_output: {name!r} dims mismatch from expected output, given {dims}, expected {expected_output}"
319
+ assert expected_output.dims == tuple(dims), (
320
+ f"mark_as_output: {name!r} dims mismatch from expected output, given {dims}, expected {expected_output}"
321
+ )
322
322
 
323
323
  if not isinstance(tensor, Tensor):
324
324
  assert isinstance(tensor, _backend.global_backend.RawTensorType)
@@ -2,7 +2,6 @@
2
2
  stft etc
3
3
  """
4
4
 
5
-
6
5
  from __future__ import annotations
7
6
  from typing import Optional, Tuple
8
7
  from returnn.tensor import Tensor, Dim
returnn/frontend/types.py CHANGED
@@ -19,15 +19,13 @@ ItemKeyType = Union[RawTensorTypes, Tensor, slice, Sequence[Union[RawTensorTypes
19
19
  class GetModelFunc(Protocol):
20
20
  """get model func"""
21
21
 
22
- def __call__(self, *, epoch: int, step: int) -> rf.Module:
23
- ...
22
+ def __call__(self, *, epoch: int, step: int) -> rf.Module: ...
24
23
 
25
24
 
26
25
  class StepFunc(Protocol):
27
26
  """step func"""
28
27
 
29
- def __call__(self, *, model: rf.Module, extern_data: TensorDict) -> None:
30
- ...
28
+ def __call__(self, *, model: rf.Module, extern_data: TensorDict) -> None: ...
31
29
 
32
30
 
33
31
  def get_raw_tensor_type() -> Type:
returnn/native_op.py CHANGED
@@ -291,6 +291,7 @@ class LstmGenericBase(NativeOpGenBase):
291
291
  :param H: gates and cell state. 3d (time,batch,dim*4)
292
292
  :param d: final cell state. 2d (batch,dim)
293
293
  """
294
+
294
295
  in_info = (
295
296
  {
296
297
  "name": "Z",
@@ -542,6 +543,7 @@ class LstmLowMem(NativeOpGenBase):
542
543
  :param C: cell states. 3d (time,batch,dim). gradient ignored!
543
544
  :param d: final cell state. 2d (batch,dim)
544
545
  """
546
+
545
547
  in_info = (
546
548
  {"name": "X", "ndim": 3, "shape": (None, None, None), "need_contiguous": True},
547
549
  {"name": "W", "ndim": 2, "shape": (None, None), "need_contiguous": True},
@@ -994,6 +996,7 @@ class NativeLstm2(NativeOpGenBase):
994
996
  :param H: cell-in + gates. 3d (time,batch,dim*4). gradient ignored!
995
997
  :param d: final cell state. 2d (batch,dim)
996
998
  """
999
+
997
1000
  in_info = (
998
1001
  {"name": "X", "ndim": 3, "shape": (None, None, None), "need_contiguous": True},
999
1002
  {"name": "W", "ndim": 2, "shape": (None, None), "need_contiguous": True},
@@ -1423,6 +1426,7 @@ class TwoDLSTM(NativeOpGenBase):
1423
1426
  :param H: gates and cell state. 4d (timeS,timeT,batch,dim*5) ?
1424
1427
  :param d: final cell state. 3d (timeT,batch,dim)
1425
1428
  """
1429
+
1426
1430
  in_info = (
1427
1431
  {
1428
1432
  "name": "X",
@@ -3198,6 +3202,7 @@ class FastBaumWelchOp(NativeOpGenBase):
3198
3202
  outputs:
3199
3203
  :param output: Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores
3200
3204
  """
3205
+
3201
3206
  in_info = (
3202
3207
  {
3203
3208
  "name": "am_scores",
@@ -3620,6 +3625,7 @@ class MultiEndFastBaumWelchOp(NativeOpGenBase):
3620
3625
  outputs:
3621
3626
  :param output: Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores
3622
3627
  """
3628
+
3623
3629
  in_info = (
3624
3630
  {
3625
3631
  "name": "am_scores",
@@ -4497,6 +4503,7 @@ class FastViterbiOp(NativeOpGenBase):
4497
4503
  :param output: Viterbi (hard) alignment, scores in +log space. 2d (time,batch)
4498
4504
  :param scores: (batch,)
4499
4505
  """
4506
+
4500
4507
  in_info = (
4501
4508
  {
4502
4509
  "name": "am_scores",
@@ -4865,6 +4872,7 @@ class GetCtcFsaFastBwOp(NativeOpGenBase):
4865
4872
  `num_edges` should be `n_batch * (5 * (n_time - 1) + 10)`
4866
4873
  (see construction in kernel why that number).
4867
4874
  """
4875
+
4868
4876
  in_info = (
4869
4877
  {
4870
4878
  "name": "targets",
@@ -5229,6 +5237,7 @@ class EditDistanceOp(NativeOpGenBase):
5229
5237
  outputs:
5230
5238
  :param output: 1d (batch,), int32, unnormalized edit distance
5231
5239
  """
5240
+
5232
5241
  in_info = (
5233
5242
  {
5234
5243
  "name": "a",
@@ -5414,6 +5423,7 @@ class OptimalCompletionEditDistanceOp(NativeOpGenBase):
5414
5423
  outputs:
5415
5424
  :param output: 1d (batch,), int32, unnormalized edit distance
5416
5425
  """
5426
+
5417
5427
  in_info = (
5418
5428
  {
5419
5429
  "name": "a",
@@ -5610,6 +5620,7 @@ class OptimalCompletionEditDistancePerSuccessorOp(NativeOpGenBase):
5610
5620
  outputs:
5611
5621
  :param output: 2d (batch,num_labels), int32, unnormalized edit distance
5612
5622
  """
5623
+
5613
5624
  in_info = (
5614
5625
  {
5615
5626
  "name": "a",
@@ -5880,6 +5891,7 @@ class NextEditDistanceRowOp(NativeOpGenBase):
5880
5891
  outputs:
5881
5892
  :param output: 2d (batch,b_time + 1), int32, next (unnormalized) edit distance row
5882
5893
  """
5894
+
5883
5895
  in_info = (
5884
5896
  {
5885
5897
  "name": "last_row",
@@ -6039,6 +6051,7 @@ class NextEditDistanceReduceOp(NativeOpGenBase):
6039
6051
  outputs:
6040
6052
  :param output: 2d (batch,n_labels), int32, next (unnormalized) (maybe optional) edit distance
6041
6053
  """
6054
+
6042
6055
  in_info = (
6043
6056
  {
6044
6057
  "name": "last_row",
returnn/sprint/cache.py CHANGED
@@ -7,7 +7,7 @@ This module is about reading (maybe later also writing) the Sprint archive forma
7
7
  """
8
8
 
9
9
  from __future__ import annotations
10
- from typing import List
10
+ from typing import List, Optional, Tuple
11
11
  import sys
12
12
  import os
13
13
  import typing
@@ -904,9 +904,7 @@ class MixtureSet:
904
904
  self.densities[n, 1] = cov_idx
905
905
 
906
906
  self.num_mixtures = self.read_u32()
907
- self.mixtures = [
908
- None
909
- ] * self.num_mixtures # type: typing.List[typing.Optional[typing.Tuple[typing.List[int],typing.List[float]]]] # nopep8
907
+ self.mixtures: List[Optional[Tuple[List[int], List[float]]]] = [None] * self.num_mixtures
910
908
  for n in range(self.num_mixtures):
911
909
  num_densities = self.read_u32()
912
910
  dns_idx = []
@@ -820,9 +820,9 @@ def _prepare_forwarding():
820
820
  assert engine
821
821
  assert config
822
822
  # Should already be set via setTargetMode().
823
- assert config.list("extract") == [
824
- "posteriors"
825
- ], "You need to have extract = posteriors in your RETURNN config. You have: %s" % config.list("extract")
823
+ assert config.list("extract") == ["posteriors"], (
824
+ "You need to have extract = posteriors in your RETURNN config. You have: %s" % config.list("extract")
825
+ )
826
826
 
827
827
  # Load network.
828
828
  engine.init_network_from_config(config)
@@ -870,7 +870,6 @@ def _train(segment_name, features, targets=None):
870
870
  # The CRNN train thread started via start() will do the actual training.
871
871
 
872
872
  if TargetMode == "criterion-by-sprint":
873
-
874
873
  # TODO...
875
874
  make_criterion_class()
876
875
 
@@ -1067,13 +1067,14 @@ class _DimMixin:
1067
1067
  )
1068
1068
  )
1069
1069
  if batch and getattr(x, "_RETURNN_dyn_size_beam", None):
1070
- assert batch.beam == getattr(
1071
- x, "_RETURNN_dyn_size_beam"
1072
- ), "%s: dyn size %s has unexpected batch %s, expected %s" % (
1073
- self,
1074
- x,
1075
- batch,
1076
- getattr(x, "_RETURNN_dyn_size_beam"),
1070
+ assert batch.beam == getattr(x, "_RETURNN_dyn_size_beam"), (
1071
+ "%s: dyn size %s has unexpected batch %s, expected %s"
1072
+ % (
1073
+ self,
1074
+ x,
1075
+ batch,
1076
+ getattr(x, "_RETURNN_dyn_size_beam"),
1077
+ )
1077
1078
  )
1078
1079
  if self.batch and batch:
1079
1080
  assert self.batch == batch
@@ -1359,8 +1360,7 @@ class _DimMixin:
1359
1360
  # Only auto-generated dim tags are allowed to be treated as broadcastable.
1360
1361
  # This was another suggestion from here: https://github.com/rwth-i6/returnn/issues/666
1361
1362
  # It was not implemented like this because the auto_generated flag was only introduced later.
1362
- (self.dimension == 1 and self.auto_generated)
1363
- or (other.dimension == 1 and other.auto_generated)
1363
+ (self.dimension == 1 and self.auto_generated) or (other.dimension == 1 and other.auto_generated)
1364
1364
  ):
1365
1365
  pass # pass on
1366
1366
  else:
@@ -335,9 +335,9 @@ class _TensorMixin(_TensorMixinBase):
335
335
  if tag.dyn_size_ext.placeholder is None:
336
336
  tag.complete_dyn_size()
337
337
  if self.placeholder is not None:
338
- assert (
339
- tag.dyn_size_ext.placeholder is not None
340
- ), "%s sanity_check: dynamic dim %s value unknown" % (self, tag)
338
+ assert tag.dyn_size_ext.placeholder is not None, (
339
+ "%s sanity_check: dynamic dim %s value unknown" % (self, tag)
340
+ )
341
341
  assert tag.is_dim_known()
342
342
 
343
343
  def get_runtime_sanity_check_op(self: Tensor):
@@ -2494,8 +2494,7 @@ class _TensorMixin(_TensorMixinBase):
2494
2494
  if res_tag.match_priority > tag.match_priority:
2495
2495
  continue
2496
2496
  raise Exception(
2497
- f"{self}: get_axis_from_description({axis}) not unique."
2498
- f" use match_priority to resolve ambiguity"
2497
+ f"{self}: get_axis_from_description({axis}) not unique. use match_priority to resolve ambiguity"
2499
2498
  )
2500
2499
  if res_idx is None:
2501
2500
  raise Exception(f"{self}: get_axis_from_description({axis}) not found")
@@ -2646,12 +2645,13 @@ class _TensorMixin(_TensorMixinBase):
2646
2645
  return self.batch_shape[self.time_dim_axis_excluding_batch] is None
2647
2646
  if self.time_dim_axis_excluding_batch in self.size_placeholder:
2648
2647
  return True
2649
- assert isinstance(
2650
- self.shape[self.time_dim_axis_excluding_batch], int
2651
- ), "%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information" % (
2652
- self,
2653
- self.time_dim_axis,
2654
- self.size_placeholder,
2648
+ assert isinstance(self.shape[self.time_dim_axis_excluding_batch], int), (
2649
+ "%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information"
2650
+ % (
2651
+ self,
2652
+ self.time_dim_axis,
2653
+ self.size_placeholder,
2654
+ )
2655
2655
  )
2656
2656
  return False
2657
2657
 
@@ -3307,14 +3307,15 @@ class _TensorMixin(_TensorMixinBase):
3307
3307
  if self_axis not in taken_self_axes
3308
3308
  ]
3309
3309
  if opt == "unknown_spatial_matches":
3310
- assert (
3311
- len(matching) <= 1
3312
- ), "cannot match axes %s from %s to %s, failed at other %s, not unique after %s" % (
3313
- other_axes,
3314
- other,
3315
- self,
3316
- other_axis,
3317
- opt,
3310
+ assert len(matching) <= 1, (
3311
+ "cannot match axes %s from %s to %s, failed at other %s, not unique after %s"
3312
+ % (
3313
+ other_axes,
3314
+ other,
3315
+ self,
3316
+ other_axis,
3317
+ opt,
3318
+ )
3318
3319
  )
3319
3320
  if matching:
3320
3321
  break
@@ -13,7 +13,6 @@ from ._tensor_mixin_base import _TensorMixinBase
13
13
 
14
14
 
15
15
  class _TensorOpOverloadsMixin(_TensorMixinBase):
16
-
17
16
  # Note that all those ops have native implementations as well,
18
17
  # so keep the logic in sync.
19
18
 
returnn/tensor/tensor.py CHANGED
@@ -187,7 +187,7 @@ class Tensor(_TensorMixin, _TensorOpOverloadsMixin, Generic[RawTensorType]):
187
187
  if not backend.executing_eagerly():
188
188
  backend.set_known_shape_raw(value, self.batch_shape)
189
189
  assert backend.get_dtype_name_raw(value) == self.dtype, (
190
- f"{self} dtype {self.dtype} does not match " f"raw tensor dtype {backend.get_dtype_name_raw(value)}"
190
+ f"{self} dtype {self.dtype} does not match raw tensor dtype {backend.get_dtype_name_raw(value)}"
191
191
  )
192
192
  self._raw_tensor = value
193
193
 
@@ -91,9 +91,9 @@ class TensorDict:
91
91
  out = {}
92
92
  for key, value in self.data.items():
93
93
  assert key not in out
94
- assert isinstance(
95
- value.raw_tensor, expected_value_type
96
- ), f"key {key} {value}: unexpected {type(value.raw_tensor)}, expected {expected_value_type}"
94
+ assert isinstance(value.raw_tensor, expected_value_type), (
95
+ f"key {key} {value}: unexpected {type(value.raw_tensor)}, expected {expected_value_type}"
96
+ )
97
97
  out[key] = value.raw_tensor
98
98
  for i, dim in enumerate(value.dims):
99
99
  if exclude_duplicate_dims and dim in visited_dims:
@@ -103,9 +103,9 @@ class TensorDict:
103
103
  if dim.is_batch_dim() and (dim.dyn_size_ext is None or dim.dyn_size_ext.raw_tensor is None):
104
104
  if include_scalar_dyn_sizes:
105
105
  dim_value = dim.get_dim_value()
106
- assert isinstance(
107
- dim_value, expected_value_type
108
- ), f"key {key_} {dim}: unexpected {type(dim_value)}, expected {expected_value_type}"
106
+ assert isinstance(dim_value, expected_value_type), (
107
+ f"key {key_} {dim}: unexpected {type(dim_value)}, expected {expected_value_type}"
108
+ )
109
109
  out[key_] = dim_value
110
110
  elif dim.dyn_size_ext is not None:
111
111
  if include_scalar_dyn_sizes or dim.dyn_size_ext.dims:
@@ -116,9 +116,9 @@ class TensorDict:
116
116
  out[key_] = dim.dyn_size_ext.raw_tensor
117
117
  elif dim.size is not None:
118
118
  if include_scalar_dyn_sizes and include_const_sizes:
119
- assert isinstance(
120
- dim.size, expected_value_type
121
- ), f"key {key_} {dim}: unexpected {type(dim.size)}, expected {expected_value_type}"
119
+ assert isinstance(dim.size, expected_value_type), (
120
+ f"key {key_} {dim}: unexpected {type(dim.size)}, expected {expected_value_type}"
121
+ )
122
122
  out[key_] = dim.size
123
123
  else:
124
124
  raise Exception(f"cannot handle dim: {dim}")