nshtrainer 1.2.1__tar.gz → 1.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/PKG-INFO +1 -1
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/pyproject.toml +1 -1
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/distributed_prediction_writer.py +22 -11
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/__init__.py +3 -3
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -4
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/trainer/__init__.py +4 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/__init__.py +6 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/_config.py +0 -10
- nshtrainer-1.3.1/src/nshtrainer/trainer/_distributed_prediction_result.py +80 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/trainer.py +66 -2
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/README.md +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/.nshconfig.generated.json +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/metric_validation.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/.gitattributes +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/data/datamodule.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/loggers/base.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/lr_scheduler/base.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/nn/rng.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/_log_hparams.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/accelerator.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/plugin/base.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/plugin/environment.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/plugin/io.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/plugin/precision.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/trainer/strategy.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/util/typing_utils.py +0 -0
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/distributed_prediction_writer.py
RENAMED
@@ -4,15 +4,19 @@ import functools
|
|
4
4
|
import logging
|
5
5
|
from collections.abc import Iterator, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import
|
7
|
+
from typing import TYPE_CHECKING, ClassVar, Generic, Literal, cast, overload
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from lightning.fabric.utilities.apply_func import move_data_to_device
|
11
11
|
from lightning.pytorch.callbacks import BasePredictionWriter
|
12
|
-
from typing_extensions import final, override
|
12
|
+
from typing_extensions import TypeVar, final, override
|
13
13
|
|
14
14
|
from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
|
15
15
|
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from ..model.base import IndividualSample
|
18
|
+
|
19
|
+
|
16
20
|
log = logging.getLogger(__name__)
|
17
21
|
|
18
22
|
|
@@ -130,7 +134,15 @@ class DistributedPredictionWriter(BasePredictionWriter):
|
|
130
134
|
save(sample, output_dir / f"{sample['index']}.pt")
|
131
135
|
|
132
136
|
|
133
|
-
|
137
|
+
SampleT = TypeVar(
|
138
|
+
"SampleT",
|
139
|
+
bound="IndividualSample",
|
140
|
+
default="IndividualSample",
|
141
|
+
infer_variance=True,
|
142
|
+
)
|
143
|
+
|
144
|
+
|
145
|
+
class DistributedPredictionReader(Sequence[SampleT], Generic[SampleT]):
|
134
146
|
def __init__(self, output_dir: Path):
|
135
147
|
self.output_dir = output_dir
|
136
148
|
|
@@ -139,15 +151,13 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
|
|
139
151
|
return len(list(self.output_dir.glob("*.pt")))
|
140
152
|
|
141
153
|
@overload
|
142
|
-
def __getitem__(self, index: int) ->
|
154
|
+
def __getitem__(self, index: int) -> SampleT: ...
|
143
155
|
|
144
156
|
@overload
|
145
|
-
def __getitem__(self, index: slice) -> list[
|
157
|
+
def __getitem__(self, index: slice) -> list[SampleT]: ...
|
146
158
|
|
147
159
|
@override
|
148
|
-
def __getitem__(
|
149
|
-
self, index: int | slice
|
150
|
-
) -> tuple[Any, Any] | list[tuple[Any, Any]]:
|
160
|
+
def __getitem__(self, index: int | slice) -> SampleT | list[SampleT]:
|
151
161
|
if isinstance(index, slice):
|
152
162
|
# Handle slice indexing
|
153
163
|
indices = range(*index.indices(len(self)))
|
@@ -157,10 +167,11 @@ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
|
|
157
167
|
path = self.output_dir / f"{index}.pt"
|
158
168
|
if not path.exists():
|
159
169
|
raise FileNotFoundError(f"File {path} does not exist.")
|
160
|
-
|
161
|
-
|
170
|
+
|
171
|
+
sample = cast(SampleT, torch.load(path))
|
172
|
+
return sample
|
162
173
|
|
163
174
|
@override
|
164
|
-
def __iter__(self) -> Iterator[
|
175
|
+
def __iter__(self) -> Iterator[SampleT]:
|
165
176
|
for i in range(len(self)):
|
166
177
|
yield self[i]
|
@@ -22,9 +22,6 @@ from nshtrainer.trainer._config import (
|
|
22
22
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
23
23
|
)
|
24
24
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
25
|
-
from nshtrainer.trainer._config import (
|
26
|
-
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
27
|
-
)
|
28
25
|
from nshtrainer.trainer._config import (
|
29
26
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
30
27
|
)
|
@@ -126,6 +123,9 @@ from nshtrainer.trainer.plugin.precision import (
|
|
126
123
|
)
|
127
124
|
from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
|
128
125
|
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
126
|
+
from nshtrainer.trainer.trainer import (
|
127
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
128
|
+
)
|
129
129
|
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
130
130
|
|
131
131
|
from . import _config as _config
|
@@ -18,9 +18,6 @@ from nshtrainer.trainer._config import (
|
|
18
18
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
19
19
|
)
|
20
20
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
21
|
-
from nshtrainer.trainer._config import (
|
22
|
-
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
23
|
-
)
|
24
21
|
from nshtrainer.trainer._config import (
|
25
22
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
26
23
|
)
|
@@ -73,7 +70,6 @@ __all__ = [
|
|
73
70
|
"CheckpointSavingConfig",
|
74
71
|
"DebugFlagCallbackConfig",
|
75
72
|
"DirectoryConfig",
|
76
|
-
"DistributedPredictionWriterConfig",
|
77
73
|
"EarlyStoppingCallbackConfig",
|
78
74
|
"EnvironmentConfig",
|
79
75
|
"GradientClippingConfig",
|
@@ -3,12 +3,16 @@ from __future__ import annotations
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
5
|
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
6
|
+
from nshtrainer.trainer.trainer import (
|
7
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
8
|
+
)
|
6
9
|
from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
|
7
10
|
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
8
11
|
from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
|
9
12
|
|
10
13
|
__all__ = [
|
11
14
|
"AcceleratorConfigBase",
|
15
|
+
"DistributedPredictionWriterConfig",
|
12
16
|
"EnvironmentConfig",
|
13
17
|
"StrategyConfigBase",
|
14
18
|
"TrainerConfig",
|
@@ -1,7 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from ..callbacks import callback_registry as callback_registry
|
4
|
+
from ..callbacks.distributed_prediction_writer import (
|
5
|
+
DistributedPredictionReader as DistributedPredictionReader,
|
6
|
+
)
|
4
7
|
from ._config import TrainerConfig as TrainerConfig
|
8
|
+
from ._distributed_prediction_result import (
|
9
|
+
DistributedPredictionResult as DistributedPredictionResult,
|
10
|
+
)
|
5
11
|
from .accelerator import accelerator_registry as accelerator_registry
|
6
12
|
from .plugin import plugin_registry as plugin_registry
|
7
13
|
from .trainer import Trainer as Trainer
|
@@ -31,7 +31,6 @@ from .._hf_hub import HuggingFaceHubConfig
|
|
31
31
|
from ..callbacks import (
|
32
32
|
BestCheckpointCallbackConfig,
|
33
33
|
CallbackConfig,
|
34
|
-
DistributedPredictionWriterConfig,
|
35
34
|
EarlyStoppingCallbackConfig,
|
36
35
|
LastCheckpointCallbackConfig,
|
37
36
|
NormLoggingCallbackConfig,
|
@@ -702,14 +701,6 @@ class TrainerConfig(C.Config):
|
|
702
701
|
auto_validate_metrics: MetricValidationCallbackConfig | None = None
|
703
702
|
"""If enabled, will automatically validate the metrics before starting the training routine."""
|
704
703
|
|
705
|
-
distributed_predict: DistributedPredictionWriterConfig | None = (
|
706
|
-
DistributedPredictionWriterConfig()
|
707
|
-
)
|
708
|
-
"""If enabled, will use a custom BasePredictionWriter callback to automatically
|
709
|
-
handle distributed prediction. This is useful for running prediction on multiple GPUs
|
710
|
-
seamlessly.
|
711
|
-
"""
|
712
|
-
|
713
704
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
714
705
|
"""
|
715
706
|
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
@@ -778,7 +769,6 @@ class TrainerConfig(C.Config):
|
|
778
769
|
yield self.reduce_lr_on_plateau_sanity_checking
|
779
770
|
yield self.auto_set_debug_flag
|
780
771
|
yield self.auto_validate_metrics
|
781
|
-
yield self.distributed_predict
|
782
772
|
yield from self.callbacks
|
783
773
|
|
784
774
|
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class DistributedPredictionResult:
|
12
|
+
"""Represents the results of a distributed prediction run.
|
13
|
+
|
14
|
+
This dataclass provides easy access to both raw and processed prediction data.
|
15
|
+
"""
|
16
|
+
|
17
|
+
root_dir: Path
|
18
|
+
"""Root directory where predictions are stored."""
|
19
|
+
|
20
|
+
@property
|
21
|
+
def raw_dir(self) -> Path:
|
22
|
+
"""Directory containing raw prediction data."""
|
23
|
+
return self.root_dir / "raw"
|
24
|
+
|
25
|
+
@property
|
26
|
+
def processed_dir(self) -> Path:
|
27
|
+
"""Directory containing processed prediction data."""
|
28
|
+
return self.root_dir / "processed"
|
29
|
+
|
30
|
+
def get_raw_predictions(self, dataloader_idx: int = 0) -> Path:
|
31
|
+
"""Get the directory containing raw predictions for a specific dataloader.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
dataloader_idx: Index of the dataloader
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
Path to the raw predictions directory for the specified dataloader
|
38
|
+
"""
|
39
|
+
raw_loader_dir = self.raw_dir / f"dataloader_{dataloader_idx}"
|
40
|
+
if not raw_loader_dir.exists():
|
41
|
+
log.warning(f"Raw predictions directory {raw_loader_dir} does not exist.")
|
42
|
+
return raw_loader_dir
|
43
|
+
|
44
|
+
def get_processed_reader(self, dataloader_idx: int = 0):
|
45
|
+
"""Get a reader for processed predictions from a specific dataloader.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
dataloader_idx: Index of the dataloader
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
A DistributedPredictionReader for the processed predictions, or None if no data exists
|
52
|
+
"""
|
53
|
+
from ..callbacks.distributed_prediction_writer import (
|
54
|
+
DistributedPredictionReader,
|
55
|
+
)
|
56
|
+
|
57
|
+
processed_loader_dir = self.processed_dir / f"dataloader_{dataloader_idx}"
|
58
|
+
if not processed_loader_dir.exists():
|
59
|
+
log.warning(
|
60
|
+
f"Processed predictions directory {processed_loader_dir} does not exist."
|
61
|
+
)
|
62
|
+
return None
|
63
|
+
|
64
|
+
return DistributedPredictionReader(processed_loader_dir)
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def load(cls, path: Path | str):
|
68
|
+
"""Load prediction results from a directory.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
path: Path to the predictions directory
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
A DistributedPredictionResult instance
|
75
|
+
"""
|
76
|
+
path = Path(path)
|
77
|
+
if not path.exists():
|
78
|
+
raise FileNotFoundError(f"Predictions directory {path} does not exist.")
|
79
|
+
|
80
|
+
return cls(root_dir=path)
|
@@ -4,7 +4,7 @@ import logging
|
|
4
4
|
import os
|
5
5
|
from collections.abc import Callable, Mapping, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING, Any, cast
|
7
|
+
from typing import TYPE_CHECKING, Any, cast, overload
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
@@ -24,9 +24,14 @@ from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
|
24
24
|
|
25
25
|
from .._checkpoint.metadata import write_checkpoint_metadata
|
26
26
|
from ..callbacks.base import resolve_all_callbacks
|
27
|
+
from ..callbacks.distributed_prediction_writer import (
|
28
|
+
DistributedPredictionWriter,
|
29
|
+
DistributedPredictionWriterConfig,
|
30
|
+
)
|
27
31
|
from ..util._environment_info import EnvironmentConfig
|
28
32
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
29
33
|
from ._config import LightningTrainerKwargs, TrainerConfig
|
34
|
+
from ._distributed_prediction_result import DistributedPredictionResult
|
30
35
|
from ._log_hparams import patch_log_hparams_function
|
31
36
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
32
37
|
from .accelerator import AcceleratorConfigBase
|
@@ -537,13 +542,66 @@ class Trainer(LightningTrainer):
|
|
537
542
|
)
|
538
543
|
return cls(hparams)
|
539
544
|
|
545
|
+
@overload
|
540
546
|
def distributed_predict(
|
541
547
|
self,
|
542
548
|
model: LightningModule | None = None,
|
543
549
|
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
544
550
|
datamodule: LightningDataModule | None = None,
|
545
551
|
ckpt_path: str | Path | None = None,
|
546
|
-
|
552
|
+
*,
|
553
|
+
config: DistributedPredictionWriterConfig,
|
554
|
+
) -> DistributedPredictionResult: ...
|
555
|
+
|
556
|
+
@overload
|
557
|
+
def distributed_predict(
|
558
|
+
self,
|
559
|
+
model: LightningModule | None = None,
|
560
|
+
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
561
|
+
datamodule: LightningDataModule | None = None,
|
562
|
+
ckpt_path: str | Path | None = None,
|
563
|
+
*,
|
564
|
+
dirpath: Path | None = None,
|
565
|
+
move_to_cpu_on_save: bool = True,
|
566
|
+
save_raw: bool = True,
|
567
|
+
save_processed: bool = True,
|
568
|
+
) -> DistributedPredictionResult: ...
|
569
|
+
|
570
|
+
def distributed_predict(
|
571
|
+
self,
|
572
|
+
model: LightningModule | None = None,
|
573
|
+
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
574
|
+
datamodule: LightningDataModule | None = None,
|
575
|
+
ckpt_path: str | Path | None = None,
|
576
|
+
*,
|
577
|
+
config: DistributedPredictionWriterConfig | None = None,
|
578
|
+
dirpath: Path | None = None,
|
579
|
+
move_to_cpu_on_save: bool = True,
|
580
|
+
save_raw: bool = True,
|
581
|
+
save_processed: bool = True,
|
582
|
+
) -> DistributedPredictionResult:
|
583
|
+
if config is None:
|
584
|
+
config = DistributedPredictionWriterConfig(
|
585
|
+
dirpath=dirpath,
|
586
|
+
move_to_cpu_on_save=move_to_cpu_on_save,
|
587
|
+
save_raw=save_raw,
|
588
|
+
save_processed=save_processed,
|
589
|
+
)
|
590
|
+
|
591
|
+
# Remove any DistributedPredictionWriter callbacks that are already set
|
592
|
+
# and add the new one.
|
593
|
+
callbacks = self.callbacks.copy()
|
594
|
+
callbacks = [
|
595
|
+
callback
|
596
|
+
for callback in callbacks
|
597
|
+
if not isinstance(callback, DistributedPredictionWriter)
|
598
|
+
]
|
599
|
+
writer_callbacks = list(config.create_callbacks(self.hparams))
|
600
|
+
assert len(writer_callbacks) == 1
|
601
|
+
callback = writer_callbacks[0]
|
602
|
+
callbacks.append(callback)
|
603
|
+
self.callbacks = self._callback_connector._reorder_callbacks(callbacks)
|
604
|
+
|
547
605
|
self.predict(
|
548
606
|
model,
|
549
607
|
dataloaders,
|
@@ -551,3 +609,9 @@ class Trainer(LightningTrainer):
|
|
551
609
|
return_predictions=False,
|
552
610
|
ckpt_path=ckpt_path,
|
553
611
|
)
|
612
|
+
|
613
|
+
# Wait for all processes to finish
|
614
|
+
self.strategy.barrier("Trainer.distributed_predict")
|
615
|
+
|
616
|
+
# Return an object that contains information about the predictions
|
617
|
+
return DistributedPredictionResult(root_dir=callback.output_dir)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/print_table/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/loggers/tensorboard/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/accelerator/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/base/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py
RENAMED
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py
RENAMED
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/util/_environment_info/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.2.1 → nshtrainer-1.3.1}/src/nshtrainer/configs/util/config/duration/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|