nshtrainer 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_directory.py CHANGED
@@ -19,20 +19,8 @@ class DirectoryConfig(C.Config):
19
19
  This isn't specific to the run; it is the parent directory of all runs.
20
20
  """
21
21
 
22
- log: Path | None = None
23
- """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
24
-
25
- stdio: Path | None = None
26
- """stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
27
-
28
- checkpoint: Path | None = None
29
- """Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
30
-
31
- activation: Path | None = None
32
- """Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
33
-
34
- profile: Path | None = None
35
- """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
22
+ logdir_basename: str = "nshtrainer"
23
+ """Base name for the log directory."""
36
24
 
37
25
  setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
38
26
  """Configuration for the directory setup PyTorch Lightning callback."""
@@ -41,11 +29,11 @@ class DirectoryConfig(C.Config):
41
29
  if (project_root_dir := self.project_root) is None:
42
30
  project_root_dir = Path.cwd()
43
31
 
44
- # The default base dir is $CWD/nshtrainer/{id}/
45
- base_dir = project_root_dir / "nshtrainer"
32
+ # The default base dir is $CWD/{logdir_basename}/{id}/
33
+ base_dir = project_root_dir / self.logdir_basename
46
34
  base_dir.mkdir(exist_ok=True)
47
35
 
48
- # Add a .gitignore file to the nshtrainer directory
36
+ # Add a .gitignore file to the {logdir_basename} directory
49
37
  # which will ignore all files except for the .gitignore file itself
50
38
  gitignore_path = base_dir / ".gitignore"
51
39
  if not gitignore_path.exists():
@@ -57,13 +45,8 @@ class DirectoryConfig(C.Config):
57
45
 
58
46
  return base_dir
59
47
 
60
- def resolve_subdirectory(
61
- self,
62
- run_id: str,
63
- # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
64
- subdirectory: str,
65
- ) -> Path:
66
- # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
48
+ def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
49
+ # The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
67
50
  if (subdir := getattr(self, subdirectory, None)) is not None:
68
51
  assert isinstance(subdir, Path), (
69
52
  f"Expected a Path for {subdirectory}, got {type(subdir)}"
@@ -79,7 +62,7 @@ class DirectoryConfig(C.Config):
79
62
  if (log_dir := logger.log_dir) is not None:
80
63
  return log_dir
81
64
 
82
- # Save to nshtrainer/{id}/log/{logger name}
65
+ # Save to {logdir_basename}/{id}/log/{logger name}
83
66
  log_dir = self.resolve_subdirectory(run_id, "log")
84
67
  log_dir = log_dir / logger.resolve_logger_dirname()
85
68
  # ^ NOTE: Logger must have a `name` attribute, as this is
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
4
  from pathlib import Path
6
5
  from typing import Literal
7
6
 
@@ -27,6 +26,7 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
27
26
  def __bool__(self):
28
27
  return self.enabled
29
28
 
29
+ @override
30
30
  def create_callbacks(self, trainer_config):
31
31
  if not self:
32
32
  return
@@ -35,21 +35,28 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
35
35
 
36
36
 
37
37
  def _create_symlink_to_nshrunner(base_dir: Path):
38
- # Resolve the current nshrunner session directory
39
- if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
40
- log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
38
+ try:
39
+ import nshrunner as nr
40
+ except ImportError:
41
+ log.info("nshrunner is not installed. Skipping symlink creation to nshrunner.")
42
+ return
43
+
44
+ # Check if we are in a nshrunner session
45
+ if (session := nr.Session.from_current_session()) is None:
46
+ log.info("No current nshrunner session found. Skipping symlink creation.")
41
47
  return
42
- session_dir = Path(session_dir)
48
+
49
+ session_dir = session.session_dir
43
50
  if not session_dir.exists() or not session_dir.is_dir():
44
51
  log.warning(
45
- f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
52
+ f"nshrunner's session_dir is not a valid directory: {session_dir}. "
46
53
  "Skipping symlink creation."
47
54
  )
48
55
  return
49
56
 
50
57
  # Create the symlink
51
58
  symlink_path = base_dir / "nshrunner"
52
- if symlink_path.exists():
59
+ if symlink_path.exists(follow_symlinks=False):
53
60
  # If it already points to the correct directory, we're done
54
61
  if symlink_path.resolve() == session_dir.resolve():
55
62
  return
@@ -61,7 +68,7 @@ def _create_symlink_to_nshrunner(base_dir: Path):
61
68
  )
62
69
  symlink_path.unlink()
63
70
 
64
- symlink_path.symlink_to(session_dir)
71
+ symlink_path.symlink_to(session_dir, target_is_directory=True)
65
72
 
66
73
 
67
74
  class DirectorySetupCallback(NTCallbackBase):
@@ -4,15 +4,19 @@ import functools
4
4
  import logging
5
5
  from collections.abc import Iterator, Sequence
6
6
  from pathlib import Path
7
- from typing import Any, ClassVar, Literal, overload
7
+ from typing import TYPE_CHECKING, ClassVar, Generic, Literal, cast, overload
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.utilities.apply_func import move_data_to_device
11
11
  from lightning.pytorch.callbacks import BasePredictionWriter
12
- from typing_extensions import final, override
12
+ from typing_extensions import TypeVar, final, override
13
13
 
14
14
  from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
15
15
 
16
+ if TYPE_CHECKING:
17
+ from ..model.base import IndividualSample
18
+
19
+
16
20
  log = logging.getLogger(__name__)
17
21
 
18
22
 
@@ -130,7 +134,15 @@ class DistributedPredictionWriter(BasePredictionWriter):
130
134
  save(sample, output_dir / f"{sample['index']}.pt")
131
135
 
132
136
 
133
- class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
137
+ SampleT = TypeVar(
138
+ "SampleT",
139
+ bound="IndividualSample",
140
+ default="IndividualSample",
141
+ infer_variance=True,
142
+ )
143
+
144
+
145
+ class DistributedPredictionReader(Sequence[SampleT], Generic[SampleT]):
134
146
  def __init__(self, output_dir: Path):
135
147
  self.output_dir = output_dir
136
148
 
@@ -139,15 +151,13 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
139
151
  return len(list(self.output_dir.glob("*.pt")))
140
152
 
141
153
  @overload
142
- def __getitem__(self, index: int) -> tuple[Any, Any]: ...
154
+ def __getitem__(self, index: int) -> SampleT: ...
143
155
 
144
156
  @overload
145
- def __getitem__(self, index: slice) -> list[tuple[Any, Any]]: ...
157
+ def __getitem__(self, index: slice) -> list[SampleT]: ...
146
158
 
147
159
  @override
148
- def __getitem__(
149
- self, index: int | slice
150
- ) -> tuple[Any, Any] | list[tuple[Any, Any]]:
160
+ def __getitem__(self, index: int | slice) -> SampleT | list[SampleT]:
151
161
  if isinstance(index, slice):
152
162
  # Handle slice indexing
153
163
  indices = range(*index.indices(len(self)))
@@ -157,10 +167,11 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
157
167
  path = self.output_dir / f"{index}.pt"
158
168
  if not path.exists():
159
169
  raise FileNotFoundError(f"File {path} does not exist.")
160
- sample = torch.load(path)
161
- return sample["batch"], sample["prediction"]
170
+
171
+ sample = cast(SampleT, torch.load(path))
172
+ return sample
162
173
 
163
174
  @override
164
- def __iter__(self) -> Iterator[tuple[Any, Any]]:
175
+ def __iter__(self) -> Iterator[SampleT]:
165
176
  for i in range(len(self)):
166
177
  yield self[i]
@@ -22,9 +22,6 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
- from nshtrainer.trainer._config import (
26
- DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
27
- )
28
25
  from nshtrainer.trainer._config import (
29
26
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
30
27
  )
@@ -126,6 +123,9 @@ from nshtrainer.trainer.plugin.precision import (
126
123
  )
127
124
  from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
128
125
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
126
+ from nshtrainer.trainer.trainer import (
127
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
128
+ )
129
129
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
130
130
 
131
131
  from . import _config as _config
@@ -18,9 +18,6 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
- from nshtrainer.trainer._config import (
22
- DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
23
- )
24
21
  from nshtrainer.trainer._config import (
25
22
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
26
23
  )
@@ -73,7 +70,6 @@ __all__ = [
73
70
  "CheckpointSavingConfig",
74
71
  "DebugFlagCallbackConfig",
75
72
  "DirectoryConfig",
76
- "DistributedPredictionWriterConfig",
77
73
  "EarlyStoppingCallbackConfig",
78
74
  "EnvironmentConfig",
79
75
  "GradientClippingConfig",
@@ -3,12 +3,16 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
+ from nshtrainer.trainer.trainer import (
7
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
8
+ )
6
9
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
10
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
8
11
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
9
12
 
10
13
  __all__ = [
11
14
  "AcceleratorConfigBase",
15
+ "DistributedPredictionWriterConfig",
12
16
  "EnvironmentConfig",
13
17
  "StrategyConfigBase",
14
18
  "TrainerConfig",
@@ -31,7 +31,6 @@ from .._hf_hub import HuggingFaceHubConfig
31
31
  from ..callbacks import (
32
32
  BestCheckpointCallbackConfig,
33
33
  CallbackConfig,
34
- DistributedPredictionWriterConfig,
35
34
  EarlyStoppingCallbackConfig,
36
35
  LastCheckpointCallbackConfig,
37
36
  NormLoggingCallbackConfig,
@@ -702,14 +701,6 @@ class TrainerConfig(C.Config):
702
701
  auto_validate_metrics: MetricValidationCallbackConfig | None = None
703
702
  """If enabled, will automatically validate the metrics before starting the training routine."""
704
703
 
705
- distributed_predict: DistributedPredictionWriterConfig | None = (
706
- DistributedPredictionWriterConfig()
707
- )
708
- """If enabled, will use a custom BasePredictionWriter callback to automatically
709
- handle distributed prediction. This is useful for running prediction on multiple GPUs
710
- seamlessly.
711
- """
712
-
713
704
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
714
705
  """
715
706
  Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
@@ -761,6 +752,7 @@ class TrainerConfig(C.Config):
761
752
  )
762
753
 
763
754
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
755
+ yield self.directory.setup_callback
764
756
  yield self.early_stopping
765
757
  yield self.checkpoint_saving
766
758
  yield self.lr_monitor
@@ -777,7 +769,6 @@ class TrainerConfig(C.Config):
777
769
  yield self.reduce_lr_on_plateau_sanity_checking
778
770
  yield self.auto_set_debug_flag
779
771
  yield self.auto_validate_metrics
780
- yield self.distributed_predict
781
772
  yield from self.callbacks
782
773
 
783
774
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+
10
+ @dataclass
11
+ class DistributedPredictionResult:
12
+ """Represents the results of a distributed prediction run.
13
+
14
+ This dataclass provides easy access to both raw and processed prediction data.
15
+ """
16
+
17
+ root_dir: Path
18
+ """Root directory where predictions are stored."""
19
+
20
+ @property
21
+ def raw_dir(self) -> Path:
22
+ """Directory containing raw prediction data."""
23
+ return self.root_dir / "raw"
24
+
25
+ @property
26
+ def processed_dir(self) -> Path:
27
+ """Directory containing processed prediction data."""
28
+ return self.root_dir / "processed"
29
+
30
+ def get_raw_predictions(self, dataloader_idx: int = 0) -> Path:
31
+ """Get the directory containing raw predictions for a specific dataloader.
32
+
33
+ Args:
34
+ dataloader_idx: Index of the dataloader
35
+
36
+ Returns:
37
+ Path to the raw predictions directory for the specified dataloader
38
+ """
39
+ raw_loader_dir = self.raw_dir / f"dataloader_{dataloader_idx}"
40
+ if not raw_loader_dir.exists():
41
+ log.warning(f"Raw predictions directory {raw_loader_dir} does not exist.")
42
+ return raw_loader_dir
43
+
44
+ def get_processed_reader(self, dataloader_idx: int = 0):
45
+ """Get a reader for processed predictions from a specific dataloader.
46
+
47
+ Args:
48
+ dataloader_idx: Index of the dataloader
49
+
50
+ Returns:
51
+ A DistributedPredictionReader for the processed predictions, or None if no data exists
52
+ """
53
+ from ..callbacks.distributed_prediction_writer import (
54
+ DistributedPredictionReader,
55
+ )
56
+
57
+ processed_loader_dir = self.processed_dir / f"dataloader_{dataloader_idx}"
58
+ if not processed_loader_dir.exists():
59
+ log.warning(
60
+ f"Processed predictions directory {processed_loader_dir} does not exist."
61
+ )
62
+ return None
63
+
64
+ return DistributedPredictionReader(processed_loader_dir)
65
+
66
+ @classmethod
67
+ def load(cls, path: Path | str):
68
+ """Load prediction results from a directory.
69
+
70
+ Args:
71
+ path: Path to the predictions directory
72
+
73
+ Returns:
74
+ A DistributedPredictionResult instance
75
+ """
76
+ path = Path(path)
77
+ if not path.exists():
78
+ raise FileNotFoundError(f"Predictions directory {path} does not exist.")
79
+
80
+ return cls(root_dir=path)
@@ -4,7 +4,7 @@ import logging
4
4
  import os
5
5
  from collections.abc import Callable, Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, cast, overload
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
@@ -24,9 +24,14 @@ from typing_extensions import Never, Unpack, assert_never, deprecated, override
24
24
 
25
25
  from .._checkpoint.metadata import write_checkpoint_metadata
26
26
  from ..callbacks.base import resolve_all_callbacks
27
+ from ..callbacks.distributed_prediction_writer import (
28
+ DistributedPredictionWriter,
29
+ DistributedPredictionWriterConfig,
30
+ )
27
31
  from ..util._environment_info import EnvironmentConfig
28
32
  from ..util.bf16 import is_bf16_supported_no_emulation
29
33
  from ._config import LightningTrainerKwargs, TrainerConfig
34
+ from ._distributed_prediction_result import DistributedPredictionResult
30
35
  from ._log_hparams import patch_log_hparams_function
31
36
  from ._runtime_callback import RuntimeTrackerCallback, Stage
32
37
  from .accelerator import AcceleratorConfigBase
@@ -537,13 +542,66 @@ class Trainer(LightningTrainer):
537
542
  )
538
543
  return cls(hparams)
539
544
 
545
+ @overload
540
546
  def distributed_predict(
541
547
  self,
542
548
  model: LightningModule | None = None,
543
549
  dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
544
550
  datamodule: LightningDataModule | None = None,
545
551
  ckpt_path: str | Path | None = None,
546
- ):
552
+ *,
553
+ config: DistributedPredictionWriterConfig,
554
+ ) -> DistributedPredictionResult: ...
555
+
556
+ @overload
557
+ def distributed_predict(
558
+ self,
559
+ model: LightningModule | None = None,
560
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
561
+ datamodule: LightningDataModule | None = None,
562
+ ckpt_path: str | Path | None = None,
563
+ *,
564
+ dirpath: Path | None = None,
565
+ move_to_cpu_on_save: bool = True,
566
+ save_raw: bool = True,
567
+ save_processed: bool = True,
568
+ ) -> DistributedPredictionResult: ...
569
+
570
+ def distributed_predict(
571
+ self,
572
+ model: LightningModule | None = None,
573
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
574
+ datamodule: LightningDataModule | None = None,
575
+ ckpt_path: str | Path | None = None,
576
+ *,
577
+ config: DistributedPredictionWriterConfig | None = None,
578
+ dirpath: Path | None = None,
579
+ move_to_cpu_on_save: bool = True,
580
+ save_raw: bool = True,
581
+ save_processed: bool = True,
582
+ ) -> DistributedPredictionResult:
583
+ if config is None:
584
+ config = DistributedPredictionWriterConfig(
585
+ dirpath=dirpath,
586
+ move_to_cpu_on_save=move_to_cpu_on_save,
587
+ save_raw=save_raw,
588
+ save_processed=save_processed,
589
+ )
590
+
591
+ # Remove any DistributedPredictionWriter callbacks that are already set
592
+ # and add the new one.
593
+ callbacks = self.callbacks.copy()
594
+ callbacks = [
595
+ callback
596
+ for callback in callbacks
597
+ if not isinstance(callback, DistributedPredictionWriter)
598
+ ]
599
+ writer_callbacks = list(config.create_callbacks(self.hparams))
600
+ assert len(writer_callbacks) == 1
601
+ callback = writer_callbacks[0]
602
+ callbacks.append(callback)
603
+ self.callbacks = self._callback_connector._reorder_callbacks(callbacks)
604
+
547
605
  self.predict(
548
606
  model,
549
607
  dataloaders,
@@ -551,3 +609,9 @@ class Trainer(LightningTrainer):
551
609
  return_predictions=False,
552
610
  ckpt_path=ckpt_path,
553
611
  )
612
+
613
+ # Wait for all processes to finish
614
+ self.strategy.barrier("Trainer.distributed_predict")
615
+
616
+ # Return an object that contains information about the predictions
617
+ return DistributedPredictionResult(root_dir=callback.output_dir)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
3
3
  nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
6
- nshtrainer/_directory.py,sha256=SuXJe9xJXZkDXWWfeOS9rEDz6vZUA6mpnEdkAW0ZQnY,3193
6
+ nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
9
9
  nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
@@ -15,8 +15,8 @@ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
17
17
  nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
18
- nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
19
- nshtrainer/callbacks/distributed_prediction_writer.py,sha256=OSh2C6XF7Nki4eFByNVhwlt69izkxnlmfPx54w4rvBo,5274
18
+ nshtrainer/callbacks/directory_setup.py,sha256=Ln6f0tCgoBscHeigIAWtCCoAmuWB-kPyaf7SylU7MYo,2773
19
+ nshtrainer/callbacks/distributed_prediction_writer.py,sha256=PvxV9E9lHT-NQ-h1ld7WugajqiFyFXECsreUt3e7pxk,5440
20
20
  nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
21
21
  nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
22
22
  nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
@@ -85,8 +85,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
85
85
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
86
86
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
87
87
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
88
- nshtrainer/configs/trainer/__init__.py,sha256=PF9rYuVpk0IuhjcxS_hmBTT6A0oq7AWZDcx0Gfqi7MM,8040
89
- nshtrainer/configs/trainer/_config/__init__.py,sha256=5B8pjyNHfyFJ6p8dD5VSHD1tw2CcZ87Eq2C_Req3t60,3977
88
+ nshtrainer/configs/trainer/__init__.py,sha256=YLlDOUYDp_qURHhcmhCxTcY6K5AbmoTxdzBPB9SEZII,8040
89
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
90
90
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
91
91
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
92
92
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
@@ -95,7 +95,7 @@ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=AtGUuE0M16dTpX0q9NqvJiE4
95
95
  nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
96
96
  nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
97
97
  nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
98
- nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
98
+ nshtrainer/configs/trainer/trainer/__init__.py,sha256=gOyfE4LlKP-pDJB_ILf79--GztnkF_QmEcexHgqGxOI,646
99
99
  nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
100
100
  nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
101
101
  nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
@@ -135,7 +135,8 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
135
135
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
136
136
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
137
137
  nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
138
- nshtrainer/trainer/_config.py,sha256=tdWAYh-KGXBpgdY8fwvOejjRZN-AS2Ze0f_9s2VEuZ0,33556
138
+ nshtrainer/trainer/_config.py,sha256=Lt9tuzxgVzVnyEFz61xbaPudfsXbKYUphOg-qMDHO8g,33203
139
+ nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
139
140
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
140
141
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
141
142
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
@@ -147,7 +148,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
147
148
  nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
148
149
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
149
150
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
150
- nshtrainer/trainer/trainer.py,sha256=smoN61iixWYDWGFvxrt8VwryZVy_NzqqjUcgOid0gRA,21696
151
+ nshtrainer/trainer/trainer.py,sha256=6oky6E8cjGqUNzJGyyTO551pE9A6YueOv5oxg1fZVR0,24129
151
152
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
152
153
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
153
154
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -159,6 +160,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
159
160
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
160
161
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
161
162
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
162
- nshtrainer-1.2.0.dist-info/METADATA,sha256=HkNLruaJJuf3ijnGe7NqNd9emBR6QHMRh2-taC5wTrU,960
163
- nshtrainer-1.2.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
- nshtrainer-1.2.0.dist-info/RECORD,,
163
+ nshtrainer-1.3.0.dist-info/METADATA,sha256=M84AwXCuoJp21_m2IQKYDC-SFWDAdhOy-2fDL1jk9Lw,960
164
+ nshtrainer-1.3.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
+ nshtrainer-1.3.0.dist-info/RECORD,,