nshtrainer 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,15 +4,19 @@ import functools
4
4
  import logging
5
5
  from collections.abc import Iterator, Sequence
6
6
  from pathlib import Path
7
- from typing import Any, ClassVar, Literal, overload
7
+ from typing import TYPE_CHECKING, ClassVar, Generic, Literal, cast, overload
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.utilities.apply_func import move_data_to_device
11
11
  from lightning.pytorch.callbacks import BasePredictionWriter
12
- from typing_extensions import final, override
12
+ from typing_extensions import TypeVar, final, override
13
13
 
14
14
  from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
15
15
 
16
+ if TYPE_CHECKING:
17
+ from ..model.base import IndividualSample
18
+
19
+
16
20
  log = logging.getLogger(__name__)
17
21
 
18
22
 
@@ -130,7 +134,15 @@ class DistributedPredictionWriter(BasePredictionWriter):
130
134
  save(sample, output_dir / f"{sample['index']}.pt")
131
135
 
132
136
 
133
- class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
137
+ SampleT = TypeVar(
138
+ "SampleT",
139
+ bound="IndividualSample",
140
+ default="IndividualSample",
141
+ infer_variance=True,
142
+ )
143
+
144
+
145
+ class DistributedPredictionReader(Sequence[SampleT], Generic[SampleT]):
134
146
  def __init__(self, output_dir: Path):
135
147
  self.output_dir = output_dir
136
148
 
@@ -139,15 +151,13 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
139
151
  return len(list(self.output_dir.glob("*.pt")))
140
152
 
141
153
  @overload
142
- def __getitem__(self, index: int) -> tuple[Any, Any]: ...
154
+ def __getitem__(self, index: int) -> SampleT: ...
143
155
 
144
156
  @overload
145
- def __getitem__(self, index: slice) -> list[tuple[Any, Any]]: ...
157
+ def __getitem__(self, index: slice) -> list[SampleT]: ...
146
158
 
147
159
  @override
148
- def __getitem__(
149
- self, index: int | slice
150
- ) -> tuple[Any, Any] | list[tuple[Any, Any]]:
160
+ def __getitem__(self, index: int | slice) -> SampleT | list[SampleT]:
151
161
  if isinstance(index, slice):
152
162
  # Handle slice indexing
153
163
  indices = range(*index.indices(len(self)))
@@ -157,10 +167,11 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
157
167
  path = self.output_dir / f"{index}.pt"
158
168
  if not path.exists():
159
169
  raise FileNotFoundError(f"File {path} does not exist.")
160
- sample = torch.load(path)
161
- return sample["batch"], sample["prediction"]
170
+
171
+ sample = cast(SampleT, torch.load(path))
172
+ return sample
162
173
 
163
174
  @override
164
- def __iter__(self) -> Iterator[tuple[Any, Any]]:
175
+ def __iter__(self) -> Iterator[SampleT]:
165
176
  for i in range(len(self)):
166
177
  yield self[i]
@@ -22,9 +22,6 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
- from nshtrainer.trainer._config import (
26
- DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
27
- )
28
25
  from nshtrainer.trainer._config import (
29
26
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
30
27
  )
@@ -126,6 +123,9 @@ from nshtrainer.trainer.plugin.precision import (
126
123
  )
127
124
  from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
128
125
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
126
+ from nshtrainer.trainer.trainer import (
127
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
128
+ )
129
129
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
130
130
 
131
131
  from . import _config as _config
@@ -18,9 +18,6 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
- from nshtrainer.trainer._config import (
22
- DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
23
- )
24
21
  from nshtrainer.trainer._config import (
25
22
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
26
23
  )
@@ -73,7 +70,6 @@ __all__ = [
73
70
  "CheckpointSavingConfig",
74
71
  "DebugFlagCallbackConfig",
75
72
  "DirectoryConfig",
76
- "DistributedPredictionWriterConfig",
77
73
  "EarlyStoppingCallbackConfig",
78
74
  "EnvironmentConfig",
79
75
  "GradientClippingConfig",
@@ -3,12 +3,16 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
+ from nshtrainer.trainer.trainer import (
7
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
8
+ )
6
9
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
10
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
8
11
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
9
12
 
10
13
  __all__ = [
11
14
  "AcceleratorConfigBase",
15
+ "DistributedPredictionWriterConfig",
12
16
  "EnvironmentConfig",
13
17
  "StrategyConfigBase",
14
18
  "TrainerConfig",
@@ -1,7 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from ..callbacks import callback_registry as callback_registry
4
+ from ..callbacks.distributed_prediction_writer import (
5
+ DistributedPredictionReader as DistributedPredictionReader,
6
+ )
4
7
  from ._config import TrainerConfig as TrainerConfig
8
+ from ._distributed_prediction_result import (
9
+ DistributedPredictionResult as DistributedPredictionResult,
10
+ )
5
11
  from .accelerator import accelerator_registry as accelerator_registry
6
12
  from .plugin import plugin_registry as plugin_registry
7
13
  from .trainer import Trainer as Trainer
@@ -31,7 +31,6 @@ from .._hf_hub import HuggingFaceHubConfig
31
31
  from ..callbacks import (
32
32
  BestCheckpointCallbackConfig,
33
33
  CallbackConfig,
34
- DistributedPredictionWriterConfig,
35
34
  EarlyStoppingCallbackConfig,
36
35
  LastCheckpointCallbackConfig,
37
36
  NormLoggingCallbackConfig,
@@ -702,14 +701,6 @@ class TrainerConfig(C.Config):
702
701
  auto_validate_metrics: MetricValidationCallbackConfig | None = None
703
702
  """If enabled, will automatically validate the metrics before starting the training routine."""
704
703
 
705
- distributed_predict: DistributedPredictionWriterConfig | None = (
706
- DistributedPredictionWriterConfig()
707
- )
708
- """If enabled, will use a custom BasePredictionWriter callback to automatically
709
- handle distributed prediction. This is useful for running prediction on multiple GPUs
710
- seamlessly.
711
- """
712
-
713
704
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
714
705
  """
715
706
  Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
@@ -778,7 +769,6 @@ class TrainerConfig(C.Config):
778
769
  yield self.reduce_lr_on_plateau_sanity_checking
779
770
  yield self.auto_set_debug_flag
780
771
  yield self.auto_validate_metrics
781
- yield self.distributed_predict
782
772
  yield from self.callbacks
783
773
 
784
774
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+
10
+ @dataclass
11
+ class DistributedPredictionResult:
12
+ """Represents the results of a distributed prediction run.
13
+
14
+ This dataclass provides easy access to both raw and processed prediction data.
15
+ """
16
+
17
+ root_dir: Path
18
+ """Root directory where predictions are stored."""
19
+
20
+ @property
21
+ def raw_dir(self) -> Path:
22
+ """Directory containing raw prediction data."""
23
+ return self.root_dir / "raw"
24
+
25
+ @property
26
+ def processed_dir(self) -> Path:
27
+ """Directory containing processed prediction data."""
28
+ return self.root_dir / "processed"
29
+
30
+ def get_raw_predictions(self, dataloader_idx: int = 0) -> Path:
31
+ """Get the directory containing raw predictions for a specific dataloader.
32
+
33
+ Args:
34
+ dataloader_idx: Index of the dataloader
35
+
36
+ Returns:
37
+ Path to the raw predictions directory for the specified dataloader
38
+ """
39
+ raw_loader_dir = self.raw_dir / f"dataloader_{dataloader_idx}"
40
+ if not raw_loader_dir.exists():
41
+ log.warning(f"Raw predictions directory {raw_loader_dir} does not exist.")
42
+ return raw_loader_dir
43
+
44
+ def get_processed_reader(self, dataloader_idx: int = 0):
45
+ """Get a reader for processed predictions from a specific dataloader.
46
+
47
+ Args:
48
+ dataloader_idx: Index of the dataloader
49
+
50
+ Returns:
51
+ A DistributedPredictionReader for the processed predictions, or None if no data exists
52
+ """
53
+ from ..callbacks.distributed_prediction_writer import (
54
+ DistributedPredictionReader,
55
+ )
56
+
57
+ processed_loader_dir = self.processed_dir / f"dataloader_{dataloader_idx}"
58
+ if not processed_loader_dir.exists():
59
+ log.warning(
60
+ f"Processed predictions directory {processed_loader_dir} does not exist."
61
+ )
62
+ return None
63
+
64
+ return DistributedPredictionReader(processed_loader_dir)
65
+
66
+ @classmethod
67
+ def load(cls, path: Path | str):
68
+ """Load prediction results from a directory.
69
+
70
+ Args:
71
+ path: Path to the predictions directory
72
+
73
+ Returns:
74
+ A DistributedPredictionResult instance
75
+ """
76
+ path = Path(path)
77
+ if not path.exists():
78
+ raise FileNotFoundError(f"Predictions directory {path} does not exist.")
79
+
80
+ return cls(root_dir=path)
@@ -4,7 +4,7 @@ import logging
4
4
  import os
5
5
  from collections.abc import Callable, Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, cast, overload
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
@@ -24,9 +24,14 @@ from typing_extensions import Never, Unpack, assert_never, deprecated, override
24
24
 
25
25
  from .._checkpoint.metadata import write_checkpoint_metadata
26
26
  from ..callbacks.base import resolve_all_callbacks
27
+ from ..callbacks.distributed_prediction_writer import (
28
+ DistributedPredictionWriter,
29
+ DistributedPredictionWriterConfig,
30
+ )
27
31
  from ..util._environment_info import EnvironmentConfig
28
32
  from ..util.bf16 import is_bf16_supported_no_emulation
29
33
  from ._config import LightningTrainerKwargs, TrainerConfig
34
+ from ._distributed_prediction_result import DistributedPredictionResult
30
35
  from ._log_hparams import patch_log_hparams_function
31
36
  from ._runtime_callback import RuntimeTrackerCallback, Stage
32
37
  from .accelerator import AcceleratorConfigBase
@@ -537,13 +542,66 @@ class Trainer(LightningTrainer):
537
542
  )
538
543
  return cls(hparams)
539
544
 
545
+ @overload
540
546
  def distributed_predict(
541
547
  self,
542
548
  model: LightningModule | None = None,
543
549
  dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
544
550
  datamodule: LightningDataModule | None = None,
545
551
  ckpt_path: str | Path | None = None,
546
- ):
552
+ *,
553
+ config: DistributedPredictionWriterConfig,
554
+ ) -> DistributedPredictionResult: ...
555
+
556
+ @overload
557
+ def distributed_predict(
558
+ self,
559
+ model: LightningModule | None = None,
560
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
561
+ datamodule: LightningDataModule | None = None,
562
+ ckpt_path: str | Path | None = None,
563
+ *,
564
+ dirpath: Path | None = None,
565
+ move_to_cpu_on_save: bool = True,
566
+ save_raw: bool = True,
567
+ save_processed: bool = True,
568
+ ) -> DistributedPredictionResult: ...
569
+
570
+ def distributed_predict(
571
+ self,
572
+ model: LightningModule | None = None,
573
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
574
+ datamodule: LightningDataModule | None = None,
575
+ ckpt_path: str | Path | None = None,
576
+ *,
577
+ config: DistributedPredictionWriterConfig | None = None,
578
+ dirpath: Path | None = None,
579
+ move_to_cpu_on_save: bool = True,
580
+ save_raw: bool = True,
581
+ save_processed: bool = True,
582
+ ) -> DistributedPredictionResult:
583
+ if config is None:
584
+ config = DistributedPredictionWriterConfig(
585
+ dirpath=dirpath,
586
+ move_to_cpu_on_save=move_to_cpu_on_save,
587
+ save_raw=save_raw,
588
+ save_processed=save_processed,
589
+ )
590
+
591
+ # Remove any DistributedPredictionWriter callbacks that are already set
592
+ # and add the new one.
593
+ callbacks = self.callbacks.copy()
594
+ callbacks = [
595
+ callback
596
+ for callback in callbacks
597
+ if not isinstance(callback, DistributedPredictionWriter)
598
+ ]
599
+ writer_callbacks = list(config.create_callbacks(self.hparams))
600
+ assert len(writer_callbacks) == 1
601
+ callback = writer_callbacks[0]
602
+ callbacks.append(callback)
603
+ self.callbacks = self._callback_connector._reorder_callbacks(callbacks)
604
+
547
605
  self.predict(
548
606
  model,
549
607
  dataloaders,
@@ -551,3 +609,9 @@ class Trainer(LightningTrainer):
551
609
  return_predictions=False,
552
610
  ckpt_path=ckpt_path,
553
611
  )
612
+
613
+ # Wait for all processes to finish
614
+ self.strategy.barrier("Trainer.distributed_predict")
615
+
616
+ # Return an object that contains information about the predictions
617
+ return DistributedPredictionResult(root_dir=callback.output_dir)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.2.1
3
+ Version: 1.3.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -16,7 +16,7 @@ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
17
17
  nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
18
18
  nshtrainer/callbacks/directory_setup.py,sha256=Ln6f0tCgoBscHeigIAWtCCoAmuWB-kPyaf7SylU7MYo,2773
19
- nshtrainer/callbacks/distributed_prediction_writer.py,sha256=OSh2C6XF7Nki4eFByNVhwlt69izkxnlmfPx54w4rvBo,5274
19
+ nshtrainer/callbacks/distributed_prediction_writer.py,sha256=PvxV9E9lHT-NQ-h1ld7WugajqiFyFXECsreUt3e7pxk,5440
20
20
  nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
21
21
  nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
22
22
  nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
@@ -85,8 +85,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
85
85
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
86
86
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
87
87
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
88
- nshtrainer/configs/trainer/__init__.py,sha256=PF9rYuVpk0IuhjcxS_hmBTT6A0oq7AWZDcx0Gfqi7MM,8040
89
- nshtrainer/configs/trainer/_config/__init__.py,sha256=5B8pjyNHfyFJ6p8dD5VSHD1tw2CcZ87Eq2C_Req3t60,3977
88
+ nshtrainer/configs/trainer/__init__.py,sha256=YLlDOUYDp_qURHhcmhCxTcY6K5AbmoTxdzBPB9SEZII,8040
89
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
90
90
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
91
91
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
92
92
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
@@ -95,7 +95,7 @@ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=AtGUuE0M16dTpX0q9NqvJiE4
95
95
  nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
96
96
  nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
97
97
  nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
98
- nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
98
+ nshtrainer/configs/trainer/trainer/__init__.py,sha256=gOyfE4LlKP-pDJB_ILf79--GztnkF_QmEcexHgqGxOI,646
99
99
  nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
100
100
  nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
101
101
  nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
@@ -134,8 +134,9 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
134
134
  nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
135
135
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
136
136
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
137
- nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
138
- nshtrainer/trainer/_config.py,sha256=EaAZCajuvDoVAn79zGebUgg0ijz-i3eLhwxYq8oTNe8,33600
137
+ nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
138
+ nshtrainer/trainer/_config.py,sha256=Lt9tuzxgVzVnyEFz61xbaPudfsXbKYUphOg-qMDHO8g,33203
139
+ nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
139
140
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
140
141
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
141
142
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
@@ -147,7 +148,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
147
148
  nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
148
149
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
149
150
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
150
- nshtrainer/trainer/trainer.py,sha256=smoN61iixWYDWGFvxrt8VwryZVy_NzqqjUcgOid0gRA,21696
151
+ nshtrainer/trainer/trainer.py,sha256=6oky6E8cjGqUNzJGyyTO551pE9A6YueOv5oxg1fZVR0,24129
151
152
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
152
153
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
153
154
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -159,6 +160,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
159
160
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
160
161
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
161
162
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
162
- nshtrainer-1.2.1.dist-info/METADATA,sha256=MrF6xvpgRjy2HsFwCIU9YZXMYcOb9ItAXY8m66P2CoQ,960
163
- nshtrainer-1.2.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
- nshtrainer-1.2.1.dist-info/RECORD,,
163
+ nshtrainer-1.3.1.dist-info/METADATA,sha256=RCFzQ6YlNZmaYUMcLR4RMotPI3X3QXFwI6MWyN5nkjE,960
164
+ nshtrainer-1.3.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
+ nshtrainer-1.3.1.dist-info/RECORD,,