returnn 1.20250206.144022__py3-none-any.whl → 1.20250207.143045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/torch/engine.py +53 -17
- returnn/torch/frontend/_backend.py +1 -1
- {returnn-1.20250206.144022.dist-info → returnn-1.20250207.143045.dist-info}/METADATA +1 -1
- {returnn-1.20250206.144022.dist-info → returnn-1.20250207.143045.dist-info}/RECORD +9 -9
- {returnn-1.20250206.144022.dist-info → returnn-1.20250207.143045.dist-info}/LICENSE +0 -0
- {returnn-1.20250206.144022.dist-info → returnn-1.20250207.143045.dist-info}/WHEEL +0 -0
- {returnn-1.20250206.144022.dist-info → returnn-1.20250207.143045.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250207.143045'
|
|
2
|
+
long_version = '1.20250207.143045+git.b994e87'
|
returnn/torch/engine.py
CHANGED
|
@@ -3,7 +3,7 @@ Main engine for PyTorch
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import Optional, Any, Union, Callable, Dict, Set
|
|
6
|
+
from typing import Optional, Any, Union, Callable, Dict, Set, Tuple
|
|
7
7
|
from contextlib import nullcontext, ExitStack, contextmanager
|
|
8
8
|
|
|
9
9
|
import gc
|
|
@@ -371,6 +371,7 @@ class Engine(EngineBase):
|
|
|
371
371
|
total_data_size_packed = NumbersDict()
|
|
372
372
|
total_data_size_padded = NumbersDict()
|
|
373
373
|
|
|
374
|
+
report_prefix = f"ep {self.epoch} train"
|
|
374
375
|
try:
|
|
375
376
|
while True:
|
|
376
377
|
with torch.no_grad():
|
|
@@ -398,21 +399,13 @@ class Engine(EngineBase):
|
|
|
398
399
|
{k: int(util.prod(extern_data_raw[k].shape[:2])) for k in keys_w_seq_len},
|
|
399
400
|
)
|
|
400
401
|
|
|
401
|
-
|
|
402
|
-
|
|
402
|
+
num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
|
|
403
|
+
report_prefix=report_prefix,
|
|
404
|
+
extern_data_raw=extern_data_raw,
|
|
405
|
+
step_idx=step_idx,
|
|
406
|
+
prev_num_seqs=num_seqs,
|
|
407
|
+
prev_last_seq_idx=last_seq_idx,
|
|
403
408
|
)
|
|
404
|
-
# Note: The batches might have been shuffled,
|
|
405
|
-
# thus we cannot really assert that the seq_idx is always increasing.
|
|
406
|
-
last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), last_seq_idx)
|
|
407
|
-
if step_idx == 0:
|
|
408
|
-
if num_seqs_ >= 0:
|
|
409
|
-
print(f"Epoch {self.epoch} num_seqs: {num_seqs_}", file=log.v5)
|
|
410
|
-
num_seqs = num_seqs_
|
|
411
|
-
elif num_seqs_ >= 0:
|
|
412
|
-
assert num_seqs_ == num_seqs
|
|
413
|
-
del num_seqs_
|
|
414
|
-
if num_seqs is not None:
|
|
415
|
-
assert last_seq_idx < num_seqs
|
|
416
409
|
epoch_continuous = (self.epoch - 1 + (last_seq_idx + 1) / num_seqs) if num_seqs is not None else None
|
|
417
410
|
|
|
418
411
|
# clear the gradients when every gradient accumulation loop starts
|
|
@@ -485,7 +478,7 @@ class Engine(EngineBase):
|
|
|
485
478
|
accumulated_inv_norm_factors_dict += inv_norm_factors_dict
|
|
486
479
|
eval_info = self._maybe_extend_losses_info(losses_dict / inv_norm_factors_dict)
|
|
487
480
|
_print_process(
|
|
488
|
-
|
|
481
|
+
report_prefix,
|
|
489
482
|
step=step_idx,
|
|
490
483
|
eval_info=dict(eval_info),
|
|
491
484
|
step_duration=step_duration,
|
|
@@ -1276,6 +1269,8 @@ class Engine(EngineBase):
|
|
|
1276
1269
|
new_dim.dyn_size_ext = _get_tensor_wo_batch_numpy(dim.dyn_size_ext)
|
|
1277
1270
|
return new_dim
|
|
1278
1271
|
|
|
1272
|
+
num_seqs = None
|
|
1273
|
+
last_seq_idx = 0
|
|
1279
1274
|
report_prefix = f"ep {self.epoch} {dataset.name} forward"
|
|
1280
1275
|
with torch.no_grad():
|
|
1281
1276
|
callback.init(model=self._orig_model)
|
|
@@ -1283,6 +1278,15 @@ class Engine(EngineBase):
|
|
|
1283
1278
|
step_idx = 0
|
|
1284
1279
|
for extern_data_raw in data_loader:
|
|
1285
1280
|
step_begin_time = time.monotonic()
|
|
1281
|
+
|
|
1282
|
+
num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
|
|
1283
|
+
report_prefix=report_prefix,
|
|
1284
|
+
extern_data_raw=extern_data_raw,
|
|
1285
|
+
step_idx=step_idx,
|
|
1286
|
+
prev_num_seqs=num_seqs,
|
|
1287
|
+
prev_last_seq_idx=last_seq_idx,
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1286
1290
|
if self._forward_step_expected_outputs:
|
|
1287
1291
|
# Also resets any dyn dims, which might have been set in the prev step.
|
|
1288
1292
|
self._forward_step_expected_outputs.reset_content()
|
|
@@ -1319,11 +1323,19 @@ class Engine(EngineBase):
|
|
|
1319
1323
|
model_outputs_per_batch.data[k] = _get_tensor_wo_batch_numpy(v)
|
|
1320
1324
|
callback.process_seq(seq_tag=seq_tag, outputs=model_outputs_per_batch)
|
|
1321
1325
|
|
|
1322
|
-
|
|
1326
|
+
step_end_time = time.monotonic()
|
|
1327
|
+
step_duration = step_end_time - step_begin_time
|
|
1328
|
+
elapsed_computation_time += step_duration
|
|
1329
|
+
|
|
1323
1330
|
_print_process(
|
|
1324
1331
|
report_prefix,
|
|
1325
1332
|
step=step_idx,
|
|
1326
1333
|
eval_info=None,
|
|
1334
|
+
step_duration=step_duration,
|
|
1335
|
+
start_elapsed=step_end_time - epoch_start_time,
|
|
1336
|
+
seq_idx=last_seq_idx,
|
|
1337
|
+
num_seqs=num_seqs,
|
|
1338
|
+
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
|
|
1327
1339
|
log_memory_usage_device=self._device if self._log_memory_usage else None,
|
|
1328
1340
|
)
|
|
1329
1341
|
step_idx += 1
|
|
@@ -1601,3 +1613,27 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
|
|
|
1601
1613
|
p=p,
|
|
1602
1614
|
).item()
|
|
1603
1615
|
)
|
|
1616
|
+
|
|
1617
|
+
|
|
1618
|
+
def _get_num_seqs_last_seq_idx(
|
|
1619
|
+
*,
|
|
1620
|
+
report_prefix: str,
|
|
1621
|
+
extern_data_raw: Dict[str, Any],
|
|
1622
|
+
step_idx: int,
|
|
1623
|
+
prev_num_seqs: Optional[int],
|
|
1624
|
+
prev_last_seq_idx: int,
|
|
1625
|
+
) -> Tuple[Optional[int], int]:
|
|
1626
|
+
num_seqs = prev_num_seqs
|
|
1627
|
+
num_seqs_ = int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
|
|
1628
|
+
# Note: The batches might have been shuffled,
|
|
1629
|
+
# thus we cannot really assert that the seq_idx is always increasing.
|
|
1630
|
+
last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), prev_last_seq_idx)
|
|
1631
|
+
if step_idx == 0:
|
|
1632
|
+
if num_seqs_ >= 0:
|
|
1633
|
+
print(f"{report_prefix} num_seqs: {num_seqs_}", file=log.v5)
|
|
1634
|
+
num_seqs = num_seqs_
|
|
1635
|
+
elif num_seqs_ >= 0:
|
|
1636
|
+
assert num_seqs_ == num_seqs
|
|
1637
|
+
if num_seqs is not None:
|
|
1638
|
+
assert last_seq_idx < num_seqs
|
|
1639
|
+
return num_seqs, last_seq_idx
|
|
@@ -1624,7 +1624,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1624
1624
|
assert isinstance(static, bool)
|
|
1625
1625
|
if static:
|
|
1626
1626
|
assert seed is not None
|
|
1627
|
-
generator = torch.Generator()
|
|
1627
|
+
generator = torch.Generator(device=out.raw_tensor.device)
|
|
1628
1628
|
generator.manual_seed(seed)
|
|
1629
1629
|
else:
|
|
1630
1630
|
assert seed is None
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=JxwNmuittLMoytS17vVYTCUstvu63egjofVHNtxAWoI,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=PiumrsjdDM8B1EXRYPtxkhE50REGsTfjClC6U1cNh58,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
|
|
|
207
207
|
returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=i13cUVjI7GxpO0TAresrNyCM0ZBAaf-cXNr09Fmg_2k,6266
|
|
210
|
-
returnn/torch/engine.py,sha256=
|
|
210
|
+
returnn/torch/engine.py,sha256=neM-AL7XQLpZ3V1K4ziqVmij19ey1k2MpLCaFXATOpg,76301
|
|
211
211
|
returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
|
|
@@ -216,7 +216,7 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
|
|
|
216
216
|
returnn/torch/data/returnn_dataset_wrapper.py,sha256=1Bw82-Ge_8m_DSDXZNqQ3zGDic2HQlp6jysELL0NVK0,7369
|
|
217
217
|
returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
|
|
218
218
|
returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
|
|
219
|
-
returnn/torch/frontend/_backend.py,sha256=
|
|
219
|
+
returnn/torch/frontend/_backend.py,sha256=h_rUhBPxLRgpZSqX4C8vX8q4dHWMhZpwPmGbKN6MsZo,99995
|
|
220
220
|
returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
|
|
221
221
|
returnn/torch/frontend/bridge.py,sha256=bAzOVlL-3hD6af9ir8EOyBSXy6O3KtnCRD7SaZTF2yU,7538
|
|
222
222
|
returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250207.143045.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250207.143045.dist-info/METADATA,sha256=JxwNmuittLMoytS17vVYTCUstvu63egjofVHNtxAWoI,5215
|
|
258
|
+
returnn-1.20250207.143045.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
259
|
+
returnn-1.20250207.143045.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250207.143045.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|