nshtrainer 0.6.0__tar.gz → 0.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/nn/__init__.py +2 -0
  4. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/nn/mlp.py +67 -1
  5. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/README.md +0 -0
  6. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/__init__.py +0 -0
  7. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  8. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  9. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  10. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  11. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/_snoop.py +0 -0
  12. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/actsave/__init__.py +0 -0
  13. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/actsave/_callback.py +0 -0
  14. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  15. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  16. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/base.py +0 -0
  17. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  18. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/ema.py +0 -0
  19. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  20. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  21. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/interval.py +0 -0
  22. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
  23. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  24. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  25. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  26. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  27. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  28. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/timer.py +0 -0
  29. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  30. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/config.py +0 -0
  31. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/data/__init__.py +0 -0
  32. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  33. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/data/transform.py +0 -0
  34. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/ll/__init__.py +0 -0
  35. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  36. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  37. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  38. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  39. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/__init__.py +0 -0
  40. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/base.py +0 -0
  41. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/config.py +0 -0
  42. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/modules/callback.py +0 -0
  43. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/modules/debug.py +0 -0
  44. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/modules/distributed.py +0 -0
  45. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/modules/logger.py +0 -0
  46. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/modules/profiler.py +0 -0
  47. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  48. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  49. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/nn/module_dict.py +0 -0
  50. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/nn/module_list.py +0 -0
  51. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  52. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/optimizer.py +0 -0
  53. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/runner.py +0 -0
  54. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/scripts/check_env.py +0 -0
  55. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  56. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/trainer/__init__.py +0 -0
  57. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  58. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/trainer/trainer.py +0 -0
  59. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/typecheck.py +0 -0
  60. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/util/environment.py +0 -0
  61. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/util/seed.py +0 -0
  62. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/util/slurm.py +0 -0
  63. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/util/typed.py +0 -0
  64. {nshtrainer-0.6.0 → nshtrainer-0.6.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.6.0
3
+ Version: 0.6.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.6.0"
3
+ version = "0.6.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -1,4 +1,6 @@
1
1
  from .mlp import MLP as MLP
2
+ from .mlp import MLPConfig as MLPConfig
3
+ from .mlp import MLPConfigDict as MLPConfigDict
2
4
  from .mlp import ResidualSequential as ResidualSequential
3
5
  from .module_dict import TypedModuleDict as TypedModuleDict
4
6
  from .module_list import TypedModuleList as TypedModuleList
@@ -2,9 +2,10 @@ import copy
2
2
  from collections.abc import Callable, Sequence
3
3
  from typing import Literal, Protocol, runtime_checkable
4
4
 
5
+ import nshconfig as C
5
6
  import torch
6
7
  import torch.nn as nn
7
- from typing_extensions import override
8
+ from typing_extensions import TypedDict, override
8
9
 
9
10
  from .nonlinearity import BaseNonlinearityConfig
10
11
 
@@ -22,6 +23,71 @@ class ResidualSequential(nn.Sequential):
22
23
  return input + super().forward(input)
23
24
 
24
25
 
26
+ class MLPConfigDict(TypedDict):
27
+ bias: bool
28
+ """Whether to include bias terms in the linear layers."""
29
+
30
+ no_bias_scalar: bool
31
+ """Whether to exclude bias terms when the output dimension is 1."""
32
+
33
+ nonlinearity: BaseNonlinearityConfig | None
34
+ """Activation function to use between layers."""
35
+
36
+ ln: bool | Literal["pre", "post"]
37
+ """Whether to apply layer normalization before or after the linear layers."""
38
+
39
+ dropout: float | None
40
+ """Dropout probability to apply between layers."""
41
+
42
+ residual: bool
43
+ """Whether to use residual connections between layers."""
44
+
45
+
46
+ class MLPConfig(C.Config):
47
+ bias: bool = True
48
+ """Whether to include bias terms in the linear layers."""
49
+
50
+ no_bias_scalar: bool = True
51
+ """Whether to exclude bias terms when the output dimension is 1."""
52
+
53
+ nonlinearity: BaseNonlinearityConfig | None = None
54
+ """Activation function to use between layers."""
55
+
56
+ ln: bool | Literal["pre", "post"] = False
57
+ """Whether to apply layer normalization before or after the linear layers."""
58
+
59
+ dropout: float | None = None
60
+ """Dropout probability to apply between layers."""
61
+
62
+ residual: bool = False
63
+ """Whether to use residual connections between layers."""
64
+
65
+ def to_kwargs(self) -> MLPConfigDict:
66
+ return {
67
+ "bias": self.bias,
68
+ "no_bias_scalar": self.no_bias_scalar,
69
+ "nonlinearity": self.nonlinearity,
70
+ "ln": self.ln,
71
+ "dropout": self.dropout,
72
+ "residual": self.residual,
73
+ }
74
+
75
+ def create_module(
76
+ self,
77
+ dims: Sequence[int],
78
+ pre_layers: Sequence[nn.Module] = [],
79
+ post_layers: Sequence[nn.Module] = [],
80
+ linear_cls: LinearModuleConstructor = nn.Linear,
81
+ ):
82
+ return MLP(
83
+ dims,
84
+ **self.to_kwargs(),
85
+ pre_layers=pre_layers,
86
+ post_layers=post_layers,
87
+ linear_cls=linear_cls,
88
+ )
89
+
90
+
25
91
  def MLP(
26
92
  dims: Sequence[int],
27
93
  activation: BaseNonlinearityConfig
File without changes