shadax 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
shadax-0.2.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Omar Alghafri
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
shadax-0.2.0/PKG-INFO ADDED
@@ -0,0 +1,508 @@
1
+ Metadata-Version: 2.4
2
+ Name: shadax
3
+ Version: 0.2.0
4
+ Summary: SHADA: a self-supervised hierarchical hybrid (CNN + Transformer) multi-modal model for vision and text
5
+ Author: Omar
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/OmarAlghafri/SHADA-API-Core-Reference
8
+ Project-URL: Repository, https://github.com/OmarAlghafri/SHADA-API-Core-Reference
9
+ Keywords: deep-learning,pytorch,transformer,cnn,multimodal,self-supervised,computer-vision,nlp
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.8
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.8
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: torch>=2.0.0
23
+ Requires-Dist: numpy>=1.21.0
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest>=7.0; extra == "dev"
26
+ Dynamic: license-file
27
+
28
+ # shadax
29
+
30
+ **SHADA** — the **S**elf-supervised **H**ierarchical **A**daptive **H**ybrid **A**lgorithm — is a single
31
+ PyTorch architecture that pairs a convolutional stem with four hierarchical Transformer stages into one
32
+ **hybrid** backbone, then shares that backbone across **two modalities** (images and text) and **four
33
+ downstream tasks** (classification, segmentation, language modeling, detection). It is **adaptive** to
34
+ input resolution because the image path uses conditional (depthwise-convolutional) positional encoding
35
+ rather than a fixed positional table, and it is **self-supervised**: it ships with masked-image-modeling
36
+ (MAE-style) and masked-language-modeling (BERT-style) objectives and a four-phase training pipeline
37
+ (pretrain → multitask → finetune → deploy). Everything is exposed behind a familiar scikit-learn-style
38
+ estimator (`fit` / `predict` / `score`).
39
+
40
+ This README documents the library exactly as implemented. There are **no** bundled pretrained weights,
41
+ **no** bundled datasets, and **no** CUDA-only features — the model runs on CPU or GPU.
42
+
43
+ ## Key features
44
+
45
+ - **Hybrid hierarchical backbone** — a `/4` convolutional stem followed by four Transformer stages, with
46
+ `/2` patch-merging downsamples between consecutive stages (total spatial reduction **32**).
47
+ - **One shared multi-modal encoder** — the *same* `HierarchicalEncoder` processes images `(B, C, H, W)`
48
+ and text token ids `(B, L)`; the Transformer blocks are shared between the two paths.
49
+ - **Resolution-adaptive** — images use conditional positional encoding (a residual depthwise conv), so
50
+ there is no fixed positional table for images; any `H`, `W` divisible by 32 works.
51
+ - **Four model tiers** — `nano` / `base` / `large` / `xl`.
52
+ - **Four task heads** — classification, segmentation, language model (`lm`), detection (anchor-free
53
+ CenterNet-style) — all really implemented.
54
+ - **Self-supervised objectives** — masked image modeling (`mask_ratio=0.75` default) and masked language
55
+ modeling (`text_mask_ratio=0.15` default).
56
+ - **Four-phase training pipeline** — `pretrain`, `multitask`, `finetune`, `deploy`.
57
+ - **scikit-learn-style API** — `fit` / `predict` / `predict_proba` / `score` / `extract_features` /
58
+ `save` / `load`.
59
+
60
+ ## Installation
61
+
62
+ PyTorch is a **hard dependency** (`torch>=2.0`), along with `numpy`. PyTorch is not bundled — install the
63
+ build appropriate for your platform/accelerator (see <https://pytorch.org>) if `pip` does not resolve a
64
+ suitable wheel automatically.
65
+
66
+ ### From PyPI (recommended)
67
+ ```bash
68
+ pip install shadax
69
+ ```
70
+
71
+ ### From GitHub
72
+ ```bash
73
+ pip install git+https://github.com/OmarAlghafri/SHADA-API-Core-Reference.git
74
+ ```
75
+
76
+ ### From a local wheel
77
+ ```bash
78
+ pip install dist/*.whl
79
+ ```
80
+
81
+ ## Quickstart (image classification)
82
+
83
+ ```python
84
+ import numpy as np
85
+ from shadax import SHADA
86
+
87
+ # Tiny synthetic dataset: 16 RGB images, 64x64 (H, W must be divisible by 32).
88
+ X = np.random.randn(16, 3, 64, 64).astype("float32")
89
+ y = np.random.randint(0, 10, size=16)
90
+
91
+ # Build a classification model (tier + number of classes).
92
+ model = SHADA(tier="nano", task="classification", num_classes=10)
93
+
94
+ # fit accepts epochs= directly. With y given, the default phase is [FINETUNE].
95
+ model.fit(X, y, epochs=2)
96
+
97
+ # predict -> (N,) integer labels; score -> top-1 accuracy in [0, 1].
98
+ preds = model.predict(X)
99
+ acc = model.score(X, y)
100
+ print(preds.shape, acc)
101
+ ```
102
+
103
+ ## Architecture overview
104
+
105
+ The backbone is a `HierarchicalEncoder` driven entirely by a `SHADAConfig`. It has two input paths that
106
+ **share their per-stage Transformer blocks** — both paths ultimately run the blocks over a token sequence
107
+ `(B, N, dim)`.
108
+
109
+ **Image path**
110
+
111
+ ```
112
+ image (B, C, H, W)
113
+ | ConvStem -> (B, dims[0], H/4, W/4) [/4]
114
+ v
115
+ Stage 0 CPE + Transformer blocks (dims[0])
116
+ | StageDownsample (2x2, s2) -> (B, dims[1], H/8, W/8) [/2]
117
+ v
118
+ Stage 1 CPE + Transformer blocks (dims[1])
119
+ | StageDownsample -> (B, dims[2], H/16, W/16) [/2]
120
+ v
121
+ Stage 2 CPE + Transformer blocks (dims[2])
122
+ | StageDownsample -> (B, dims[3], H/32, W/32) [/2]
123
+ v
124
+ Stage 3 CPE + Transformer blocks (dims[3])
125
+ |
126
+ v
127
+ feature_maps: 4 maps at strides 4, 8, 16, 32
128
+ tokens: (B, (H/32)*(W/32), dims[-1])
129
+ global_features: (B, dims[-1]) (mean-pooled over tokens)
130
+ ```
131
+
132
+ - **Conv stem** (`/4`): two stride-2 `3x3` convolutions with normalization and GELU.
133
+ - **Conditional positional encoding (CPE)**: a residual depthwise `3x3` convolution applied at the start
134
+ of each image stage. Because it is convolutional it adapts to arbitrary resolution — there is no fixed
135
+ positional table for images. This is what makes the image path **resolution-adaptive**.
136
+ - **Patch-merging downsample** (`/2`): a normalization plus a stride-2 `2x2` convolution between stages,
137
+ which also widens the channels (`dims[i] -> dims[i+1]`).
138
+ - **Spatial-32 constraint**: stem `/4` × three stage `/2` downsamples = **`/32`** total, so image `H` and
139
+ `W` must each be divisible by 32. Both `SHADAConfig.validate` and the encoder enforce this.
140
+
141
+ **Text path**
142
+
143
+ The same four stages of Transformer blocks run over a token sequence produced by a learned token +
144
+ positional `TextEmbedding`. Between stages a length-preserving `TextStageProject` (LayerNorm + Linear)
145
+ changes the channel width (`dims[i] -> dims[i+1]`) without changing the sequence length. For the language
146
+ model task a causal attention mask is applied so each position only attends to itself and earlier
147
+ positions.
148
+
149
+ **Encoder output contract** — `encoder(x, modality=...)` returns a dict with keys:
150
+
151
+ | key | image | text |
152
+ | ----------------- | ---------------------------------- | ----------------------------- |
153
+ | `feature_maps` | list of 4 maps, strides 4/8/16/32 | list of 4 `(B, L, dims[i])` |
154
+ | `tokens` | `(B, (H/32)*(W/32), dims[-1])` | `(B, L, dims[-1])` |
155
+ | `global_features` | `(B, dims[-1])` | `(B, dims[-1])` |
156
+ | `hw` | `(Hs, Ws)` int tuple | `None` |
157
+ | `modality` | `"image"` | `"text"` |
158
+
159
+ ## The four tasks
160
+
161
+ Select a task with `task=` (or via `create_config(...)`). Each task has a real head, a real loss, and a
162
+ task-specific `predict` / `score` contract.
163
+
164
+ ### Classification
165
+
166
+ - **Input** `X`: `(N, C, H, W)` images (`H`, `W` divisible by 32).
167
+ - **Target** `y`: `(N,)` integer labels (one-hot `(N, num_classes)` is also accepted and argmaxed).
168
+ - **`predict(X)`** → `(N,)` integer labels. With `return_probs=True`, returns `(labels, probs)` where
169
+ `probs` is `(N, num_classes)`.
170
+ - **`score(X, y)`** → top-1 accuracy.
171
+
172
+ ```python
173
+ import numpy as np
174
+ from shadax import SHADA
175
+
176
+ X = np.random.randn(16, 3, 64, 64).astype("float32")
177
+ y = np.random.randint(0, 5, size=16)
178
+
179
+ model = SHADA(tier="nano", task="classification", num_classes=5)
180
+ model.fit(X, y, epochs=2)
181
+
182
+ labels = model.predict(X) # (N,)
183
+ labels, probs = model.predict(X, return_probs=True) # (N,), (N, 5)
184
+ print(labels.shape, probs.shape, model.score(X, y))
185
+ ```
186
+
187
+ ### Segmentation
188
+
189
+ - **Input** `X`: `(N, C, H, W)` images.
190
+ - **Target** `y`: `(N, H, W)` integer per-pixel class masks.
191
+ - **`predict(X)`** → `(N, H, W)` per-pixel argmax labels.
192
+ - **`score(X, y)`** → mean per-pixel accuracy.
193
+
194
+ ```python
195
+ import numpy as np
196
+ from shadax import SHADA
197
+
198
+ X = np.random.randn(8, 3, 64, 64).astype("float32")
199
+ y = np.random.randint(0, 4, size=(8, 64, 64)) # (N, H, W) masks
200
+
201
+ model = SHADA(tier="nano", task="segmentation", num_classes=4)
202
+ model.fit(X, y, epochs=2)
203
+
204
+ masks = model.predict(X) # (N, H, W)
205
+ print(masks.shape, model.score(X, y))
206
+ ```
207
+
208
+ ### Language model (`lm`)
209
+
210
+ - **Input** `X`: `(N, L)` integer token ids. Reserve id `vocab_size - 1` as the `[MASK]` token (used by
211
+ masked-LM pretraining); do not assign it to a real token.
212
+ - **Target** `y`: the `lm` loss is next-token prediction over `X` itself, so **no `y` is needed** — for the
213
+ language-model task even the supervised phases (`FINETUNE` / `MULTITASK`) accept `y=None`. Self-supervised
214
+ pretraining likewise uses `fit(X)` / `pretrain(X)`.
215
+ - **`predict(X)`** → `(N, L)` per-position argmax (next-token) ids.
216
+ - **`score(X, y)`** → next-token accuracy (positions equal to `pad_token_id` are ignored). Pass the same
217
+ token-id array as `y`.
218
+
219
+ ```python
220
+ import numpy as np
221
+ from shadax import SHADA, TrainingPhase
222
+
223
+ vocab_size = 256
224
+ # Keep ids in [1, vocab_size-2]: id 0 is padding, id vocab_size-1 is [MASK].
225
+ X = np.random.randint(1, vocab_size - 1, size=(8, 32))
226
+
227
+ model = SHADA(tier="nano", task="lm", vocab_size=vocab_size, max_seq_len=64)
228
+ # Supervised next-token finetuning. lm derives its targets from X itself, so no y.
229
+ model.fit(X, phases=[TrainingPhase.FINETUNE], epochs=2)
230
+
231
+ next_ids = model.predict(X) # (N, L)
232
+ print(next_ids.shape, model.score(X, X))
233
+ ```
234
+
235
+ ### Detection (anchor-free, CenterNet-style)
236
+
237
+ The detection head predicts dense maps on the **stride-32 grid** `(Hs, Ws) = (H/32, W/32)`.
238
+
239
+ - **Input** `X`: `(N, C, H, W)` images.
240
+ - **Target** `y`: a dict of dense CenterNet targets on `(Hs, Ws)`:
241
+
242
+ | key | shape | meaning |
243
+ | ---------- | ------------------ | -------------------------------------- |
244
+ | `heatmap` | `(N, C, Hs, Ws)` | per-class center heatmap, float `[0,1]`|
245
+ | `wh` | `(N, 2, Hs, Ws)` | box width/height at each location |
246
+ | `offset` | `(N, 2, Hs, Ws)` | sub-pixel center offset |
247
+ | `reg_mask` | `(N, 1, Hs, Ws)` | `1` where a center exists, else `0` |
248
+
249
+ Here `C` is `num_classes` (object categories).
250
+ - **`predict(X)`** → a length-`N` list of dicts `{"boxes": (k, 4), "scores": (k,), "labels": (k,)}`, where
251
+ each box is `(cx, cy, w, h)` in stride-32 grid coordinates. `k` is the top-`k` peaks kept
252
+ (`min(20, num_classes * Hs * Ws)`), so on a tiny grid `k` may be smaller than 20.
253
+ - **`score(X, y)`** raises `NotImplementedError` — use a dedicated mAP metric (e.g.
254
+ `torchmetrics.detection.MeanAveragePrecision` or `pycocotools`).
255
+
256
+ ```python
257
+ import numpy as np
258
+ from shadax import SHADA
259
+
260
+ N, C_in, H, W = 4, 3, 64, 64
261
+ num_classes = 3
262
+ Hs, Ws = H // 32, W // 32 # stride-32 grid (here 2 x 2)
263
+
264
+ X = np.random.randn(N, C_in, H, W).astype("float32")
265
+
266
+ # Build the dense CenterNet target dict. Here we plant one object per image.
267
+ heatmap = np.zeros((N, num_classes, Hs, Ws), dtype="float32")
268
+ wh = np.zeros((N, 2, Hs, Ws), dtype="float32")
269
+ offset = np.zeros((N, 2, Hs, Ws), dtype="float32")
270
+ reg_mask = np.zeros((N, 1, Hs, Ws), dtype="float32")
271
+ for i in range(N):
272
+ cls, gy, gx = i % num_classes, i % Hs, i % Ws
273
+ heatmap[i, cls, gy, gx] = 1.0 # a center peak for class `cls`
274
+ wh[i, :, gy, gx] = [1.5, 1.0] # box size at that center
275
+ offset[i, :, gy, gx] = [0.2, 0.1] # sub-pixel offset
276
+ reg_mask[i, 0, gy, gx] = 1.0 # mark the center location
277
+
278
+ y = {"heatmap": heatmap, "wh": wh, "offset": offset, "reg_mask": reg_mask}
279
+
280
+ model = SHADA(tier="nano", task="detection", num_classes=num_classes)
281
+ model.fit(X, y, epochs=2)
282
+
283
+ dets = model.predict(X) # list of N dicts
284
+ print(len(dets), dets[0]["boxes"].shape, dets[0]["scores"].shape, dets[0]["labels"].shape)
285
+ ```
286
+
287
+ ## Self-supervised pretraining
288
+
289
+ Two objectives ship with the library, selected automatically from the input modality:
290
+
291
+ - **Masked image modeling (MIM)** — MAE-style. A random fraction (`mask_ratio`, default `0.75`) of the
292
+ image patches (at granularity 32) is replaced by a learnable mask token; the masked image is encoded and
293
+ a lightweight convolutional decoder reconstructs the original pixels. The loss is the MSE over masked
294
+ pixels only.
295
+ - **Masked language modeling (MLM)** — BERT-style. A random fraction (`text_mask_ratio`, default `0.15`)
296
+ of the non-padding tokens is replaced by the reserved `[MASK]` id (`vocab_size - 1`); the masked
297
+ sequence is encoded and the LM head predicts the originals. The loss is cross-entropy over the masked
298
+ positions only.
299
+
300
+ Call `pretrain(X, ...)` (a convenience wrapper for `fit(X, y=None, phases=[PRETRAIN], ...)`):
301
+
302
+ ```python
303
+ import numpy as np
304
+ from shadax import SHADA
305
+
306
+ # Image pretraining (MIM). No labels.
307
+ X_img = np.random.randn(16, 3, 64, 64).astype("float32")
308
+ img_model = SHADA(tier="nano", task="classification", num_classes=10)
309
+ img_model.pretrain(X_img, epochs=2)
310
+
311
+ # Text pretraining (MLM). Reserve id vocab_size-1 as [MASK].
312
+ vocab_size = 256
313
+ X_txt = np.random.randint(1, vocab_size - 1, size=(16, 32))
314
+ txt_model = SHADA(tier="nano", task="lm", vocab_size=vocab_size, max_seq_len=64)
315
+ txt_model.pretrain(X_txt, epochs=2)
316
+ ```
317
+
318
+ ## The four-phase training pipeline
319
+
320
+ `TrainingPhase` defines four phases; `fit(..., phases=[...])` runs any sequence of them in order. Each
321
+ optimised phase uses a fresh `AdamW` optimiser and a cosine-annealing schedule.
322
+
323
+ | phase | what it optimises | needs labels? |
324
+ | ----------- | -------------------------------------------------- | ------------- |
325
+ | `PRETRAIN` | self-supervised loss only (MIM or MLM) | no |
326
+ | `MULTITASK` | task loss `+ ssl_weight *` self-supervised loss | yes |
327
+ | `FINETUNE` | task loss only | yes |
328
+ | `DEPLOY` | nothing — switches to `eval` mode, no optimisation | no |
329
+
330
+ **Default phases** when `phases` is not given: `[PRETRAIN]` if `y is None`, otherwise `[FINETUNE]`.
331
+
332
+ ```python
333
+ import numpy as np
334
+ from shadax import SHADA, TrainingPhase
335
+
336
+ X = np.random.randn(16, 3, 64, 64).astype("float32")
337
+ y = np.random.randint(0, 10, size=16)
338
+
339
+ model = SHADA(tier="nano", task="classification", num_classes=10)
340
+
341
+ # Full pipeline: self-supervised pretrain, then joint multitask, then finetune.
342
+ model.fit(
343
+ X, y,
344
+ phases=[TrainingPhase.PRETRAIN, TrainingPhase.MULTITASK, TrainingPhase.FINETUNE],
345
+ epochs=2,
346
+ )
347
+ print(model.score(X, y))
348
+ ```
349
+
350
+ ## Multi-modal usage
351
+
352
+ The lower-level `HierarchicalEncoder` is public and processes both modalities with the *same* weights.
353
+ Build it from a config and call it with `modality="image"` or `modality="text"`.
354
+
355
+ ```python
356
+ import torch
357
+ from shadax import HierarchicalEncoder, create_config
358
+
359
+ config = create_config("nano", task="classification", num_classes=10)
360
+ encoder = HierarchicalEncoder(config).eval()
361
+
362
+ # Image batch: (B, C, H, W), H and W divisible by 32.
363
+ images = torch.randn(2, 3, 64, 64)
364
+ img_out = encoder(images, modality="image")
365
+ print("image global:", img_out["global_features"].shape) # (2, dims[-1])
366
+ print("image tokens:", img_out["tokens"].shape) # (2, (H/32)*(W/32), dims[-1])
367
+
368
+ # Text batch: (B, L) integer token ids.
369
+ tokens = torch.randint(0, config.vocab_size, (2, 16))
370
+ txt_out = encoder(tokens, modality="text")
371
+ print("text global:", txt_out["global_features"].shape) # (2, dims[-1])
372
+ print("text tokens:", txt_out["tokens"].shape) # (2, 16, dims[-1])
373
+ ```
374
+
375
+ ## Model tiers
376
+
377
+ All tiers share `encoder_depths`/`num_heads` *lengths* of 4 (one per stage). `max_seq_len` is the text
378
+ positional-table size for that tier.
379
+
380
+ | tier | `encoder_dims` | `encoder_depths` | `num_heads` | `max_seq_len` |
381
+ | ------ | ------------------------ | ---------------- | --------------- | ------------- |
382
+ | `nano` | `[64, 128, 256, 512]` | `[2, 2, 4, 2]` | `[2, 4, 8, 16]` | 512 |
383
+ | `base` | `[128, 256, 512, 1024]` | `[3, 4, 6, 3]` | `[4, 8, 16, 32]`| 1024 |
384
+ | `large`| `[192, 384, 768, 1536]` | `[3, 4, 18, 3]` | `[6, 12, 24, 48]`| 2048 |
385
+ | `xl` | `[256, 512, 1024, 2048]` | `[3, 4, 24, 3]` | `[8, 16, 32, 64]`| 4096 |
386
+
387
+ ## Feature extraction
388
+
389
+ Once fitted, `extract_features(X, layer=...)` returns encoder representations as numpy arrays:
390
+
391
+ - `layer="global"` → pooled global features `(N, final_dim)`.
392
+ - `layer="tokens"` → final token sequence `(N, N_tok, final_dim)`.
393
+ - `layer="spatial"` → last image feature map `(N, final_dim, Hs, Ws)` (image modality).
394
+
395
+ ```python
396
+ import numpy as np
397
+ from shadax import SHADA
398
+
399
+ X = np.random.randn(8, 3, 64, 64).astype("float32")
400
+ y = np.random.randint(0, 10, size=8)
401
+ model = SHADA(tier="nano", task="classification", num_classes=10).fit(X, y, epochs=1)
402
+
403
+ g = model.extract_features(X, layer="global") # (8, final_dim)
404
+ t = model.extract_features(X, layer="tokens") # (8, N_tok, final_dim)
405
+ s = model.extract_features(X, layer="spatial") # (8, final_dim, 2, 2)
406
+ print(g.shape, t.shape, s.shape)
407
+ ```
408
+
409
+ ## Save / load
410
+
411
+ ```python
412
+ import numpy as np
413
+ from shadax import SHADA
414
+
415
+ X = np.random.randn(8, 3, 64, 64).astype("float32")
416
+ y = np.random.randint(0, 10, size=8)
417
+ model = SHADA(tier="nano", task="classification", num_classes=10).fit(X, y, epochs=1)
418
+
419
+ model.save("shada_model.pt")
420
+
421
+ reloaded = SHADA(tier="nano", task="classification", num_classes=10)
422
+ reloaded.load("shada_model.pt")
423
+ print(reloaded.is_fitted, reloaded.predict(X).shape)
424
+ ```
425
+
426
+ ## API reference
427
+
428
+ ### `SHADA` (high-level estimator)
429
+
430
+ ```python
431
+ SHADA(
432
+ tier="base", # tier string ("nano"/"base"/"large"/"xl") OR a SHADAConfig
433
+ num_classes=1000,
434
+ task="classification", # "classification" / "segmentation" / "lm" / "detection"
435
+ learning_rate=1e-4,
436
+ weight_decay=0.05,
437
+ epochs=100, # default epochs per phase
438
+ batch_size=64,
439
+ device=None, # None -> "cuda" if available else "cpu"
440
+ phases=None, # default phase list; resolved at fit() time
441
+ **kwargs, # extra SHADAConfig overrides (e.g. vocab_size, max_seq_len)
442
+ )
443
+ ```
444
+
445
+ Methods:
446
+
447
+ | method | signature | summary |
448
+ | ------ | --------- | ------- |
449
+ | `fit` | `fit(X, y=None, eval_set=None, verbose=True, epochs=None, phases=None) -> self` | Train through the resolved phases. `fit` **does** accept `epochs=`. |
450
+ | `pretrain` | `pretrain(X, epochs=None, verbose=True) -> self` | SSL-only shortcut: `fit(X, y=None, phases=[PRETRAIN], ...)`. |
451
+ | `predict` | `predict(X, return_probs=False)` | Task-specific predictions (see each task above). |
452
+ | `predict_proba` | `predict_proba(X) -> np.ndarray` | Probabilities (classification/segmentation/lm); raises `NotImplementedError` for detection. |
453
+ | `score` | `score(X, y) -> float` | Accuracy in `[0, 1]`; raises `NotImplementedError` for detection. |
454
+ | `extract_features` | `extract_features(X, layer="global")` | `"global"` / `"tokens"` / `"spatial"` features. |
455
+ | `save` | `save(path) -> None` | Save config + weights + hyper-parameters. |
456
+ | `load` | `load(path) -> self` | Restore a saved model. |
457
+ | `is_fitted` | property `-> bool` | Whether the model has been fitted. |
458
+
459
+ ### `SHADAConfig` (dataclass — the model-shape contract)
460
+
461
+ Key fields (with defaults): `tier="base"`, `encoder_dims=[128,256,512,1024]`, `encoder_depths=[3,4,6,3]`,
462
+ `num_heads=[4,8,16,32]`, `mlp_ratio=4.0`, `dropout=0.1`, `in_channels=3`, `image_size=224`,
463
+ `max_seq_len=1024`, `vocab_size=50257`, `task="classification"`, `num_classes=1000`, `mask_ratio=0.75`,
464
+ `text_mask_ratio=0.15`, `decoder_dim=256`, `decoder_depth=2`, `pad_token_id=0`. Properties: `embed_dim`,
465
+ `final_dim`, `num_stages`, `task_type`. `validate()` enforces the per-stage list lengths (4 entries each),
466
+ divisibility of each `encoder_dims[i]` by `num_heads[i]`, and the mask-ratio / `num_classes` ranges.
467
+
468
+ ### `create_config`
469
+
470
+ ```python
471
+ create_config(tier="base", task="classification", num_classes=1000, **overrides) -> SHADAConfig
472
+ ```
473
+
474
+ Builds a validated `SHADAConfig` from a tier preset, applying any field `**overrides`.
475
+
476
+ ### `HierarchicalEncoder` (shared backbone, `nn.Module`)
477
+
478
+ ```python
479
+ HierarchicalEncoder(config: SHADAConfig)
480
+ encoder(x, modality="image", causal=False) -> dict # see the encoder output contract above
481
+ encoder.forward_features(x, modality="image", causal=False) -> list # the 4 feature maps only
482
+ ```
483
+
484
+ ### `SHADANet` (unified encoder + head + SSL, `nn.Module`)
485
+
486
+ ```python
487
+ SHADANet(config: SHADAConfig)
488
+ net(x, modality=None) -> dict # routes the encoder output through the task head
489
+ net.ssl_loss(x, modality=None) -> dict # the modality-matched self-supervised loss
490
+ net.encode(x, modality=None) -> dict # raw encoder output (feature extraction)
491
+ ```
492
+
493
+ ### Enums (controlled vocabularies)
494
+
495
+ - `ModelTier`: `NANO`, `BASE`, `LARGE`, `XL`.
496
+ - `TaskType`: `CLASSIFICATION`, `DETECTION`, `SEGMENTATION`, `LANGUAGE_MODEL` (value `"lm"`).
497
+ - `TrainingPhase`: `PRETRAIN`, `MULTITASK`, `FINETUNE`, `DEPLOY`.
498
+ - `Modality`: `IMAGE`, `TEXT`.
499
+
500
+ ## Requirements
501
+
502
+ - **Python** `>=3.8`.
503
+ - **PyTorch** `>=2.0` (hard dependency).
504
+ - **NumPy**.
505
+
506
+ ## License
507
+
508
+ Released under the **MIT License**. See [LICENSE](LICENSE).