nshtrainer 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_directory.py +8 -25
- nshtrainer/callbacks/directory_setup.py +15 -8
- nshtrainer/callbacks/distributed_prediction_writer.py +22 -11
- nshtrainer/configs/trainer/__init__.py +3 -3
- nshtrainer/configs/trainer/_config/__init__.py +0 -4
- nshtrainer/configs/trainer/trainer/__init__.py +4 -0
- nshtrainer/trainer/_config.py +1 -10
- nshtrainer/trainer/_distributed_prediction_result.py +80 -0
- nshtrainer/trainer/trainer.py +66 -2
- {nshtrainer-1.2.0.dist-info → nshtrainer-1.3.0.dist-info}/METADATA +1 -1
- {nshtrainer-1.2.0.dist-info → nshtrainer-1.3.0.dist-info}/RECORD +12 -11
- {nshtrainer-1.2.0.dist-info → nshtrainer-1.3.0.dist-info}/WHEEL +0 -0
nshtrainer/_directory.py
CHANGED
@@ -19,20 +19,8 @@ class DirectoryConfig(C.Config):
|
|
19
19
|
This isn't specific to the run; it is the parent directory of all runs.
|
20
20
|
"""
|
21
21
|
|
22
|
-
|
23
|
-
"""Base
|
24
|
-
|
25
|
-
stdio: Path | None = None
|
26
|
-
"""stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
|
27
|
-
|
28
|
-
checkpoint: Path | None = None
|
29
|
-
"""Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
|
30
|
-
|
31
|
-
activation: Path | None = None
|
32
|
-
"""Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
|
33
|
-
|
34
|
-
profile: Path | None = None
|
35
|
-
"""Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
|
22
|
+
logdir_basename: str = "nshtrainer"
|
23
|
+
"""Base name for the log directory."""
|
36
24
|
|
37
25
|
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
38
26
|
"""Configuration for the directory setup PyTorch Lightning callback."""
|
@@ -41,11 +29,11 @@ class DirectoryConfig(C.Config):
|
|
41
29
|
if (project_root_dir := self.project_root) is None:
|
42
30
|
project_root_dir = Path.cwd()
|
43
31
|
|
44
|
-
# The default base dir is $CWD/
|
45
|
-
base_dir = project_root_dir /
|
32
|
+
# The default base dir is $CWD/{logdir_basename}/{id}/
|
33
|
+
base_dir = project_root_dir / self.logdir_basename
|
46
34
|
base_dir.mkdir(exist_ok=True)
|
47
35
|
|
48
|
-
# Add a .gitignore file to the
|
36
|
+
# Add a .gitignore file to the {logdir_basename} directory
|
49
37
|
# which will ignore all files except for the .gitignore file itself
|
50
38
|
gitignore_path = base_dir / ".gitignore"
|
51
39
|
if not gitignore_path.exists():
|
@@ -57,13 +45,8 @@ class DirectoryConfig(C.Config):
|
|
57
45
|
|
58
46
|
return base_dir
|
59
47
|
|
60
|
-
def resolve_subdirectory(
|
61
|
-
|
62
|
-
run_id: str,
|
63
|
-
# subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
|
64
|
-
subdirectory: str,
|
65
|
-
) -> Path:
|
66
|
-
# The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
|
48
|
+
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
49
|
+
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
67
50
|
if (subdir := getattr(self, subdirectory, None)) is not None:
|
68
51
|
assert isinstance(subdir, Path), (
|
69
52
|
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
@@ -79,7 +62,7 @@ class DirectoryConfig(C.Config):
|
|
79
62
|
if (log_dir := logger.log_dir) is not None:
|
80
63
|
return log_dir
|
81
64
|
|
82
|
-
# Save to
|
65
|
+
# Save to {logdir_basename}/{id}/log/{logger name}
|
83
66
|
log_dir = self.resolve_subdirectory(run_id, "log")
|
84
67
|
log_dir = log_dir / logger.resolve_logger_dirname()
|
85
68
|
# ^ NOTE: Logger must have a `name` attribute, as this is
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import os
|
5
4
|
from pathlib import Path
|
6
5
|
from typing import Literal
|
7
6
|
|
@@ -27,6 +26,7 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
|
|
27
26
|
def __bool__(self):
|
28
27
|
return self.enabled
|
29
28
|
|
29
|
+
@override
|
30
30
|
def create_callbacks(self, trainer_config):
|
31
31
|
if not self:
|
32
32
|
return
|
@@ -35,21 +35,28 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
|
|
35
35
|
|
36
36
|
|
37
37
|
def _create_symlink_to_nshrunner(base_dir: Path):
|
38
|
-
|
39
|
-
|
40
|
-
|
38
|
+
try:
|
39
|
+
import nshrunner as nr
|
40
|
+
except ImportError:
|
41
|
+
log.info("nshrunner is not installed. Skipping symlink creation to nshrunner.")
|
42
|
+
return
|
43
|
+
|
44
|
+
# Check if we are in a nshrunner session
|
45
|
+
if (session := nr.Session.from_current_session()) is None:
|
46
|
+
log.info("No current nshrunner session found. Skipping symlink creation.")
|
41
47
|
return
|
42
|
-
|
48
|
+
|
49
|
+
session_dir = session.session_dir
|
43
50
|
if not session_dir.exists() or not session_dir.is_dir():
|
44
51
|
log.warning(
|
45
|
-
f"
|
52
|
+
f"nshrunner's session_dir is not a valid directory: {session_dir}. "
|
46
53
|
"Skipping symlink creation."
|
47
54
|
)
|
48
55
|
return
|
49
56
|
|
50
57
|
# Create the symlink
|
51
58
|
symlink_path = base_dir / "nshrunner"
|
52
|
-
if symlink_path.exists():
|
59
|
+
if symlink_path.exists(follow_symlinks=False):
|
53
60
|
# If it already points to the correct directory, we're done
|
54
61
|
if symlink_path.resolve() == session_dir.resolve():
|
55
62
|
return
|
@@ -61,7 +68,7 @@ def _create_symlink_to_nshrunner(base_dir: Path):
|
|
61
68
|
)
|
62
69
|
symlink_path.unlink()
|
63
70
|
|
64
|
-
symlink_path.symlink_to(session_dir)
|
71
|
+
symlink_path.symlink_to(session_dir, target_is_directory=True)
|
65
72
|
|
66
73
|
|
67
74
|
class DirectorySetupCallback(NTCallbackBase):
|
@@ -4,15 +4,19 @@ import functools
|
|
4
4
|
import logging
|
5
5
|
from collections.abc import Iterator, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import
|
7
|
+
from typing import TYPE_CHECKING, ClassVar, Generic, Literal, cast, overload
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from lightning.fabric.utilities.apply_func import move_data_to_device
|
11
11
|
from lightning.pytorch.callbacks import BasePredictionWriter
|
12
|
-
from typing_extensions import final, override
|
12
|
+
from typing_extensions import TypeVar, final, override
|
13
13
|
|
14
14
|
from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
|
15
15
|
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from ..model.base import IndividualSample
|
18
|
+
|
19
|
+
|
16
20
|
log = logging.getLogger(__name__)
|
17
21
|
|
18
22
|
|
@@ -130,7 +134,15 @@ class DistributedPredictionWriter(BasePredictionWriter):
|
|
130
134
|
save(sample, output_dir / f"{sample['index']}.pt")
|
131
135
|
|
132
136
|
|
133
|
-
|
137
|
+
SampleT = TypeVar(
|
138
|
+
"SampleT",
|
139
|
+
bound="IndividualSample",
|
140
|
+
default="IndividualSample",
|
141
|
+
infer_variance=True,
|
142
|
+
)
|
143
|
+
|
144
|
+
|
145
|
+
class DistributedPredictionReader(Sequence[SampleT], Generic[SampleT]):
|
134
146
|
def __init__(self, output_dir: Path):
|
135
147
|
self.output_dir = output_dir
|
136
148
|
|
@@ -139,15 +151,13 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
|
|
139
151
|
return len(list(self.output_dir.glob("*.pt")))
|
140
152
|
|
141
153
|
@overload
|
142
|
-
def __getitem__(self, index: int) ->
|
154
|
+
def __getitem__(self, index: int) -> SampleT: ...
|
143
155
|
|
144
156
|
@overload
|
145
|
-
def __getitem__(self, index: slice) -> list[
|
157
|
+
def __getitem__(self, index: slice) -> list[SampleT]: ...
|
146
158
|
|
147
159
|
@override
|
148
|
-
def __getitem__(
|
149
|
-
self, index: int | slice
|
150
|
-
) -> tuple[Any, Any] | list[tuple[Any, Any]]:
|
160
|
+
def __getitem__(self, index: int | slice) -> SampleT | list[SampleT]:
|
151
161
|
if isinstance(index, slice):
|
152
162
|
# Handle slice indexing
|
153
163
|
indices = range(*index.indices(len(self)))
|
@@ -157,10 +167,11 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
|
|
157
167
|
path = self.output_dir / f"{index}.pt"
|
158
168
|
if not path.exists():
|
159
169
|
raise FileNotFoundError(f"File {path} does not exist.")
|
160
|
-
|
161
|
-
|
170
|
+
|
171
|
+
sample = cast(SampleT, torch.load(path))
|
172
|
+
return sample
|
162
173
|
|
163
174
|
@override
|
164
|
-
def __iter__(self) -> Iterator[
|
175
|
+
def __iter__(self) -> Iterator[SampleT]:
|
165
176
|
for i in range(len(self)):
|
166
177
|
yield self[i]
|
@@ -22,9 +22,6 @@ from nshtrainer.trainer._config import (
|
|
22
22
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
23
23
|
)
|
24
24
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
25
|
-
from nshtrainer.trainer._config import (
|
26
|
-
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
27
|
-
)
|
28
25
|
from nshtrainer.trainer._config import (
|
29
26
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
30
27
|
)
|
@@ -126,6 +123,9 @@ from nshtrainer.trainer.plugin.precision import (
|
|
126
123
|
)
|
127
124
|
from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
|
128
125
|
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
126
|
+
from nshtrainer.trainer.trainer import (
|
127
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
128
|
+
)
|
129
129
|
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
130
130
|
|
131
131
|
from . import _config as _config
|
@@ -18,9 +18,6 @@ from nshtrainer.trainer._config import (
|
|
18
18
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
19
19
|
)
|
20
20
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
21
|
-
from nshtrainer.trainer._config import (
|
22
|
-
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
23
|
-
)
|
24
21
|
from nshtrainer.trainer._config import (
|
25
22
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
26
23
|
)
|
@@ -73,7 +70,6 @@ __all__ = [
|
|
73
70
|
"CheckpointSavingConfig",
|
74
71
|
"DebugFlagCallbackConfig",
|
75
72
|
"DirectoryConfig",
|
76
|
-
"DistributedPredictionWriterConfig",
|
77
73
|
"EarlyStoppingCallbackConfig",
|
78
74
|
"EnvironmentConfig",
|
79
75
|
"GradientClippingConfig",
|
@@ -3,12 +3,16 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
6
|
+
from nshtrainer.trainer.trainer import (
|
7
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
8
|
+
)
|
6
9
|
from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
|
7
10
|
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
8
11
|
from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
|
9
12
|
|
10
13
|
__all__ = [
|
11
14
|
"AcceleratorConfigBase",
|
15
|
+
"DistributedPredictionWriterConfig",
|
12
16
|
"EnvironmentConfig",
|
13
17
|
"StrategyConfigBase",
|
14
18
|
"TrainerConfig",
|
nshtrainer/trainer/_config.py
CHANGED
@@ -31,7 +31,6 @@ from .._hf_hub import HuggingFaceHubConfig
|
|
31
31
|
from ..callbacks import (
|
32
32
|
BestCheckpointCallbackConfig,
|
33
33
|
CallbackConfig,
|
34
|
-
DistributedPredictionWriterConfig,
|
35
34
|
EarlyStoppingCallbackConfig,
|
36
35
|
LastCheckpointCallbackConfig,
|
37
36
|
NormLoggingCallbackConfig,
|
@@ -702,14 +701,6 @@ class TrainerConfig(C.Config):
|
|
702
701
|
auto_validate_metrics: MetricValidationCallbackConfig | None = None
|
703
702
|
"""If enabled, will automatically validate the metrics before starting the training routine."""
|
704
703
|
|
705
|
-
distributed_predict: DistributedPredictionWriterConfig | None = (
|
706
|
-
DistributedPredictionWriterConfig()
|
707
|
-
)
|
708
|
-
"""If enabled, will use a custom BasePredictionWriter callback to automatically
|
709
|
-
handle distributed prediction. This is useful for running prediction on multiple GPUs
|
710
|
-
seamlessly.
|
711
|
-
"""
|
712
|
-
|
713
704
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
714
705
|
"""
|
715
706
|
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
@@ -761,6 +752,7 @@ class TrainerConfig(C.Config):
|
|
761
752
|
)
|
762
753
|
|
763
754
|
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
755
|
+
yield self.directory.setup_callback
|
764
756
|
yield self.early_stopping
|
765
757
|
yield self.checkpoint_saving
|
766
758
|
yield self.lr_monitor
|
@@ -777,7 +769,6 @@ class TrainerConfig(C.Config):
|
|
777
769
|
yield self.reduce_lr_on_plateau_sanity_checking
|
778
770
|
yield self.auto_set_debug_flag
|
779
771
|
yield self.auto_validate_metrics
|
780
|
-
yield self.distributed_predict
|
781
772
|
yield from self.callbacks
|
782
773
|
|
783
774
|
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class DistributedPredictionResult:
|
12
|
+
"""Represents the results of a distributed prediction run.
|
13
|
+
|
14
|
+
This dataclass provides easy access to both raw and processed prediction data.
|
15
|
+
"""
|
16
|
+
|
17
|
+
root_dir: Path
|
18
|
+
"""Root directory where predictions are stored."""
|
19
|
+
|
20
|
+
@property
|
21
|
+
def raw_dir(self) -> Path:
|
22
|
+
"""Directory containing raw prediction data."""
|
23
|
+
return self.root_dir / "raw"
|
24
|
+
|
25
|
+
@property
|
26
|
+
def processed_dir(self) -> Path:
|
27
|
+
"""Directory containing processed prediction data."""
|
28
|
+
return self.root_dir / "processed"
|
29
|
+
|
30
|
+
def get_raw_predictions(self, dataloader_idx: int = 0) -> Path:
|
31
|
+
"""Get the directory containing raw predictions for a specific dataloader.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
dataloader_idx: Index of the dataloader
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
Path to the raw predictions directory for the specified dataloader
|
38
|
+
"""
|
39
|
+
raw_loader_dir = self.raw_dir / f"dataloader_{dataloader_idx}"
|
40
|
+
if not raw_loader_dir.exists():
|
41
|
+
log.warning(f"Raw predictions directory {raw_loader_dir} does not exist.")
|
42
|
+
return raw_loader_dir
|
43
|
+
|
44
|
+
def get_processed_reader(self, dataloader_idx: int = 0):
|
45
|
+
"""Get a reader for processed predictions from a specific dataloader.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
dataloader_idx: Index of the dataloader
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
A DistributedPredictionReader for the processed predictions, or None if no data exists
|
52
|
+
"""
|
53
|
+
from ..callbacks.distributed_prediction_writer import (
|
54
|
+
DistributedPredictionReader,
|
55
|
+
)
|
56
|
+
|
57
|
+
processed_loader_dir = self.processed_dir / f"dataloader_{dataloader_idx}"
|
58
|
+
if not processed_loader_dir.exists():
|
59
|
+
log.warning(
|
60
|
+
f"Processed predictions directory {processed_loader_dir} does not exist."
|
61
|
+
)
|
62
|
+
return None
|
63
|
+
|
64
|
+
return DistributedPredictionReader(processed_loader_dir)
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def load(cls, path: Path | str):
|
68
|
+
"""Load prediction results from a directory.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
path: Path to the predictions directory
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
A DistributedPredictionResult instance
|
75
|
+
"""
|
76
|
+
path = Path(path)
|
77
|
+
if not path.exists():
|
78
|
+
raise FileNotFoundError(f"Predictions directory {path} does not exist.")
|
79
|
+
|
80
|
+
return cls(root_dir=path)
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -4,7 +4,7 @@ import logging
|
|
4
4
|
import os
|
5
5
|
from collections.abc import Callable, Mapping, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING, Any, cast
|
7
|
+
from typing import TYPE_CHECKING, Any, cast, overload
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
@@ -24,9 +24,14 @@ from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
|
24
24
|
|
25
25
|
from .._checkpoint.metadata import write_checkpoint_metadata
|
26
26
|
from ..callbacks.base import resolve_all_callbacks
|
27
|
+
from ..callbacks.distributed_prediction_writer import (
|
28
|
+
DistributedPredictionWriter,
|
29
|
+
DistributedPredictionWriterConfig,
|
30
|
+
)
|
27
31
|
from ..util._environment_info import EnvironmentConfig
|
28
32
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
29
33
|
from ._config import LightningTrainerKwargs, TrainerConfig
|
34
|
+
from ._distributed_prediction_result import DistributedPredictionResult
|
30
35
|
from ._log_hparams import patch_log_hparams_function
|
31
36
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
32
37
|
from .accelerator import AcceleratorConfigBase
|
@@ -537,13 +542,66 @@ class Trainer(LightningTrainer):
|
|
537
542
|
)
|
538
543
|
return cls(hparams)
|
539
544
|
|
545
|
+
@overload
|
540
546
|
def distributed_predict(
|
541
547
|
self,
|
542
548
|
model: LightningModule | None = None,
|
543
549
|
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
544
550
|
datamodule: LightningDataModule | None = None,
|
545
551
|
ckpt_path: str | Path | None = None,
|
546
|
-
|
552
|
+
*,
|
553
|
+
config: DistributedPredictionWriterConfig,
|
554
|
+
) -> DistributedPredictionResult: ...
|
555
|
+
|
556
|
+
@overload
|
557
|
+
def distributed_predict(
|
558
|
+
self,
|
559
|
+
model: LightningModule | None = None,
|
560
|
+
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
561
|
+
datamodule: LightningDataModule | None = None,
|
562
|
+
ckpt_path: str | Path | None = None,
|
563
|
+
*,
|
564
|
+
dirpath: Path | None = None,
|
565
|
+
move_to_cpu_on_save: bool = True,
|
566
|
+
save_raw: bool = True,
|
567
|
+
save_processed: bool = True,
|
568
|
+
) -> DistributedPredictionResult: ...
|
569
|
+
|
570
|
+
def distributed_predict(
|
571
|
+
self,
|
572
|
+
model: LightningModule | None = None,
|
573
|
+
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
574
|
+
datamodule: LightningDataModule | None = None,
|
575
|
+
ckpt_path: str | Path | None = None,
|
576
|
+
*,
|
577
|
+
config: DistributedPredictionWriterConfig | None = None,
|
578
|
+
dirpath: Path | None = None,
|
579
|
+
move_to_cpu_on_save: bool = True,
|
580
|
+
save_raw: bool = True,
|
581
|
+
save_processed: bool = True,
|
582
|
+
) -> DistributedPredictionResult:
|
583
|
+
if config is None:
|
584
|
+
config = DistributedPredictionWriterConfig(
|
585
|
+
dirpath=dirpath,
|
586
|
+
move_to_cpu_on_save=move_to_cpu_on_save,
|
587
|
+
save_raw=save_raw,
|
588
|
+
save_processed=save_processed,
|
589
|
+
)
|
590
|
+
|
591
|
+
# Remove any DistributedPredictionWriter callbacks that are already set
|
592
|
+
# and add the new one.
|
593
|
+
callbacks = self.callbacks.copy()
|
594
|
+
callbacks = [
|
595
|
+
callback
|
596
|
+
for callback in callbacks
|
597
|
+
if not isinstance(callback, DistributedPredictionWriter)
|
598
|
+
]
|
599
|
+
writer_callbacks = list(config.create_callbacks(self.hparams))
|
600
|
+
assert len(writer_callbacks) == 1
|
601
|
+
callback = writer_callbacks[0]
|
602
|
+
callbacks.append(callback)
|
603
|
+
self.callbacks = self._callback_connector._reorder_callbacks(callbacks)
|
604
|
+
|
547
605
|
self.predict(
|
548
606
|
model,
|
549
607
|
dataloaders,
|
@@ -551,3 +609,9 @@ class Trainer(LightningTrainer):
|
|
551
609
|
return_predictions=False,
|
552
610
|
ckpt_path=ckpt_path,
|
553
611
|
)
|
612
|
+
|
613
|
+
# Wait for all processes to finish
|
614
|
+
self.strategy.barrier("Trainer.distributed_predict")
|
615
|
+
|
616
|
+
# Return an object that contains information about the predictions
|
617
|
+
return DistributedPredictionResult(root_dir=callback.output_dir)
|
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
|
|
3
3
|
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
6
|
-
nshtrainer/_directory.py,sha256=
|
6
|
+
nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
9
|
nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
|
@@ -15,8 +15,8 @@ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7
|
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
17
17
|
nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
|
18
|
-
nshtrainer/callbacks/directory_setup.py,sha256=
|
19
|
-
nshtrainer/callbacks/distributed_prediction_writer.py,sha256=
|
18
|
+
nshtrainer/callbacks/directory_setup.py,sha256=Ln6f0tCgoBscHeigIAWtCCoAmuWB-kPyaf7SylU7MYo,2773
|
19
|
+
nshtrainer/callbacks/distributed_prediction_writer.py,sha256=PvxV9E9lHT-NQ-h1ld7WugajqiFyFXECsreUt3e7pxk,5440
|
20
20
|
nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
|
21
21
|
nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
|
22
22
|
nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
|
@@ -85,8 +85,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
85
85
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
86
86
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
87
87
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
88
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
89
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
88
|
+
nshtrainer/configs/trainer/__init__.py,sha256=YLlDOUYDp_qURHhcmhCxTcY6K5AbmoTxdzBPB9SEZII,8040
|
89
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
|
90
90
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
91
91
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
92
92
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
|
@@ -95,7 +95,7 @@ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=AtGUuE0M16dTpX0q9NqvJiE4
|
|
95
95
|
nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
|
96
96
|
nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
|
97
97
|
nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
|
98
|
-
nshtrainer/configs/trainer/trainer/__init__.py,sha256=
|
98
|
+
nshtrainer/configs/trainer/trainer/__init__.py,sha256=gOyfE4LlKP-pDJB_ILf79--GztnkF_QmEcexHgqGxOI,646
|
99
99
|
nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
|
100
100
|
nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
|
101
101
|
nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
|
@@ -135,7 +135,8 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
|
|
135
135
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
136
136
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
137
137
|
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
138
|
-
nshtrainer/trainer/_config.py,sha256=
|
138
|
+
nshtrainer/trainer/_config.py,sha256=Lt9tuzxgVzVnyEFz61xbaPudfsXbKYUphOg-qMDHO8g,33203
|
139
|
+
nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
|
139
140
|
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
140
141
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
141
142
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
@@ -147,7 +148,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
|
|
147
148
|
nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
|
148
149
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
149
150
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
150
|
-
nshtrainer/trainer/trainer.py,sha256=
|
151
|
+
nshtrainer/trainer/trainer.py,sha256=6oky6E8cjGqUNzJGyyTO551pE9A6YueOv5oxg1fZVR0,24129
|
151
152
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
152
153
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
153
154
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
@@ -159,6 +160,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
159
160
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
160
161
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
161
162
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
162
|
-
nshtrainer-1.
|
163
|
-
nshtrainer-1.
|
164
|
-
nshtrainer-1.
|
163
|
+
nshtrainer-1.3.0.dist-info/METADATA,sha256=M84AwXCuoJp21_m2IQKYDC-SFWDAdhOy-2fDL1jk9Lw,960
|
164
|
+
nshtrainer-1.3.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
165
|
+
nshtrainer-1.3.0.dist-info/RECORD,,
|
File without changes
|