nshtrainer 0.10.9__py3-none-any.whl → 0.10.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  import math
2
- from logging import getLogger
3
3
 
4
4
  from lightning.fabric.utilities.rank_zero import _get_rank
5
5
  from lightning.pytorch import Trainer
@@ -7,7 +7,7 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
7
7
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
8
8
  from typing_extensions import override
9
9
 
10
- log = getLogger(__name__)
10
+ log = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  class EarlyStopping(_EarlyStopping):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal
3
3
 
4
4
  import torch
@@ -7,7 +7,7 @@ from typing_extensions import override
7
7
 
8
8
  from .base import CallbackConfigBase
9
9
 
10
- log = getLogger(__name__)
10
+ log = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  def finite_checks(
@@ -1,6 +1,6 @@
1
+ import logging
1
2
  import re
2
3
  from datetime import timedelta
3
- from logging import getLogger
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Literal
6
6
 
@@ -15,7 +15,7 @@ from .base import CallbackConfigBase
15
15
  if TYPE_CHECKING:
16
16
  from ..model.config import BaseConfig
17
17
 
18
- log = getLogger(__name__)
18
+ log = logging.getLogger(__name__)
19
19
 
20
20
 
21
21
  def _convert_string(input_string: str):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal, cast
3
3
 
4
4
  import torch
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from .base import CallbackConfigBase
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  def grad_norm(
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
3
3
 
4
4
  from typing_extensions import NotRequired, override
@@ -6,7 +6,7 @@ from typing_extensions import NotRequired, override
6
6
  from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
7
7
  from .base import CallbackConfigBase
8
8
 
9
- log = getLogger(__name__)
9
+ log = logging.getLogger(__name__)
10
10
 
11
11
 
12
12
  class ThroughputMonitorBatchStats(TypedDict):
@@ -1,4 +1,4 @@
1
- from logging import getLogger
1
+ import logging
2
2
  from typing import Literal, Protocol, cast, runtime_checkable
3
3
 
4
4
  import torch.nn as nn
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from .base import CallbackConfigBase
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  @runtime_checkable
@@ -1,6 +1,6 @@
1
1
  import heapq
2
+ import logging
2
3
  from functools import cached_property
3
- from logging import getLogger
4
4
  from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  import numpy as np
@@ -10,7 +10,7 @@ from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
10
10
  from torch.utils.data import BatchSampler, Dataset, DistributedSampler
11
11
  from typing_extensions import override
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def _all_gather(tensor: torch.Tensor, device: torch.device | None = None):
nshtrainer/model/base.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import inspect
2
+ import logging
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import MutableMapping
4
- from logging import getLogger
5
5
  from typing import IO, TYPE_CHECKING, Any, Generic, cast
6
6
 
7
7
  import torch
@@ -21,7 +21,7 @@ from .modules.profiler import ProfilerMixin
21
21
  from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
22
22
  from .modules.shared_parameters import SharedParametersModuleMixin
23
23
 
24
- log = getLogger(__name__)
24
+ log = logging.getLogger(__name__)
25
25
 
26
26
  THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
27
27
 
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import logging
2
3
  import os
3
4
  import string
4
5
  import time
@@ -6,7 +7,6 @@ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
  from collections.abc import Iterable, Sequence
8
9
  from datetime import timedelta
9
- from logging import getLogger
10
10
  from pathlib import Path
11
11
  from typing import (
12
12
  Annotated,
@@ -46,7 +46,7 @@ from ..callbacks.base import CallbackConfigBase
46
46
  from ..metrics import MetricConfig
47
47
  from ._environment import EnvironmentConfig
48
48
 
49
- log = getLogger(__name__)
49
+ log = logging.getLogger(__name__)
50
50
 
51
51
 
52
52
  class IdSeedWarning(Warning):
@@ -1,6 +1,6 @@
1
+ import logging
1
2
  from collections import abc
2
3
  from collections.abc import Callable, Iterable
3
- from logging import getLogger
4
4
  from typing import Any, TypeAlias, cast, final
5
5
 
6
6
  from lightning.pytorch import Callback, LightningModule
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from ...util.typing_utils import mixin_base_type
11
11
 
12
- log = getLogger(__name__)
12
+ log = logging.getLogger(__name__)
13
13
 
14
14
  CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
15
15
 
@@ -1,9 +1,9 @@
1
- from logging import getLogger
1
+ import logging
2
2
 
3
3
  import torch
4
4
  import torch.distributed
5
5
 
6
- log = getLogger(__name__)
6
+ log = logging.getLogger(__name__)
7
7
 
8
8
 
9
9
  class DebugModuleMixin:
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  from collections.abc import Mapping
2
- from logging import getLogger
3
3
  from typing import cast
4
4
 
5
5
  import torch
@@ -14,7 +14,7 @@ from ...util.typing_utils import mixin_base_type
14
14
  from ..config import BaseConfig
15
15
  from .callback import CallbackModuleMixin
16
16
 
17
- log = getLogger(__name__)
17
+ log = logging.getLogger(__name__)
18
18
 
19
19
 
20
20
  def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
@@ -1,5 +1,5 @@
1
+ import logging
1
2
  from collections.abc import Sequence
2
- from logging import getLogger
3
3
  from typing import cast
4
4
 
5
5
  import torch.nn as nn
@@ -10,7 +10,7 @@ from ...util.typing_utils import mixin_base_type
10
10
  from ..config import BaseConfig
11
11
  from .callback import CallbackRegistrarModuleMixin
12
12
 
13
- log = getLogger(__name__)
13
+ log = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
@@ -1,8 +1,8 @@
1
+ import logging
1
2
  import os
2
3
  from contextlib import contextmanager
3
- from logging import getLogger
4
4
 
5
- log = getLogger(__name__)
5
+ log = logging.getLogger(__name__)
6
6
 
7
7
 
8
8
  @contextmanager
nshtrainer/util/seed.py CHANGED
@@ -1,8 +1,8 @@
1
- from logging import getLogger
1
+ import logging
2
2
 
3
3
  import lightning.fabric.utilities.seed as LS
4
4
 
5
- log = getLogger(__name__)
5
+ log = logging.getLogger(__name__)
6
6
 
7
7
 
8
8
  def seed_everything(seed: int | None, *, workers: bool = False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.9
3
+ Version: 0.10.10
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,22 +9,22 @@ nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gO
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=aY6T_NAzaFAVU8WMHOXnWL5wd2bi8eVxeU2S0iAs70c,4446
11
11
  nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
12
- nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
12
+ nshtrainer/callbacks/early_stopping.py,sha256=LGn3rdbvkFfUo9kwMzK4eMGlPAqD9uFdowDx6VdfozQ,3761
13
13
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
14
- nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
14
+ nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
15
15
  nshtrainer/callbacks/gradient_skipping.py,sha256=pqu5AELx4ctJxR2Y7YSSiGd5oGauVCTZFCEIIS6s88w,3665
16
16
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
17
17
  nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=zCRAUsqW-2PaoIwVKlXOqdh2uF_B_YUUTmQO1wSomR8,2489
18
18
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
19
- nshtrainer/callbacks/model_checkpoint.py,sha256=N0raLsHlCVSbO3QU5eNFUXUDqxxW3C73oQwceMnFE_k,5955
20
- nshtrainer/callbacks/norm_logging.py,sha256=EWyrfkp8iHjQi9iAAXHxb0xStw2RwkdpKG2_gLarQRA,6281
19
+ nshtrainer/callbacks/model_checkpoint.py,sha256=wkT8sHGkIuatm4gSn4W-fqicA_HG5FBdaGA5THPFj4M,5948
20
+ nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
21
21
  nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKhA9qBEosiV8ZNhhZ6lI,3355
22
22
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
23
- nshtrainer/callbacks/throughput_monitor.py,sha256=4EF3b79HdHiRgBGIFDyD4O1oywb5h1tV8nml7NuuDjU,1845
23
+ nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
24
24
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
25
- nshtrainer/callbacks/wandb_watch.py,sha256=0JKDEPSDmiKverFm00lFfufOvtvr49akFXScvdUQnqc,2930
25
+ nshtrainer/callbacks/wandb_watch.py,sha256=EJ93mtJlph4BZsXh8HJPNiw2VNSm2N6TOwpCwqRAeKI,2923
26
26
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
27
- nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
27
+ nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
28
28
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
29
29
  nshtrainer/ll/__init__.py,sha256=dD0ISxHJ2lg1HLSM0b3db7TBlsPpQCtChnuYO-c2oqI,2635
30
30
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -51,15 +51,15 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
51
51
  nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
52
52
  nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
53
53
  nshtrainer/model/_environment.py,sha256=oTtecQeF5oY2RV7UkkSLnzDy3clz4AUkf9oocD6-e54,23115
54
- nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
55
- nshtrainer/model/config.py,sha256=FRjn2pVbCzitBeIFcKXXPQ0TCmqvAwSylAvfedf1NHg,53356
56
- nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
57
- nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
54
+ nshtrainer/model/base.py,sha256=WtCj0-nLWeW04Tu2TWVjIq0D-jW_kMN2hg--4VWVnvE,17505
55
+ nshtrainer/model/config.py,sha256=orGBrp8TXnHksfAzXxNJVDdo0X_iIn_nda6BZDS9N70,53349
56
+ nshtrainer/model/modules/callback.py,sha256=K0-cyEtBcQhI7Q2e-AGTE8T-GghUPY9DYmneU6ULV6g,6401
57
+ nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
58
58
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
59
59
  nshtrainer/model/modules/logger.py,sha256=YYhehQysqTjuVFcd_EREYDh57CIlezidFBS2Ohp_xKo,5661
60
60
  nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
61
- nshtrainer/model/modules/rlp_sanity_checks.py,sha256=o6gUceFwsuDHmL8eLOYuT3JGXFzq_qc4awl2RWaBygU,8900
62
- nshtrainer/model/modules/shared_parameters.py,sha256=mD5wrlBE3c025vzVdTpnSyC8yxzuI-aUWMmPhqPT0a0,2694
61
+ nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
62
+ nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
63
63
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
64
64
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
65
65
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
@@ -73,11 +73,11 @@ nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2f
73
73
  nshtrainer/trainer/checkpoint_connector.py,sha256=xoqI2dcPnlNFPPLVIU6dBOvRPC9PtfX5qu__xV1lx0Y,2124
74
74
  nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
75
75
  nshtrainer/trainer/trainer.py,sha256=tFyzIsF8c-FABTH6wwDOR9y8kydVJqeVO7PDNFMvhSU,16950
76
- nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
77
- nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
76
+ nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
77
+ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
78
78
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
79
79
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
80
80
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
81
- nshtrainer-0.10.9.dist-info/METADATA,sha256=zpMnR1Jwc9kX7BhnHT6NA3nsPUwQzMAraBPOGsUSD1w,695
82
- nshtrainer-0.10.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
83
- nshtrainer-0.10.9.dist-info/RECORD,,
81
+ nshtrainer-0.10.10.dist-info/METADATA,sha256=AQipVj-dOXT3cPwzbfvg6u3KdgLocc1BeaSLjndnL1A,696
82
+ nshtrainer-0.10.10.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
83
+ nshtrainer-0.10.10.dist-info/RECORD,,