stackformer 0.0.1__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. stackformer-0.1.9/PKG-INFO +185 -0
  2. stackformer-0.1.9/README.md +151 -0
  3. stackformer-0.1.9/pyproject.toml +61 -0
  4. stackformer-0.1.9/stackformer/__init__.py +74 -0
  5. stackformer-0.1.9/stackformer/amp/__init__.py +5 -0
  6. stackformer-0.1.9/stackformer/amp/scaler.py +84 -0
  7. stackformer-0.1.9/stackformer/config.py +41 -0
  8. stackformer-0.1.9/stackformer/distributed/__init__.py +31 -0
  9. stackformer-0.1.9/stackformer/distributed/ddp.py +94 -0
  10. stackformer-0.1.9/stackformer/engine/__init__.py +6 -0
  11. stackformer-0.1.9/stackformer/engine/checkpoint.py +176 -0
  12. stackformer-0.1.9/stackformer/engine/engine.py +213 -0
  13. stackformer-0.1.9/stackformer/engine/state.py +81 -0
  14. stackformer-0.1.9/stackformer/engine/trainer.py +325 -0
  15. stackformer-0.1.9/stackformer/generate.py +120 -0
  16. stackformer-0.1.9/stackformer/logging/__init__.py +20 -0
  17. stackformer-0.1.9/stackformer/logging/csv_logger.py +63 -0
  18. stackformer-0.1.9/stackformer/logging/logger.py +82 -0
  19. stackformer-0.1.9/stackformer/logging/metrics.py +132 -0
  20. stackformer-0.1.9/stackformer/logging/tensorboard_logger.py +58 -0
  21. stackformer-0.1.9/stackformer/logging/wandb_logger.py +74 -0
  22. stackformer-0.1.9/stackformer/logging/wb_logger.py +5 -0
  23. stackformer-0.1.9/stackformer/metrics.py +21 -0
  24. stackformer-0.1.9/stackformer/models/Google.py +244 -0
  25. stackformer-0.1.9/stackformer/models/Meta.py +229 -0
  26. stackformer-0.1.9/stackformer/models/OpenAI.py +363 -0
  27. stackformer-0.1.9/stackformer/models/Transformer.py +166 -0
  28. stackformer-0.1.9/stackformer/models/__init__.py +16 -0
  29. stackformer-0.1.9/stackformer/modules/Attention.py +1167 -0
  30. stackformer-0.1.9/stackformer/modules/Feed_forward.py +403 -0
  31. stackformer-0.1.9/stackformer/modules/Masking.py +244 -0
  32. stackformer-0.1.9/stackformer/modules/Normalization.py +103 -0
  33. stackformer-0.1.9/stackformer/modules/__init__.py +67 -0
  34. stackformer-0.1.9/stackformer/modules/position_embedding.py +177 -0
  35. stackformer-0.1.9/stackformer/optim/__init__.py +21 -0
  36. stackformer-0.1.9/stackformer/optim/factories.py +295 -0
  37. stackformer-0.1.9/stackformer/optim/loss_fn.py +121 -0
  38. stackformer-0.1.9/stackformer/training/__init__.py +3 -0
  39. stackformer-0.1.9/stackformer/training/loops.py +75 -0
  40. stackformer-0.1.9/stackformer/utils/__init__.py +31 -0
  41. stackformer-0.1.9/stackformer/utils/device.py +70 -0
  42. stackformer-0.1.9/stackformer/utils/utils.py +83 -0
  43. stackformer-0.1.9/stackformer/vision/__init__.py +9 -0
  44. stackformer-0.1.9/stackformer/vision/segformer.py +316 -0
  45. stackformer-0.1.9/stackformer/vision/vit.py +182 -0
  46. stackformer-0.1.9/stackformer.egg-info/PKG-INFO +185 -0
  47. stackformer-0.1.9/stackformer.egg-info/SOURCES.txt +53 -0
  48. stackformer-0.1.9/stackformer.egg-info/requires.txt +10 -0
  49. stackformer-0.1.9/stackformer.egg-info/top_level.txt +1 -0
  50. stackformer-0.1.9/tests/test_distributed.py +18 -0
  51. stackformer-0.1.9/tests/test_vision.py +90 -0
  52. stackformer-0.0.1/PKG-INFO +0 -75
  53. stackformer-0.0.1/README.md +0 -54
  54. stackformer-0.0.1/models/GPT_2.py +0 -181
  55. stackformer-0.0.1/modules/Attention.py +0 -533
  56. stackformer-0.0.1/modules/Feed_forward.py +0 -59
  57. stackformer-0.0.1/modules/Normalization.py +0 -41
  58. stackformer-0.0.1/modules/__init__.py +0 -0
  59. stackformer-0.0.1/modules/mask.py +0 -36
  60. stackformer-0.0.1/modules/position_embedding.py +0 -43
  61. stackformer-0.0.1/modules/tokenizer.py +0 -25
  62. stackformer-0.0.1/pyproject.toml +0 -25
  63. stackformer-0.0.1/setup.py +0 -31
  64. stackformer-0.0.1/stackformer.egg-info/PKG-INFO +0 -75
  65. stackformer-0.0.1/stackformer.egg-info/SOURCES.txt +0 -18
  66. stackformer-0.0.1/stackformer.egg-info/requires.txt +0 -2
  67. stackformer-0.0.1/stackformer.egg-info/top_level.txt +0 -2
  68. {stackformer-0.0.1 → stackformer-0.1.9}/LICENSE +0 -0
  69. {stackformer-0.0.1 → stackformer-0.1.9}/setup.cfg +0 -0
  70. /stackformer-0.0.1/models/__init__.py → /stackformer-0.1.9/stackformer/py.typed +0 -0
  71. {stackformer-0.0.1 → stackformer-0.1.9}/stackformer.egg-info/dependency_links.txt +0 -0
@@ -0,0 +1,185 @@
1
+ Metadata-Version: 2.4
2
+ Name: stackformer
3
+ Version: 0.1.9
4
+ Summary: Modular transformer blocks built in PyTorch
5
+ Author-email: Gurumurthy <gurumurthy.00300@gmail.com>
6
+ License: MIT
7
+ Project-URL: Repository, https://github.com/Gurumurthy30/Stackformer
8
+ Project-URL: Releases, https://github.com/Gurumurthy30/Stackformer/releases
9
+ Project-URL: Issue-Tracker, https://github.com/Gurumurthy30/Stackformer/issues
10
+ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussions
11
+ Project-URL: Documentation, https://github.com/Gurumurthy30/Stackformer/tree/main/docs
12
+ Keywords: transformer,pytorch,deep-learning,attention,llm,machine-learning
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Typing :: Typed
18
+ Classifier: Operating System :: OS Independent
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
21
+ Requires-Python: >=3.10
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch<2.7,>=2.0
25
+ Requires-Dist: numpy<3.0,>=1.23
26
+ Requires-Dist: tqdm<5.0,>=4.5
27
+ Requires-Dist: tensorboard<3.0,>=2.14
28
+ Requires-Dist: wandb<1.0,>=0.17
29
+ Provides-Extra: dev
30
+ Requires-Dist: pytest<9.0,>=8.0; extra == "dev"
31
+ Requires-Dist: pytest-cov<6.0,>=5.0; extra == "dev"
32
+ Requires-Dist: build<2.0,>=1.2; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ <p align="center">
36
+ <img src="assets/logo.png" alt="StackFormer logo" width="560" />
37
+ </p>
38
+
39
+ <p align="center">
40
+ <a href="https://pypi.org/project/stackformer/"><img src="https://img.shields.io/pypi/v/stackformer.svg" alt="PyPI version" /></a>
41
+ <a href="https://pypi.org/project/stackformer/"><img src="https://img.shields.io/pypi/pyversions/stackformer.svg" alt="Python versions" /></a>
42
+ <a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License" /></a>
43
+ <a href="https://github.com/Gurumurthy30/Stackformer/actions">
44
+ <img src="https://img.shields.io/github/actions/workflow/status/Gurumurthy30/Stackformer/core-tests.yml?branch=main&label=CI" alt="CI status" />
45
+ </a>
46
+ </p>
47
+
48
+ # StackFormer
49
+
50
+ StackFormer is a modular PyTorch framework for building, training, and experimenting with Transformer architectures.
51
+
52
+ ## Overview
53
+
54
+ StackFormer is designed for fast experimentation with reusable Transformer building blocks and model implementations. It supports both language and vision workflows in a single modular codebase. The framework is built for research, prototyping, and iterative model development with practical training infrastructure.
55
+
56
+ ## Key Features
57
+
58
+ - Modular transformer components
59
+ - GPT / LLaMA / Gemma-style model implementations
60
+ - Vision models (ViT, SegFormer)
61
+ - Trainer infrastructure with AMP mixed precision and DDP support
62
+ - Logging and metrics utilities
63
+ - Checkpointing and resume training
64
+ - CI-tested training infrastructure
65
+
66
+ ## Project Structure
67
+
68
+ ```text
69
+ Stackformer/
70
+ ├── assets/
71
+ ├── docs/
72
+ │ ├── user_docs/
73
+ │ └── developer_docs/
74
+ ├── examples/
75
+ ├── stackformer/
76
+ │ ├── __init__.py
77
+ │ ├── generate.py
78
+ │ ├── metrics.py
79
+ │ ├── py.typed
80
+ │ ├── amp/
81
+ │ │ └── scaler.py
82
+ │ ├── distributed/
83
+ │ │ └── ddp.py
84
+ │ ├── engine/
85
+ │ │ ├── checkpoint.py
86
+ │ │ ├── engine.py
87
+ │ │ ├── state.py
88
+ │ │ └── trainer.py
89
+ │ ├── logging/
90
+ │ │ ├── csv_logger.py
91
+ │ │ ├── logger.py
92
+ │ │ ├── metrics.py
93
+ │ │ ├── tensorboard_logger.py
94
+ │ │ ├── wandb_logger.py
95
+ │ │ └── wb_logger.py
96
+ │ ├── models/
97
+ │ │ ├── OpenAI.py
98
+ │ │ ├── Meta.py
99
+ │ │ ├── Google.py
100
+ │ │ └── Transformer.py
101
+ │ ├── modules/
102
+ │ │ ├── Attention.py
103
+ │ │ ├── Feed_forward.py
104
+ │ │ ├── Masking.py
105
+ │ │ ├── Normalization.py
106
+ │ │ ├── position_embedding.py
107
+ │ ├── optim/
108
+ │ │ ├── factories.py
109
+ │ │ └── loss_fn.py
110
+ │ ├── training/
111
+ │ │ └── loops.py
112
+ │ ├── utils/
113
+ │ │ ├── device.py
114
+ │ │ └── utils.py
115
+ │ └── vision/
116
+ │ ├── vit.py
117
+ │ └── segformer.py
118
+ ├── tests/
119
+ │ ├── integration/
120
+ │ ├── models/
121
+ │ ├── modules/
122
+ │ ├── trainer/
123
+ │ ├── utils/
124
+ │ ├── conftest.py
125
+ │ ├── test_distributed.py
126
+ │ └── test_vision.py
127
+ ├── LICENSE
128
+ ├── pyproject.toml
129
+ └── README.md
130
+ ```
131
+
132
+ ## Installation
133
+
134
+ Python >= 3.10
135
+
136
+ ### Install from PyPI
137
+
138
+ ```bash
139
+ pip install stackformer
140
+ ```
141
+
142
+ ### Install from source
143
+
144
+ ```bash
145
+ git clone https://github.com/Gurumurthy30/Stackformer.git
146
+ cd Stackformer
147
+ pip install -e .
148
+ ```
149
+
150
+ ## Quick Start
151
+
152
+ ```python
153
+ from stackformer.engine import Trainer
154
+ import torch.nn as nn
155
+
156
+ model = nn.Linear(10, 1)
157
+
158
+ trainer = Trainer(model=model)
159
+ trainer.fit(dataset)
160
+ ```
161
+
162
+ ## Examples
163
+
164
+ More runnable examples are available in:
165
+
166
+ ```text
167
+ examples/
168
+ ```
169
+
170
+ ```text
171
+ examples/simple_train.py
172
+ examples/simple_trainer_v2.py
173
+ examples/train_ddp.py
174
+ ```
175
+
176
+ ## Documentation
177
+
178
+ - User documentation: [docs/user_docs/installation.md](docs/user_docs/installation.md)
179
+ - Developer documentation: [docs/developer_docs/architecture.md](docs/developer_docs/architecture.md)
180
+
181
+ ## Community
182
+
183
+ - Issues: https://github.com/Gurumurthy30/Stackformer/issues
184
+ - Discussions: https://github.com/Gurumurthy30/Stackformer/discussions
185
+ - Releases: https://github.com/Gurumurthy30/Stackformer/releases
@@ -0,0 +1,151 @@
1
+ <p align="center">
2
+ <img src="assets/logo.png" alt="StackFormer logo" width="560" />
3
+ </p>
4
+
5
+ <p align="center">
6
+ <a href="https://pypi.org/project/stackformer/"><img src="https://img.shields.io/pypi/v/stackformer.svg" alt="PyPI version" /></a>
7
+ <a href="https://pypi.org/project/stackformer/"><img src="https://img.shields.io/pypi/pyversions/stackformer.svg" alt="Python versions" /></a>
8
+ <a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License" /></a>
9
+ <a href="https://github.com/Gurumurthy30/Stackformer/actions">
10
+ <img src="https://img.shields.io/github/actions/workflow/status/Gurumurthy30/Stackformer/core-tests.yml?branch=main&label=CI" alt="CI status" />
11
+ </a>
12
+ </p>
13
+
14
+ # StackFormer
15
+
16
+ StackFormer is a modular PyTorch framework for building, training, and experimenting with Transformer architectures.
17
+
18
+ ## Overview
19
+
20
+ StackFormer is designed for fast experimentation with reusable Transformer building blocks and model implementations. It supports both language and vision workflows in a single modular codebase. The framework is built for research, prototyping, and iterative model development with practical training infrastructure.
21
+
22
+ ## Key Features
23
+
24
+ - Modular transformer components
25
+ - GPT / LLaMA / Gemma-style model implementations
26
+ - Vision models (ViT, SegFormer)
27
+ - Trainer infrastructure with AMP mixed precision and DDP support
28
+ - Logging and metrics utilities
29
+ - Checkpointing and resume training
30
+ - CI-tested training infrastructure
31
+
32
+ ## Project Structure
33
+
34
+ ```text
35
+ Stackformer/
36
+ ├── assets/
37
+ ├── docs/
38
+ │ ├── user_docs/
39
+ │ └── developer_docs/
40
+ ├── examples/
41
+ ├── stackformer/
42
+ │ ├── __init__.py
43
+ │ ├── generate.py
44
+ │ ├── metrics.py
45
+ │ ├── py.typed
46
+ │ ├── amp/
47
+ │ │ └── scaler.py
48
+ │ ├── distributed/
49
+ │ │ └── ddp.py
50
+ │ ├── engine/
51
+ │ │ ├── checkpoint.py
52
+ │ │ ├── engine.py
53
+ │ │ ├── state.py
54
+ │ │ └── trainer.py
55
+ │ ├── logging/
56
+ │ │ ├── csv_logger.py
57
+ │ │ ├── logger.py
58
+ │ │ ├── metrics.py
59
+ │ │ ├── tensorboard_logger.py
60
+ │ │ ├── wandb_logger.py
61
+ │ │ └── wb_logger.py
62
+ │ ├── models/
63
+ │ │ ├── OpenAI.py
64
+ │ │ ├── Meta.py
65
+ │ │ ├── Google.py
66
+ │ │ └── Transformer.py
67
+ │ ├── modules/
68
+ │ │ ├── Attention.py
69
+ │ │ ├── Feed_forward.py
70
+ │ │ ├── Masking.py
71
+ │ │ ├── Normalization.py
72
+ │ │ ├── position_embedding.py
73
+ │ ├── optim/
74
+ │ │ ├── factories.py
75
+ │ │ └── loss_fn.py
76
+ │ ├── training/
77
+ │ │ └── loops.py
78
+ │ ├── utils/
79
+ │ │ ├── device.py
80
+ │ │ └── utils.py
81
+ │ └── vision/
82
+ │ ├── vit.py
83
+ │ └── segformer.py
84
+ ├── tests/
85
+ │ ├── integration/
86
+ │ ├── models/
87
+ │ ├── modules/
88
+ │ ├── trainer/
89
+ │ ├── utils/
90
+ │ ├── conftest.py
91
+ │ ├── test_distributed.py
92
+ │ └── test_vision.py
93
+ ├── LICENSE
94
+ ├── pyproject.toml
95
+ └── README.md
96
+ ```
97
+
98
+ ## Installation
99
+
100
+ Python >= 3.10
101
+
102
+ ### Install from PyPI
103
+
104
+ ```bash
105
+ pip install stackformer
106
+ ```
107
+
108
+ ### Install from source
109
+
110
+ ```bash
111
+ git clone https://github.com/Gurumurthy30/Stackformer.git
112
+ cd Stackformer
113
+ pip install -e .
114
+ ```
115
+
116
+ ## Quick Start
117
+
118
+ ```python
119
+ from stackformer.engine import Trainer
120
+ import torch.nn as nn
121
+
122
+ model = nn.Linear(10, 1)
123
+
124
+ trainer = Trainer(model=model)
125
+ trainer.fit(dataset)
126
+ ```
127
+
128
+ ## Examples
129
+
130
+ More runnable examples are available in:
131
+
132
+ ```text
133
+ examples/
134
+ ```
135
+
136
+ ```text
137
+ examples/simple_train.py
138
+ examples/simple_trainer_v2.py
139
+ examples/train_ddp.py
140
+ ```
141
+
142
+ ## Documentation
143
+
144
+ - User documentation: [docs/user_docs/installation.md](docs/user_docs/installation.md)
145
+ - Developer documentation: [docs/developer_docs/architecture.md](docs/developer_docs/architecture.md)
146
+
147
+ ## Community
148
+
149
+ - Issues: https://github.com/Gurumurthy30/Stackformer/issues
150
+ - Discussions: https://github.com/Gurumurthy30/Stackformer/discussions
151
+ - Releases: https://github.com/Gurumurthy30/Stackformer/releases
@@ -0,0 +1,61 @@
1
+ [project]
2
+ name = "stackformer"
3
+ version = "0.1.9"
4
+ description = "Modular transformer blocks built in PyTorch"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ license = { text = "MIT" }
8
+
9
+ authors = [
10
+ { name = "Gurumurthy", email = "gurumurthy.00300@gmail.com" }
11
+ ]
12
+
13
+ keywords = ["transformer", "pytorch", "deep-learning", "attention", "llm", "machine-learning"]
14
+
15
+ dependencies = [
16
+ "torch>=2.0,<2.7",
17
+ "numpy>=1.23,<3.0",
18
+ "tqdm>=4.5,<5.0",
19
+ "tensorboard>=2.14,<3.0",
20
+ "wandb>=0.17,<1.0"
21
+ ]
22
+
23
+ classifiers = [
24
+ "Programming Language :: Python :: 3",
25
+ "Programming Language :: Python :: 3.10",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Programming Language :: Python :: 3.12",
28
+ "Typing :: Typed",
29
+ "Operating System :: OS Independent",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ "Topic :: Software Development :: Libraries :: Python Modules"
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ dev = [
36
+ "pytest>=8.0,<9.0",
37
+ "pytest-cov>=5.0,<6.0",
38
+ "build>=1.2,<2.0"
39
+ ]
40
+
41
+ [project.urls]
42
+ Repository = "https://github.com/Gurumurthy30/Stackformer"
43
+ Releases = "https://github.com/Gurumurthy30/Stackformer/releases"
44
+ Issue-Tracker = "https://github.com/Gurumurthy30/Stackformer/issues"
45
+ Discussions = "https://github.com/Gurumurthy30/Stackformer/discussions"
46
+ Documentation = "https://github.com/Gurumurthy30/Stackformer/tree/main/docs"
47
+
48
+ [tool.setuptools.packages.find]
49
+ include = ["stackformer*"]
50
+ exclude = ["assets*"]
51
+
52
+ [tool.setuptools.package-data]
53
+ stackformer = ["py.typed"]
54
+
55
+ [build-system]
56
+ requires = ["setuptools>=61", "wheel"]
57
+ build-backend = "setuptools.build_meta"
58
+
59
+ [tool.pytest.ini_options]
60
+ testpaths = ["tests"]
61
+ addopts = "-q"
@@ -0,0 +1,74 @@
1
+ """Top-level public API for Stackformer."""
2
+
3
+ from .generate import text_generate
4
+ from .config import GenerationConfig, ModelConfig, TrainingConfig
5
+ from .models import GPT_1, GPT_2, gemma_1_2b, gemma_1_7b, llama_1, llama_2, transformer
6
+ from .modules import (
7
+ AbsolutePositionEmbedding,
8
+ Cross_MultiHead_Attention,
9
+ FF_GELU,
10
+ FF_GeGLU,
11
+ FF_LeakyReLU,
12
+ FF_ReLU,
13
+ FF_Sigmoid,
14
+ FF_SiLU,
15
+ FF_SwiGLU,
16
+ Group_query_Attention,
17
+ Group_query_Attention_With_RoPE,
18
+ LayerNormalization,
19
+ Multi_Head_Attention,
20
+ Multi_Head_Attention_With_RoPE,
21
+ Multi_query_Attention,
22
+ Multi_query_Attention_With_RoPE,
23
+ RMSNormalization,
24
+ RoPE,
25
+ Self_Attention,
26
+ SinusoidalPositionalEmbedding,
27
+ kv_cache_group_query,
28
+ kv_cache_multihead,
29
+ make_mask,
30
+ )
31
+ from .vision import SegFormerB0, ViT
32
+
33
+ from .engine import Trainer
34
+
35
+ __all__ = [
36
+ "AbsolutePositionEmbedding",
37
+ "Cross_MultiHead_Attention",
38
+ "FF_GELU",
39
+ "FF_GeGLU",
40
+ "FF_LeakyReLU",
41
+ "FF_ReLU",
42
+ "FF_Sigmoid",
43
+ "FF_SiLU",
44
+ "FF_SwiGLU",
45
+ "GPT_1",
46
+ "GPT_2",
47
+ "Group_query_Attention",
48
+ "Group_query_Attention_With_RoPE",
49
+ "LayerNormalization",
50
+ "Multi_Head_Attention",
51
+ "Multi_Head_Attention_With_RoPE",
52
+ "Multi_query_Attention",
53
+ "Multi_query_Attention_With_RoPE",
54
+ "RMSNormalization",
55
+ "RoPE",
56
+ "SegFormerB0",
57
+ "Self_Attention",
58
+ "SinusoidalPositionalEmbedding",
59
+ "ViT",
60
+ "gemma_1_2b",
61
+ "gemma_1_7b",
62
+ "kv_cache_group_query",
63
+ "kv_cache_multihead",
64
+ "llama_1",
65
+ "llama_2",
66
+ "make_mask",
67
+ "text_generate",
68
+ "transformer",
69
+ "GenerationConfig",
70
+ "ModelConfig",
71
+ "TrainingConfig",
72
+ ]
73
+
74
+ __all__.append("Trainer")
@@ -0,0 +1,5 @@
1
+ """Automatic mixed precision utilities."""
2
+
3
+ from .scaler import AMPScaler, initialize_scaler, scale_loss, step_optimizer, update_scaler
4
+
5
+ __all__ = ["AMPScaler", "initialize_scaler", "scale_loss", "step_optimizer", "update_scaler"]
@@ -0,0 +1,84 @@
1
+ """Utilities for automatic mixed precision (AMP)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from contextlib import nullcontext
6
+
7
+ import torch
8
+
9
+
10
+ class AMPScaler:
11
+ """Small wrapper around PyTorch AMP scaler/autocast.
12
+
13
+ AMP is automatically disabled on CPU-only runs.
14
+ """
15
+
16
+ def __init__(self, enabled: bool = True):
17
+ self.enabled = bool(enabled and torch.cuda.is_available())
18
+
19
+ # Prefer modern torch.amp API, with fallback for older versions.
20
+ if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
21
+ self.scaler = torch.amp.GradScaler("cuda", enabled=self.enabled)
22
+ self._autocast = lambda: torch.amp.autocast(device_type="cuda", enabled=self.enabled)
23
+ else:
24
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.enabled)
25
+ self._autocast = lambda: torch.cuda.amp.autocast(enabled=self.enabled)
26
+
27
+ def autocast(self):
28
+ if not self.enabled:
29
+ return nullcontext()
30
+ return self._autocast()
31
+
32
+ def scale(self, loss: torch.Tensor) -> torch.Tensor:
33
+ if not self.enabled:
34
+ return loss
35
+ return self.scaler.scale(loss)
36
+
37
+ def step(self, optimizer: torch.optim.Optimizer) -> None:
38
+ if not self.enabled:
39
+ optimizer.step()
40
+ return
41
+ self.scaler.step(optimizer)
42
+
43
+ def update(self) -> None:
44
+ if self.enabled:
45
+ self.scaler.update()
46
+
47
+ def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
48
+ if self.enabled:
49
+ self.scaler.unscale_(optimizer)
50
+
51
+ def state_dict(self):
52
+ if not self.enabled:
53
+ return {}
54
+ return self.scaler.state_dict()
55
+
56
+ def load_state_dict(self, state_dict):
57
+ if self.enabled:
58
+ self.scaler.load_state_dict(state_dict)
59
+
60
+ @property
61
+ def is_enabled(self) -> bool:
62
+ return self.enabled
63
+
64
+
65
+ def initialize_scaler(enabled: bool = True) -> AMPScaler:
66
+ return AMPScaler(enabled=enabled)
67
+
68
+
69
+ def scale_loss(loss: torch.Tensor, scaler: AMPScaler | None) -> torch.Tensor:
70
+ if scaler is None:
71
+ return loss
72
+ return scaler.scale(loss)
73
+
74
+
75
+ def step_optimizer(optimizer: torch.optim.Optimizer, scaler: AMPScaler | None) -> None:
76
+ if scaler is None:
77
+ optimizer.step()
78
+ return
79
+ scaler.step(optimizer)
80
+
81
+
82
+ def update_scaler(scaler: AMPScaler | None) -> None:
83
+ if scaler is not None:
84
+ scaler.update()
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(slots=True)
7
+ class ModelConfig:
8
+ vocab_size: int
9
+ embed_dim: int
10
+ num_layers: int
11
+ num_heads: int
12
+ seq_len: int
13
+ hidden_dim: int
14
+ dropout: float = 0.0
15
+
16
+
17
+ @dataclass(slots=True)
18
+ class TrainingConfig:
19
+ max_epochs: int = 1
20
+ max_train_steps: int | None = None
21
+ max_eval_steps: int | None = None
22
+ eval_every_n_epochs: int = 1
23
+ save_every_n_epochs: int = 1
24
+ grad_accumulation_step: int = 1
25
+ max_grad_norm: float | None = None
26
+ lr: float = 3e-4
27
+ weight_decay: float = 0.01
28
+ optimizer_name: str = "adamw"
29
+ scheduler_name: str = "none"
30
+ warmup_steps: int = 0
31
+ total_steps: int | None = None
32
+
33
+
34
+ @dataclass(slots=True)
35
+ class GenerationConfig:
36
+ max_context_len: int = 128
37
+ max_new_tokens: int = 50
38
+ temperature: float = 1.0
39
+ top_k: int | None = None
40
+ top_p: float = 1.0
41
+ eos_token_id: int | None = None
@@ -0,0 +1,31 @@
1
+ """Distributed training helpers for StackFormer."""
2
+
3
+ from .ddp import (
4
+ barrier,
5
+ cleanup,
6
+ cleanup_distributed,
7
+ distributed_sampler,
8
+ get_local_rank,
9
+ get_rank,
10
+ get_world_size,
11
+ init_distributed,
12
+ is_distributed,
13
+ is_main_process,
14
+ setup_ddp,
15
+ wrap_model_ddp,
16
+ )
17
+
18
+ __all__ = [
19
+ "init_distributed",
20
+ "setup_ddp",
21
+ "wrap_model_ddp",
22
+ "distributed_sampler",
23
+ "get_rank",
24
+ "get_world_size",
25
+ "get_local_rank",
26
+ "is_distributed",
27
+ "is_main_process",
28
+ "barrier",
29
+ "cleanup",
30
+ "cleanup_distributed",
31
+ ]