shadax 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shadax-0.2.0/LICENSE +21 -0
- shadax-0.2.0/PKG-INFO +508 -0
- shadax-0.2.0/README.md +481 -0
- shadax-0.2.0/pyproject.toml +51 -0
- shadax-0.2.0/setup.cfg +4 -0
- shadax-0.2.0/shadax/__init__.py +50 -0
- shadax-0.2.0/shadax/config.py +294 -0
- shadax-0.2.0/shadax/core.py +34 -0
- shadax-0.2.0/shadax/encoder.py +301 -0
- shadax-0.2.0/shadax/heads.py +285 -0
- shadax-0.2.0/shadax/model.py +606 -0
- shadax-0.2.0/shadax/modules.py +316 -0
- shadax-0.2.0/shadax/network.py +204 -0
- shadax-0.2.0/shadax/py.typed +0 -0
- shadax-0.2.0/shadax/ssl.py +325 -0
- shadax-0.2.0/shadax/training.py +346 -0
- shadax-0.2.0/shadax.egg-info/PKG-INFO +508 -0
- shadax-0.2.0/shadax.egg-info/SOURCES.txt +25 -0
- shadax-0.2.0/shadax.egg-info/dependency_links.txt +1 -0
- shadax-0.2.0/shadax.egg-info/requires.txt +5 -0
- shadax-0.2.0/shadax.egg-info/top_level.txt +1 -0
- shadax-0.2.0/tests/test_config.py +83 -0
- shadax-0.2.0/tests/test_encoder.py +95 -0
- shadax-0.2.0/tests/test_heads.py +91 -0
- shadax-0.2.0/tests/test_model.py +193 -0
- shadax-0.2.0/tests/test_pipeline.py +179 -0
- shadax-0.2.0/tests/test_ssl.py +80 -0
shadax-0.2.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Omar Alghafri
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
shadax-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: shadax
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: SHADA: a self-supervised hierarchical hybrid (CNN + Transformer) multi-modal model for vision and text
|
|
5
|
+
Author: Omar
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/OmarAlghafri/SHADA-API-Core-Reference
|
|
8
|
+
Project-URL: Repository, https://github.com/OmarAlghafri/SHADA-API-Core-Reference
|
|
9
|
+
Keywords: deep-learning,pytorch,transformer,cnn,multimodal,self-supervised,computer-vision,nlp
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.8
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: torch>=2.0.0
|
|
23
|
+
Requires-Dist: numpy>=1.21.0
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
|
|
28
|
+
# shadax
|
|
29
|
+
|
|
30
|
+
**SHADA** — the **S**elf-supervised **H**ierarchical **A**daptive **H**ybrid **A**lgorithm — is a single
|
|
31
|
+
PyTorch architecture that pairs a convolutional stem with four hierarchical Transformer stages into one
|
|
32
|
+
**hybrid** backbone, then shares that backbone across **two modalities** (images and text) and **four
|
|
33
|
+
downstream tasks** (classification, segmentation, language modeling, detection). It is **adaptive** to
|
|
34
|
+
input resolution because the image path uses conditional (depthwise-convolutional) positional encoding
|
|
35
|
+
rather than a fixed positional table, and it is **self-supervised**: it ships with masked-image-modeling
|
|
36
|
+
(MAE-style) and masked-language-modeling (BERT-style) objectives and a four-phase training pipeline
|
|
37
|
+
(pretrain → multitask → finetune → deploy). Everything is exposed behind a familiar scikit-learn-style
|
|
38
|
+
estimator (`fit` / `predict` / `score`).
|
|
39
|
+
|
|
40
|
+
This README documents the library exactly as implemented. There are **no** bundled pretrained weights,
|
|
41
|
+
**no** bundled datasets, and **no** CUDA-only features — the model runs on CPU or GPU.
|
|
42
|
+
|
|
43
|
+
## Key features
|
|
44
|
+
|
|
45
|
+
- **Hybrid hierarchical backbone** — a `/4` convolutional stem followed by four Transformer stages, with
|
|
46
|
+
`/2` patch-merging downsamples between consecutive stages (total spatial reduction **32**).
|
|
47
|
+
- **One shared multi-modal encoder** — the *same* `HierarchicalEncoder` processes images `(B, C, H, W)`
|
|
48
|
+
and text token ids `(B, L)`; the Transformer blocks are shared between the two paths.
|
|
49
|
+
- **Resolution-adaptive** — images use conditional positional encoding (a residual depthwise conv), so
|
|
50
|
+
there is no fixed positional table for images; any `H`, `W` divisible by 32 works.
|
|
51
|
+
- **Four model tiers** — `nano` / `base` / `large` / `xl`.
|
|
52
|
+
- **Four task heads** — classification, segmentation, language model (`lm`), detection (anchor-free
|
|
53
|
+
CenterNet-style) — all really implemented.
|
|
54
|
+
- **Self-supervised objectives** — masked image modeling (`mask_ratio=0.75` default) and masked language
|
|
55
|
+
modeling (`text_mask_ratio=0.15` default).
|
|
56
|
+
- **Four-phase training pipeline** — `pretrain`, `multitask`, `finetune`, `deploy`.
|
|
57
|
+
- **scikit-learn-style API** — `fit` / `predict` / `predict_proba` / `score` / `extract_features` /
|
|
58
|
+
`save` / `load`.
|
|
59
|
+
|
|
60
|
+
## Installation
|
|
61
|
+
|
|
62
|
+
PyTorch is a **hard dependency** (`torch>=2.0`), along with `numpy`. PyTorch is not bundled — install the
|
|
63
|
+
build appropriate for your platform/accelerator (see <https://pytorch.org>) if `pip` does not resolve a
|
|
64
|
+
suitable wheel automatically.
|
|
65
|
+
|
|
66
|
+
### From PyPI (recommended)
|
|
67
|
+
```bash
|
|
68
|
+
pip install shadax
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
### From GitHub
|
|
72
|
+
```bash
|
|
73
|
+
pip install git+https://github.com/OmarAlghafri/SHADA-API-Core-Reference.git
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
### From a local wheel
|
|
77
|
+
```bash
|
|
78
|
+
pip install dist/*.whl
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
## Quickstart (image classification)
|
|
82
|
+
|
|
83
|
+
```python
|
|
84
|
+
import numpy as np
|
|
85
|
+
from shadax import SHADA
|
|
86
|
+
|
|
87
|
+
# Tiny synthetic dataset: 16 RGB images, 64x64 (H, W must be divisible by 32).
|
|
88
|
+
X = np.random.randn(16, 3, 64, 64).astype("float32")
|
|
89
|
+
y = np.random.randint(0, 10, size=16)
|
|
90
|
+
|
|
91
|
+
# Build a classification model (tier + number of classes).
|
|
92
|
+
model = SHADA(tier="nano", task="classification", num_classes=10)
|
|
93
|
+
|
|
94
|
+
# fit accepts epochs= directly. With y given, the default phase is [FINETUNE].
|
|
95
|
+
model.fit(X, y, epochs=2)
|
|
96
|
+
|
|
97
|
+
# predict -> (N,) integer labels; score -> top-1 accuracy in [0, 1].
|
|
98
|
+
preds = model.predict(X)
|
|
99
|
+
acc = model.score(X, y)
|
|
100
|
+
print(preds.shape, acc)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## Architecture overview
|
|
104
|
+
|
|
105
|
+
The backbone is a `HierarchicalEncoder` driven entirely by a `SHADAConfig`. It has two input paths that
|
|
106
|
+
**share their per-stage Transformer blocks** — both paths ultimately run the blocks over a token sequence
|
|
107
|
+
`(B, N, dim)`.
|
|
108
|
+
|
|
109
|
+
**Image path**
|
|
110
|
+
|
|
111
|
+
```
|
|
112
|
+
image (B, C, H, W)
|
|
113
|
+
| ConvStem -> (B, dims[0], H/4, W/4) [/4]
|
|
114
|
+
v
|
|
115
|
+
Stage 0 CPE + Transformer blocks (dims[0])
|
|
116
|
+
| StageDownsample (2x2, s2) -> (B, dims[1], H/8, W/8) [/2]
|
|
117
|
+
v
|
|
118
|
+
Stage 1 CPE + Transformer blocks (dims[1])
|
|
119
|
+
| StageDownsample -> (B, dims[2], H/16, W/16) [/2]
|
|
120
|
+
v
|
|
121
|
+
Stage 2 CPE + Transformer blocks (dims[2])
|
|
122
|
+
| StageDownsample -> (B, dims[3], H/32, W/32) [/2]
|
|
123
|
+
v
|
|
124
|
+
Stage 3 CPE + Transformer blocks (dims[3])
|
|
125
|
+
|
|
|
126
|
+
v
|
|
127
|
+
feature_maps: 4 maps at strides 4, 8, 16, 32
|
|
128
|
+
tokens: (B, (H/32)*(W/32), dims[-1])
|
|
129
|
+
global_features: (B, dims[-1]) (mean-pooled over tokens)
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
- **Conv stem** (`/4`): two stride-2 `3x3` convolutions with normalization and GELU.
|
|
133
|
+
- **Conditional positional encoding (CPE)**: a residual depthwise `3x3` convolution applied at the start
|
|
134
|
+
of each image stage. Because it is convolutional it adapts to arbitrary resolution — there is no fixed
|
|
135
|
+
positional table for images. This is what makes the image path **resolution-adaptive**.
|
|
136
|
+
- **Patch-merging downsample** (`/2`): a normalization plus a stride-2 `2x2` convolution between stages,
|
|
137
|
+
which also widens the channels (`dims[i] -> dims[i+1]`).
|
|
138
|
+
- **Spatial-32 constraint**: stem `/4` × three stage `/2` downsamples = **`/32`** total, so image `H` and
|
|
139
|
+
`W` must each be divisible by 32. Both `SHADAConfig.validate` and the encoder enforce this.
|
|
140
|
+
|
|
141
|
+
**Text path**
|
|
142
|
+
|
|
143
|
+
The same four stages of Transformer blocks run over a token sequence produced by a learned token +
|
|
144
|
+
positional `TextEmbedding`. Between stages a length-preserving `TextStageProject` (LayerNorm + Linear)
|
|
145
|
+
changes the channel width (`dims[i] -> dims[i+1]`) without changing the sequence length. For the language
|
|
146
|
+
model task a causal attention mask is applied so each position only attends to itself and earlier
|
|
147
|
+
positions.
|
|
148
|
+
|
|
149
|
+
**Encoder output contract** — `encoder(x, modality=...)` returns a dict with keys:
|
|
150
|
+
|
|
151
|
+
| key | image | text |
|
|
152
|
+
| ----------------- | ---------------------------------- | ----------------------------- |
|
|
153
|
+
| `feature_maps` | list of 4 maps, strides 4/8/16/32 | list of 4 `(B, L, dims[i])` |
|
|
154
|
+
| `tokens` | `(B, (H/32)*(W/32), dims[-1])` | `(B, L, dims[-1])` |
|
|
155
|
+
| `global_features` | `(B, dims[-1])` | `(B, dims[-1])` |
|
|
156
|
+
| `hw` | `(Hs, Ws)` int tuple | `None` |
|
|
157
|
+
| `modality` | `"image"` | `"text"` |
|
|
158
|
+
|
|
159
|
+
## The four tasks
|
|
160
|
+
|
|
161
|
+
Select a task with `task=` (or via `create_config(...)`). Each task has a real head, a real loss, and a
|
|
162
|
+
task-specific `predict` / `score` contract.
|
|
163
|
+
|
|
164
|
+
### Classification
|
|
165
|
+
|
|
166
|
+
- **Input** `X`: `(N, C, H, W)` images (`H`, `W` divisible by 32).
|
|
167
|
+
- **Target** `y`: `(N,)` integer labels (one-hot `(N, num_classes)` is also accepted and argmaxed).
|
|
168
|
+
- **`predict(X)`** → `(N,)` integer labels. With `return_probs=True`, returns `(labels, probs)` where
|
|
169
|
+
`probs` is `(N, num_classes)`.
|
|
170
|
+
- **`score(X, y)`** → top-1 accuracy.
|
|
171
|
+
|
|
172
|
+
```python
|
|
173
|
+
import numpy as np
|
|
174
|
+
from shadax import SHADA
|
|
175
|
+
|
|
176
|
+
X = np.random.randn(16, 3, 64, 64).astype("float32")
|
|
177
|
+
y = np.random.randint(0, 5, size=16)
|
|
178
|
+
|
|
179
|
+
model = SHADA(tier="nano", task="classification", num_classes=5)
|
|
180
|
+
model.fit(X, y, epochs=2)
|
|
181
|
+
|
|
182
|
+
labels = model.predict(X) # (N,)
|
|
183
|
+
labels, probs = model.predict(X, return_probs=True) # (N,), (N, 5)
|
|
184
|
+
print(labels.shape, probs.shape, model.score(X, y))
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
### Segmentation
|
|
188
|
+
|
|
189
|
+
- **Input** `X`: `(N, C, H, W)` images.
|
|
190
|
+
- **Target** `y`: `(N, H, W)` integer per-pixel class masks.
|
|
191
|
+
- **`predict(X)`** → `(N, H, W)` per-pixel argmax labels.
|
|
192
|
+
- **`score(X, y)`** → mean per-pixel accuracy.
|
|
193
|
+
|
|
194
|
+
```python
|
|
195
|
+
import numpy as np
|
|
196
|
+
from shadax import SHADA
|
|
197
|
+
|
|
198
|
+
X = np.random.randn(8, 3, 64, 64).astype("float32")
|
|
199
|
+
y = np.random.randint(0, 4, size=(8, 64, 64)) # (N, H, W) masks
|
|
200
|
+
|
|
201
|
+
model = SHADA(tier="nano", task="segmentation", num_classes=4)
|
|
202
|
+
model.fit(X, y, epochs=2)
|
|
203
|
+
|
|
204
|
+
masks = model.predict(X) # (N, H, W)
|
|
205
|
+
print(masks.shape, model.score(X, y))
|
|
206
|
+
```
|
|
207
|
+
|
|
208
|
+
### Language model (`lm`)
|
|
209
|
+
|
|
210
|
+
- **Input** `X`: `(N, L)` integer token ids. Reserve id `vocab_size - 1` as the `[MASK]` token (used by
|
|
211
|
+
masked-LM pretraining); do not assign it to a real token.
|
|
212
|
+
- **Target** `y`: the `lm` loss is next-token prediction over `X` itself, so **no `y` is needed** — for the
|
|
213
|
+
language-model task even the supervised phases (`FINETUNE` / `MULTITASK`) accept `y=None`. Self-supervised
|
|
214
|
+
pretraining likewise uses `fit(X)` / `pretrain(X)`.
|
|
215
|
+
- **`predict(X)`** → `(N, L)` per-position argmax (next-token) ids.
|
|
216
|
+
- **`score(X, y)`** → next-token accuracy (positions equal to `pad_token_id` are ignored). Pass the same
|
|
217
|
+
token-id array as `y`.
|
|
218
|
+
|
|
219
|
+
```python
|
|
220
|
+
import numpy as np
|
|
221
|
+
from shadax import SHADA, TrainingPhase
|
|
222
|
+
|
|
223
|
+
vocab_size = 256
|
|
224
|
+
# Keep ids in [1, vocab_size-2]: id 0 is padding, id vocab_size-1 is [MASK].
|
|
225
|
+
X = np.random.randint(1, vocab_size - 1, size=(8, 32))
|
|
226
|
+
|
|
227
|
+
model = SHADA(tier="nano", task="lm", vocab_size=vocab_size, max_seq_len=64)
|
|
228
|
+
# Supervised next-token finetuning. lm derives its targets from X itself, so no y.
|
|
229
|
+
model.fit(X, phases=[TrainingPhase.FINETUNE], epochs=2)
|
|
230
|
+
|
|
231
|
+
next_ids = model.predict(X) # (N, L)
|
|
232
|
+
print(next_ids.shape, model.score(X, X))
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
### Detection (anchor-free, CenterNet-style)
|
|
236
|
+
|
|
237
|
+
The detection head predicts dense maps on the **stride-32 grid** `(Hs, Ws) = (H/32, W/32)`.
|
|
238
|
+
|
|
239
|
+
- **Input** `X`: `(N, C, H, W)` images.
|
|
240
|
+
- **Target** `y`: a dict of dense CenterNet targets on `(Hs, Ws)`:
|
|
241
|
+
|
|
242
|
+
| key | shape | meaning |
|
|
243
|
+
| ---------- | ------------------ | -------------------------------------- |
|
|
244
|
+
| `heatmap` | `(N, C, Hs, Ws)` | per-class center heatmap, float `[0,1]`|
|
|
245
|
+
| `wh` | `(N, 2, Hs, Ws)` | box width/height at each location |
|
|
246
|
+
| `offset` | `(N, 2, Hs, Ws)` | sub-pixel center offset |
|
|
247
|
+
| `reg_mask` | `(N, 1, Hs, Ws)` | `1` where a center exists, else `0` |
|
|
248
|
+
|
|
249
|
+
Here `C` is `num_classes` (object categories).
|
|
250
|
+
- **`predict(X)`** → a length-`N` list of dicts `{"boxes": (k, 4), "scores": (k,), "labels": (k,)}`, where
|
|
251
|
+
each box is `(cx, cy, w, h)` in stride-32 grid coordinates. `k` is the top-`k` peaks kept
|
|
252
|
+
(`min(20, num_classes * Hs * Ws)`), so on a tiny grid `k` may be smaller than 20.
|
|
253
|
+
- **`score(X, y)`** raises `NotImplementedError` — use a dedicated mAP metric (e.g.
|
|
254
|
+
`torchmetrics.detection.MeanAveragePrecision` or `pycocotools`).
|
|
255
|
+
|
|
256
|
+
```python
|
|
257
|
+
import numpy as np
|
|
258
|
+
from shadax import SHADA
|
|
259
|
+
|
|
260
|
+
N, C_in, H, W = 4, 3, 64, 64
|
|
261
|
+
num_classes = 3
|
|
262
|
+
Hs, Ws = H // 32, W // 32 # stride-32 grid (here 2 x 2)
|
|
263
|
+
|
|
264
|
+
X = np.random.randn(N, C_in, H, W).astype("float32")
|
|
265
|
+
|
|
266
|
+
# Build the dense CenterNet target dict. Here we plant one object per image.
|
|
267
|
+
heatmap = np.zeros((N, num_classes, Hs, Ws), dtype="float32")
|
|
268
|
+
wh = np.zeros((N, 2, Hs, Ws), dtype="float32")
|
|
269
|
+
offset = np.zeros((N, 2, Hs, Ws), dtype="float32")
|
|
270
|
+
reg_mask = np.zeros((N, 1, Hs, Ws), dtype="float32")
|
|
271
|
+
for i in range(N):
|
|
272
|
+
cls, gy, gx = i % num_classes, i % Hs, i % Ws
|
|
273
|
+
heatmap[i, cls, gy, gx] = 1.0 # a center peak for class `cls`
|
|
274
|
+
wh[i, :, gy, gx] = [1.5, 1.0] # box size at that center
|
|
275
|
+
offset[i, :, gy, gx] = [0.2, 0.1] # sub-pixel offset
|
|
276
|
+
reg_mask[i, 0, gy, gx] = 1.0 # mark the center location
|
|
277
|
+
|
|
278
|
+
y = {"heatmap": heatmap, "wh": wh, "offset": offset, "reg_mask": reg_mask}
|
|
279
|
+
|
|
280
|
+
model = SHADA(tier="nano", task="detection", num_classes=num_classes)
|
|
281
|
+
model.fit(X, y, epochs=2)
|
|
282
|
+
|
|
283
|
+
dets = model.predict(X) # list of N dicts
|
|
284
|
+
print(len(dets), dets[0]["boxes"].shape, dets[0]["scores"].shape, dets[0]["labels"].shape)
|
|
285
|
+
```
|
|
286
|
+
|
|
287
|
+
## Self-supervised pretraining
|
|
288
|
+
|
|
289
|
+
Two objectives ship with the library, selected automatically from the input modality:
|
|
290
|
+
|
|
291
|
+
- **Masked image modeling (MIM)** — MAE-style. A random fraction (`mask_ratio`, default `0.75`) of the
|
|
292
|
+
image patches (at granularity 32) is replaced by a learnable mask token; the masked image is encoded and
|
|
293
|
+
a lightweight convolutional decoder reconstructs the original pixels. The loss is the MSE over masked
|
|
294
|
+
pixels only.
|
|
295
|
+
- **Masked language modeling (MLM)** — BERT-style. A random fraction (`text_mask_ratio`, default `0.15`)
|
|
296
|
+
of the non-padding tokens is replaced by the reserved `[MASK]` id (`vocab_size - 1`); the masked
|
|
297
|
+
sequence is encoded and the LM head predicts the originals. The loss is cross-entropy over the masked
|
|
298
|
+
positions only.
|
|
299
|
+
|
|
300
|
+
Call `pretrain(X, ...)` (a convenience wrapper for `fit(X, y=None, phases=[PRETRAIN], ...)`):
|
|
301
|
+
|
|
302
|
+
```python
|
|
303
|
+
import numpy as np
|
|
304
|
+
from shadax import SHADA
|
|
305
|
+
|
|
306
|
+
# Image pretraining (MIM). No labels.
|
|
307
|
+
X_img = np.random.randn(16, 3, 64, 64).astype("float32")
|
|
308
|
+
img_model = SHADA(tier="nano", task="classification", num_classes=10)
|
|
309
|
+
img_model.pretrain(X_img, epochs=2)
|
|
310
|
+
|
|
311
|
+
# Text pretraining (MLM). Reserve id vocab_size-1 as [MASK].
|
|
312
|
+
vocab_size = 256
|
|
313
|
+
X_txt = np.random.randint(1, vocab_size - 1, size=(16, 32))
|
|
314
|
+
txt_model = SHADA(tier="nano", task="lm", vocab_size=vocab_size, max_seq_len=64)
|
|
315
|
+
txt_model.pretrain(X_txt, epochs=2)
|
|
316
|
+
```
|
|
317
|
+
|
|
318
|
+
## The four-phase training pipeline
|
|
319
|
+
|
|
320
|
+
`TrainingPhase` defines four phases; `fit(..., phases=[...])` runs any sequence of them in order. Each
|
|
321
|
+
optimised phase uses a fresh `AdamW` optimiser and a cosine-annealing schedule.
|
|
322
|
+
|
|
323
|
+
| phase | what it optimises | needs labels? |
|
|
324
|
+
| ----------- | -------------------------------------------------- | ------------- |
|
|
325
|
+
| `PRETRAIN` | self-supervised loss only (MIM or MLM) | no |
|
|
326
|
+
| `MULTITASK` | task loss `+ ssl_weight *` self-supervised loss | yes |
|
|
327
|
+
| `FINETUNE` | task loss only | yes |
|
|
328
|
+
| `DEPLOY` | nothing — switches to `eval` mode, no optimisation | no |
|
|
329
|
+
|
|
330
|
+
**Default phases** when `phases` is not given: `[PRETRAIN]` if `y is None`, otherwise `[FINETUNE]`.
|
|
331
|
+
|
|
332
|
+
```python
|
|
333
|
+
import numpy as np
|
|
334
|
+
from shadax import SHADA, TrainingPhase
|
|
335
|
+
|
|
336
|
+
X = np.random.randn(16, 3, 64, 64).astype("float32")
|
|
337
|
+
y = np.random.randint(0, 10, size=16)
|
|
338
|
+
|
|
339
|
+
model = SHADA(tier="nano", task="classification", num_classes=10)
|
|
340
|
+
|
|
341
|
+
# Full pipeline: self-supervised pretrain, then joint multitask, then finetune.
|
|
342
|
+
model.fit(
|
|
343
|
+
X, y,
|
|
344
|
+
phases=[TrainingPhase.PRETRAIN, TrainingPhase.MULTITASK, TrainingPhase.FINETUNE],
|
|
345
|
+
epochs=2,
|
|
346
|
+
)
|
|
347
|
+
print(model.score(X, y))
|
|
348
|
+
```
|
|
349
|
+
|
|
350
|
+
## Multi-modal usage
|
|
351
|
+
|
|
352
|
+
The lower-level `HierarchicalEncoder` is public and processes both modalities with the *same* weights.
|
|
353
|
+
Build it from a config and call it with `modality="image"` or `modality="text"`.
|
|
354
|
+
|
|
355
|
+
```python
|
|
356
|
+
import torch
|
|
357
|
+
from shadax import HierarchicalEncoder, create_config
|
|
358
|
+
|
|
359
|
+
config = create_config("nano", task="classification", num_classes=10)
|
|
360
|
+
encoder = HierarchicalEncoder(config).eval()
|
|
361
|
+
|
|
362
|
+
# Image batch: (B, C, H, W), H and W divisible by 32.
|
|
363
|
+
images = torch.randn(2, 3, 64, 64)
|
|
364
|
+
img_out = encoder(images, modality="image")
|
|
365
|
+
print("image global:", img_out["global_features"].shape) # (2, dims[-1])
|
|
366
|
+
print("image tokens:", img_out["tokens"].shape) # (2, (H/32)*(W/32), dims[-1])
|
|
367
|
+
|
|
368
|
+
# Text batch: (B, L) integer token ids.
|
|
369
|
+
tokens = torch.randint(0, config.vocab_size, (2, 16))
|
|
370
|
+
txt_out = encoder(tokens, modality="text")
|
|
371
|
+
print("text global:", txt_out["global_features"].shape) # (2, dims[-1])
|
|
372
|
+
print("text tokens:", txt_out["tokens"].shape) # (2, 16, dims[-1])
|
|
373
|
+
```
|
|
374
|
+
|
|
375
|
+
## Model tiers
|
|
376
|
+
|
|
377
|
+
All tiers share `encoder_depths`/`num_heads` *lengths* of 4 (one per stage). `max_seq_len` is the text
|
|
378
|
+
positional-table size for that tier.
|
|
379
|
+
|
|
380
|
+
| tier | `encoder_dims` | `encoder_depths` | `num_heads` | `max_seq_len` |
|
|
381
|
+
| ------ | ------------------------ | ---------------- | --------------- | ------------- |
|
|
382
|
+
| `nano` | `[64, 128, 256, 512]` | `[2, 2, 4, 2]` | `[2, 4, 8, 16]` | 512 |
|
|
383
|
+
| `base` | `[128, 256, 512, 1024]` | `[3, 4, 6, 3]` | `[4, 8, 16, 32]`| 1024 |
|
|
384
|
+
| `large`| `[192, 384, 768, 1536]` | `[3, 4, 18, 3]` | `[6, 12, 24, 48]`| 2048 |
|
|
385
|
+
| `xl` | `[256, 512, 1024, 2048]` | `[3, 4, 24, 3]` | `[8, 16, 32, 64]`| 4096 |
|
|
386
|
+
|
|
387
|
+
## Feature extraction
|
|
388
|
+
|
|
389
|
+
Once fitted, `extract_features(X, layer=...)` returns encoder representations as numpy arrays:
|
|
390
|
+
|
|
391
|
+
- `layer="global"` → pooled global features `(N, final_dim)`.
|
|
392
|
+
- `layer="tokens"` → final token sequence `(N, N_tok, final_dim)`.
|
|
393
|
+
- `layer="spatial"` → last image feature map `(N, final_dim, Hs, Ws)` (image modality).
|
|
394
|
+
|
|
395
|
+
```python
|
|
396
|
+
import numpy as np
|
|
397
|
+
from shadax import SHADA
|
|
398
|
+
|
|
399
|
+
X = np.random.randn(8, 3, 64, 64).astype("float32")
|
|
400
|
+
y = np.random.randint(0, 10, size=8)
|
|
401
|
+
model = SHADA(tier="nano", task="classification", num_classes=10).fit(X, y, epochs=1)
|
|
402
|
+
|
|
403
|
+
g = model.extract_features(X, layer="global") # (8, final_dim)
|
|
404
|
+
t = model.extract_features(X, layer="tokens") # (8, N_tok, final_dim)
|
|
405
|
+
s = model.extract_features(X, layer="spatial") # (8, final_dim, 2, 2)
|
|
406
|
+
print(g.shape, t.shape, s.shape)
|
|
407
|
+
```
|
|
408
|
+
|
|
409
|
+
## Save / load
|
|
410
|
+
|
|
411
|
+
```python
|
|
412
|
+
import numpy as np
|
|
413
|
+
from shadax import SHADA
|
|
414
|
+
|
|
415
|
+
X = np.random.randn(8, 3, 64, 64).astype("float32")
|
|
416
|
+
y = np.random.randint(0, 10, size=8)
|
|
417
|
+
model = SHADA(tier="nano", task="classification", num_classes=10).fit(X, y, epochs=1)
|
|
418
|
+
|
|
419
|
+
model.save("shada_model.pt")
|
|
420
|
+
|
|
421
|
+
reloaded = SHADA(tier="nano", task="classification", num_classes=10)
|
|
422
|
+
reloaded.load("shada_model.pt")
|
|
423
|
+
print(reloaded.is_fitted, reloaded.predict(X).shape)
|
|
424
|
+
```
|
|
425
|
+
|
|
426
|
+
## API reference
|
|
427
|
+
|
|
428
|
+
### `SHADA` (high-level estimator)
|
|
429
|
+
|
|
430
|
+
```python
|
|
431
|
+
SHADA(
|
|
432
|
+
tier="base", # tier string ("nano"/"base"/"large"/"xl") OR a SHADAConfig
|
|
433
|
+
num_classes=1000,
|
|
434
|
+
task="classification", # "classification" / "segmentation" / "lm" / "detection"
|
|
435
|
+
learning_rate=1e-4,
|
|
436
|
+
weight_decay=0.05,
|
|
437
|
+
epochs=100, # default epochs per phase
|
|
438
|
+
batch_size=64,
|
|
439
|
+
device=None, # None -> "cuda" if available else "cpu"
|
|
440
|
+
phases=None, # default phase list; resolved at fit() time
|
|
441
|
+
**kwargs, # extra SHADAConfig overrides (e.g. vocab_size, max_seq_len)
|
|
442
|
+
)
|
|
443
|
+
```
|
|
444
|
+
|
|
445
|
+
Methods:
|
|
446
|
+
|
|
447
|
+
| method | signature | summary |
|
|
448
|
+
| ------ | --------- | ------- |
|
|
449
|
+
| `fit` | `fit(X, y=None, eval_set=None, verbose=True, epochs=None, phases=None) -> self` | Train through the resolved phases. `fit` **does** accept `epochs=`. |
|
|
450
|
+
| `pretrain` | `pretrain(X, epochs=None, verbose=True) -> self` | SSL-only shortcut: `fit(X, y=None, phases=[PRETRAIN], ...)`. |
|
|
451
|
+
| `predict` | `predict(X, return_probs=False)` | Task-specific predictions (see each task above). |
|
|
452
|
+
| `predict_proba` | `predict_proba(X) -> np.ndarray` | Probabilities (classification/segmentation/lm); raises `NotImplementedError` for detection. |
|
|
453
|
+
| `score` | `score(X, y) -> float` | Accuracy in `[0, 1]`; raises `NotImplementedError` for detection. |
|
|
454
|
+
| `extract_features` | `extract_features(X, layer="global")` | `"global"` / `"tokens"` / `"spatial"` features. |
|
|
455
|
+
| `save` | `save(path) -> None` | Save config + weights + hyper-parameters. |
|
|
456
|
+
| `load` | `load(path) -> self` | Restore a saved model. |
|
|
457
|
+
| `is_fitted` | property `-> bool` | Whether the model has been fitted. |
|
|
458
|
+
|
|
459
|
+
### `SHADAConfig` (dataclass — the model-shape contract)
|
|
460
|
+
|
|
461
|
+
Key fields (with defaults): `tier="base"`, `encoder_dims=[128,256,512,1024]`, `encoder_depths=[3,4,6,3]`,
|
|
462
|
+
`num_heads=[4,8,16,32]`, `mlp_ratio=4.0`, `dropout=0.1`, `in_channels=3`, `image_size=224`,
|
|
463
|
+
`max_seq_len=1024`, `vocab_size=50257`, `task="classification"`, `num_classes=1000`, `mask_ratio=0.75`,
|
|
464
|
+
`text_mask_ratio=0.15`, `decoder_dim=256`, `decoder_depth=2`, `pad_token_id=0`. Properties: `embed_dim`,
|
|
465
|
+
`final_dim`, `num_stages`, `task_type`. `validate()` enforces the per-stage list lengths (4 entries each),
|
|
466
|
+
divisibility of each `encoder_dims[i]` by `num_heads[i]`, and the mask-ratio / `num_classes` ranges.
|
|
467
|
+
|
|
468
|
+
### `create_config`
|
|
469
|
+
|
|
470
|
+
```python
|
|
471
|
+
create_config(tier="base", task="classification", num_classes=1000, **overrides) -> SHADAConfig
|
|
472
|
+
```
|
|
473
|
+
|
|
474
|
+
Builds a validated `SHADAConfig` from a tier preset, applying any field `**overrides`.
|
|
475
|
+
|
|
476
|
+
### `HierarchicalEncoder` (shared backbone, `nn.Module`)
|
|
477
|
+
|
|
478
|
+
```python
|
|
479
|
+
HierarchicalEncoder(config: SHADAConfig)
|
|
480
|
+
encoder(x, modality="image", causal=False) -> dict # see the encoder output contract above
|
|
481
|
+
encoder.forward_features(x, modality="image", causal=False) -> list # the 4 feature maps only
|
|
482
|
+
```
|
|
483
|
+
|
|
484
|
+
### `SHADANet` (unified encoder + head + SSL, `nn.Module`)
|
|
485
|
+
|
|
486
|
+
```python
|
|
487
|
+
SHADANet(config: SHADAConfig)
|
|
488
|
+
net(x, modality=None) -> dict # routes the encoder output through the task head
|
|
489
|
+
net.ssl_loss(x, modality=None) -> dict # the modality-matched self-supervised loss
|
|
490
|
+
net.encode(x, modality=None) -> dict # raw encoder output (feature extraction)
|
|
491
|
+
```
|
|
492
|
+
|
|
493
|
+
### Enums (controlled vocabularies)
|
|
494
|
+
|
|
495
|
+
- `ModelTier`: `NANO`, `BASE`, `LARGE`, `XL`.
|
|
496
|
+
- `TaskType`: `CLASSIFICATION`, `DETECTION`, `SEGMENTATION`, `LANGUAGE_MODEL` (value `"lm"`).
|
|
497
|
+
- `TrainingPhase`: `PRETRAIN`, `MULTITASK`, `FINETUNE`, `DEPLOY`.
|
|
498
|
+
- `Modality`: `IMAGE`, `TEXT`.
|
|
499
|
+
|
|
500
|
+
## Requirements
|
|
501
|
+
|
|
502
|
+
- **Python** `>=3.8`.
|
|
503
|
+
- **PyTorch** `>=2.0` (hard dependency).
|
|
504
|
+
- **NumPy**.
|
|
505
|
+
|
|
506
|
+
## License
|
|
507
|
+
|
|
508
|
+
Released under the **MIT License**. See [LICENSE](LICENSE).
|