softs 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
softs-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,502 @@
1
+ Metadata-Version: 2.4
2
+ Name: softs
3
+ Version: 0.1.0
4
+ Summary: Async on-the-fly training data generation pipeline for PyTorch
5
+ Author: Ayoub G.
6
+ Author-email: develop@ayghri.me
7
+ Requires-Python: >=3.11,<3.14
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Provides-Extra: dev
13
+ Provides-Extra: examples
14
+ Requires-Dist: datasets (>=2.0) ; extra == "examples"
15
+ Requires-Dist: hydra-core (>=1.3) ; extra == "examples"
16
+ Requires-Dist: msgpack (>=1.0.0)
17
+ Requires-Dist: numpy (>=1.26.0,<3.0.0)
18
+ Requires-Dist: omegaconf (>=2.3) ; extra == "examples"
19
+ Requires-Dist: pytest (>=9.0.2,<10.0.0) ; extra == "dev"
20
+ Requires-Dist: pyyaml (>=6.0,<7.0)
21
+ Requires-Dist: pyzmq (>=26.0.0)
22
+ Requires-Dist: torch (>=2.9.1,<3.0.0)
23
+ Requires-Dist: tqdm ; extra == "examples"
24
+ Requires-Dist: transformers (>=4.40) ; extra == "examples"
25
+ Project-URL: Documentation, https://softs.readthedocs.io
26
+ Project-URL: Repository, https://github.com/ayghri/softs
27
+ Description-Content-Type: text/markdown
28
+
29
+ # softs
30
+
31
+ A broker-based data pipeline for distributed teacher-student training in PyTorch.
32
+
33
+ ## Overview
34
+
35
+ `softs` provides a **data-agnostic** message routing system for teacher-student workflows:
36
+
37
+ - **Broker**: Routes messages between students and workers. Knows nothing about the data.
38
+ - **Workers**: Generate samples on-demand, write raw bytes to student-owned memory.
39
+ - **Students**: Own memory slots, request samples, read and decode bytes.
40
+
41
+ The library only moves bytes. What those bytes represent is entirely up to your application. Use `BatchConfig` for PyTorch tensor encoding/decoding.
42
+
43
+ ## Key Features
44
+
45
+ - **Zero-copy transfer**: Workers write directly to student shared memory
46
+ - **Async pipeline**: Students train while workers generate the next batch
47
+ - **Model switching**: Change teacher models mid-training (e.g., layer-by-layer distillation)
48
+ - **Fault tolerance**: Workers/students can crash and restart independently
49
+ - **DDP compatible**: Works with PyTorch's DistributedDataParallel
50
+
51
+ ## Architecture
52
+
53
+ ```
54
+ ┌─────────────────────────────────────────────────────────────────┐
55
+ │ BROKER │
56
+ │ (message router, data-agnostic) │
57
+ │ │
58
+ │ Frontend Backend Control ControlPub │
59
+ │ (ROUTER) (ROUTER) (ROUTER) (PUB) │
60
+ │ ▲ ▲ ▲ │ │
61
+ └──────┼───────────────┼──────────────┼───────────────┼────────────┘
62
+ │ │ │ │
63
+ ┌────┴────┐ ┌─────┴─────┐ ┌────┴────┐ ┌────┴────┐
64
+ │ Student │ │ Worker │ │Student 0│ │ Workers │
65
+ │DataLoader │ (teacher) │ │(leader) │ │ (SUB) │
66
+ │ workers │ │ │ │ │ │ │
67
+ └─────────┘ └───────────┘ └─────────┘ └─────────┘
68
+ ```
69
+
70
+ ### Message Flow
71
+
72
+ 1. **Student** creates shared memory slots and sends `REQUEST` with `{token, shm_name, slot_offset}`
73
+ 2. **Broker** queues the request, assigns it to an available **Worker** via `WORK`
74
+ 3. **Worker** generates sample bytes using your `generator_fn`, writes directly to shared memory
75
+ 4. **Worker** sends `DONE` to broker, broker sends `COMPLETE` to student
76
+ 5. **Student** reads bytes from shared memory, decodes tensors, trains
77
+
78
+ ### Model Switching
79
+
80
+ Student rank 0 (leader) can change the model at any time:
81
+
82
+ ```python
83
+ client.set_model("layer_5") # Workers now generate for layer_5
84
+ ```
85
+
86
+ This:
87
+ 1. Increments a **generation counter**
88
+ 2. Broadcasts new model to all workers via PUB/SUB
89
+ 3. Discards any pending work from the old generation
90
+ 4. Workers start generating for the new model immediately
91
+
92
+ ## Installation
93
+
94
+ ```bash
95
+ pip install softs
96
+ # or
97
+ poetry add softs
98
+ ```
99
+
100
+ **Dependencies**: `pyzmq`, `torch`, `numpy`
101
+
102
+ ## Quick Start
103
+
104
+ ### 1. Define your data format with BatchConfig
105
+
106
+ ```python
107
+ from softs import BatchConfig, TensorSpec
108
+
109
+ config = BatchConfig([
110
+ TensorSpec("x", (3, 224, 224), "float32"),
111
+ TensorSpec("y", (1000,), "float32"),
112
+ ])
113
+ ```
114
+
115
+ ### 2. Start the broker
116
+
117
+ ```python
118
+ from softs import Broker, setup_logging
119
+
120
+ setup_logging("INFO")
121
+ Broker().run()
122
+ ```
123
+
124
+ ### 3. Start worker(s)
125
+
126
+ ```python
127
+ import torch
128
+ from softs import Worker, setup_logging
129
+
130
+ setup_logging("INFO")
131
+
132
+ def generate_sample(model_id: str, model_cfg: dict | None) -> bytes:
133
+ # Your generation logic - runs once per sample
134
+ x = torch.randn(3, 224, 224)
135
+ y = torch.randn(1000)
136
+ return config.encode(x=x, y=y)
137
+
138
+ Worker(
139
+ generator_fn=generate_sample,
140
+ slot_size=config.nbytes(),
141
+ ).run()
142
+ ```
143
+
144
+ ### 4. Run student training
145
+
146
+ ```python
147
+ from softs import StudentClient, DistillIterableDataset, setup_logging
148
+
149
+ setup_logging("INFO")
150
+
151
+ client = StudentClient(
152
+ student_rank=0,
153
+ slot_count=16,
154
+ batch_config=config,
155
+ )
156
+ client.hello()
157
+ client.set_model("my_model")
158
+
159
+ dataset = DistillIterableDataset(
160
+ student_rank=0,
161
+ generation_value=client.generation_value,
162
+ slot_count=8,
163
+ batch_config=config,
164
+ )
165
+
166
+ for batch in dataset:
167
+ x, y = batch["x"], batch["y"]
168
+ # Training loop...
169
+
170
+ client.close()
171
+ ```
172
+
173
+ ## BatchConfig API
174
+
175
+ `BatchConfig` describes tensors and handles encoding/decoding:
176
+
177
+ ```python
178
+ from softs import BatchConfig, TensorSpec
179
+
180
+ # Define specs
181
+ config = BatchConfig([
182
+ TensorSpec("hidden", (512, 768), "bfloat16"),
183
+ TensorSpec("labels", (512,), "int64"),
184
+ ])
185
+
186
+ # Total bytes
187
+ config.nbytes() # -> 790528
188
+
189
+ # Encode tensors to bytes
190
+ data = config.encode(hidden=hidden_tensor, labels=label_tensor)
191
+
192
+ # Decode bytes to dict of tensors
193
+ tensors = config.decode(data)
194
+
195
+ # Decode a single tensor
196
+ hidden = config.decode_single(data, "hidden")
197
+
198
+ # Properties
199
+ config.tensor_names # ['hidden', 'labels']
200
+ config.get_spec("hidden") # TensorSpec object
201
+ ```
202
+
203
+ **Supported dtypes**: `float64`, `float32`, `float16`, `bfloat16`, `int64`, `int32`, `int16`, `int8`, `uint8`, `bool`
204
+
205
+ ### Loading from YAML
206
+
207
+ ```python
208
+ config = BatchConfig.from_yaml("config.yaml")
209
+
210
+ # Or from dict
211
+ config = BatchConfig.from_dict({
212
+ "specs": [
213
+ {"name": "x", "shape": [512, 768], "dtype": "bfloat16"},
214
+ {"name": "y", "shape": [512, 768], "dtype": "bfloat16"},
215
+ ]
216
+ })
217
+ ```
218
+
219
+ ### Hydra Integration
220
+
221
+ ```yaml
222
+ # config.yaml
223
+ batch_config:
224
+ _target_: softs.BatchConfig
225
+ specs:
226
+ - name: x
227
+ shape: [512, 768]
228
+ dtype: bfloat16
229
+ ```
230
+
231
+ ```python
232
+ from hydra.utils import instantiate
233
+ config = instantiate(cfg.batch_config)
234
+ ```
235
+
236
+ ## Transfer Mediums
237
+
238
+ By default, softs uses **POSIX shared memory** for zero-copy data transfer. The architecture supports other mediums through the `Medium` protocol.
239
+
240
+ ### How Mediums Work
241
+
242
+ 1. **Students** create and own the medium (e.g., shared memory segment)
243
+ 2. **Broker** routes opaque addressing info (`shm_name`, `slot_offset`) to workers
244
+ 3. **Workers** write directly to the medium using the addressing info
245
+ 4. **Students** read from the medium after receiving completion notification
246
+
247
+ The broker never touches the actual data - it only routes metadata.
248
+
249
+ ### Default: SharedMemoryManager
250
+
251
+ ```python
252
+ from softs.mediums import SharedMemoryManager
253
+
254
+ # Students create shm (read_only=True = create owner)
255
+ shm = SharedMemoryManager(slot_count=16, slot_stride=1024, read_only=True)
256
+
257
+ # Workers attach by name (read_only=False = attach)
258
+ shm = SharedMemoryManager(slot_count=16, slot_stride=1024, read_only=False, run_id=run_id)
259
+ ```
260
+
261
+ ### Custom Mediums
262
+
263
+ Implement the `Medium` protocol or extend `MediumBase`:
264
+
265
+ ```python
266
+ from softs.mediums import MediumBase
267
+
268
+ class FileMedium(MediumBase):
269
+ """File-based medium (example for network filesystems)."""
270
+
271
+ def __init__(self, slot_count: int, slot_stride: int, path: str):
272
+ self._path = path
273
+ self._slot_count = slot_count
274
+ self._slot_stride = slot_stride
275
+ self._file = open(path, 'w+b')
276
+ self._file.truncate(slot_count * slot_stride)
277
+
278
+ @property
279
+ def buf_name(self) -> str:
280
+ return self._path
281
+
282
+ @property
283
+ def slot_count(self) -> int:
284
+ return self._slot_count
285
+
286
+ @property
287
+ def slot_stride(self) -> int:
288
+ return self._slot_stride
289
+
290
+ def read_slot_tensors(self, slot_id: int) -> bytes:
291
+ offset = slot_id * self._slot_stride
292
+ self._file.seek(offset)
293
+ return self._file.read(self._slot_stride)
294
+
295
+ def close(self) -> None:
296
+ self._file.close()
297
+
298
+ def unlink(self) -> None:
299
+ import os
300
+ os.unlink(self._path)
301
+ ```
302
+
303
+ Potential medium implementations:
304
+ - **GPU Direct**: Use CUDA IPC for GPU-to-GPU transfer
305
+ - **Network**: Use RDMA or TCP for multi-node setups
306
+ - **Memory-mapped files**: For persistence or network filesystems
307
+
308
+ ## Running with DDP
309
+
310
+ ```bash
311
+ # Terminal 1: Broker
312
+ python -c "from softs import Broker; Broker().run()"
313
+
314
+ # Terminal 2: Worker(s) - can run multiple
315
+ python worker.py
316
+
317
+ # Terminal 3: DDP students
318
+ torchrun --nproc_per_node=2 student.py
319
+ ```
320
+
321
+ For multi-GPU students:
322
+ - Only rank 0 creates the main `StudentClient` and calls `set_model()`
323
+ - Other ranks listen for model changes via PUB/SUB
324
+
325
+ ```python
326
+ if rank == 0:
327
+ client.set_model("layer_0")
328
+ else:
329
+ client.start_sub_listener()
330
+ ```
331
+
332
+ ## Example: Layer-by-Layer LLM Distillation
333
+
334
+ See `examples/distill_llm.py` for a complete example that:
335
+ 1. Loads a teacher LLM
336
+ 2. Distills layer-by-layer (switches model per layer)
337
+ 3. Uses Hydra for configuration
338
+ 4. Supports DDP training
339
+
340
+ ```bash
341
+ # Start broker
342
+ python distill_llm.py mode=broker
343
+
344
+ # Start worker (loads teacher model)
345
+ python distill_llm.py mode=worker device.worker_gpu=0
346
+
347
+ # Start student training (DDP)
348
+ torchrun --nproc_per_node=2 distill_llm.py mode=student
349
+ ```
350
+
351
+ ## API Reference
352
+
353
+ ### setup_logging
354
+
355
+ ```python
356
+ setup_logging(level: int | str = "INFO") -> None
357
+ ```
358
+
359
+ Configure logging for all softs modules.
360
+
361
+ ### Broker
362
+
363
+ ```python
364
+ Broker(
365
+ frontend_endpoint: str = "ipc:///tmp/softs_frontend.sock",
366
+ backend_endpoint: str = "ipc:///tmp/softs_backend.sock",
367
+ control_endpoint: str = "ipc:///tmp/softs_control.sock",
368
+ control_pub_endpoint: str = "ipc:///tmp/softs_control_pub.sock",
369
+ )
370
+
371
+ broker.run() # Blocking
372
+ broker.start() # Non-blocking (background thread)
373
+ broker.stop()
374
+ broker.stats # BrokerStats with metrics
375
+ ```
376
+
377
+ ### Worker
378
+
379
+ ```python
380
+ Worker(
381
+ generator_fn: Callable[[str, dict | None], bytes], # model_id, model_cfg -> bytes
382
+ slot_size: int, # Expected bytes per sample
383
+ backend_endpoint: str = ...,
384
+ control_pub_endpoint: str = ...,
385
+ worker_id: int | None = None, # Defaults to PID
386
+ )
387
+
388
+ worker.run() # Blocking
389
+ worker.start() # Non-blocking
390
+ worker.stop()
391
+ worker.generation # Current generation counter
392
+ worker.model_id # Current model ID
393
+ ```
394
+
395
+ ### StudentClient
396
+
397
+ ```python
398
+ StudentClient(
399
+ student_rank: int, # 0 = leader
400
+ slot_count: int, # Shared memory slots
401
+ batch_config: BatchConfig,
402
+ frontend_endpoint: str = ...,
403
+ control_endpoint: str = ...,
404
+ control_pub_endpoint: str = ...,
405
+ )
406
+
407
+ client.hello() -> dict # Register with broker
408
+ client.set_model(model_id, model_cfg=None) -> int # Set model (leader only), returns generation
409
+ client.request_sample(timeout_ms=1000) -> SampleRef | None
410
+ client.release_slot(slot_id) # Return slot to pool
411
+ client.start_sub_listener() # Listen for model changes (non-leader)
412
+ client.generation # Current generation
413
+ client.generation_value # multiprocessing.Value for sharing with dataset
414
+ client.close()
415
+ ```
416
+
417
+ ### DistillIterableDataset
418
+
419
+ ```python
420
+ DistillIterableDataset(
421
+ student_rank: int,
422
+ generation_value: Value, # From client.generation_value
423
+ slot_count: int,
424
+ batch_config: BatchConfig,
425
+ frontend_endpoint: str = ...,
426
+ max_retries: int = 10,
427
+ retry_delay: float = 0.01,
428
+ )
429
+ ```
430
+
431
+ Infinite `IterableDataset` yielding `dict[str, Tensor]`.
432
+
433
+ ### BatchConfig / TensorSpec
434
+
435
+ ```python
436
+ TensorSpec(name: str, shape: tuple[int, ...], dtype: str)
437
+ spec.nbytes # Bytes for this tensor
438
+ spec.torch_dtype # torch.dtype
439
+
440
+ BatchConfig(specs: list[TensorSpec])
441
+ config.nbytes() -> int
442
+ config.encode(**tensors) -> bytes
443
+ config.decode(data: bytes) -> dict[str, Tensor]
444
+ config.decode_single(data: bytes, name: str) -> Tensor
445
+ config.tensor_names -> list[str]
446
+ config.get_spec(name) -> TensorSpec
447
+ ```
448
+
449
+ ## Protocol Details
450
+
451
+ The broker uses ZeroMQ with four sockets:
452
+
453
+ | Socket | Type | Purpose |
454
+ |--------|------|---------|
455
+ | Frontend | ROUTER | Student requests (REQUEST, HELLO, STATS) |
456
+ | Backend | ROUTER | Worker communication (READY, WORK, DONE) |
457
+ | Control | ROUTER | Leader commands (SET_MODEL, STOP) |
458
+ | ControlPub | PUB | Broadcasts (MODEL changes, STOP) |
459
+
460
+ Commands:
461
+ - `HELLO`: Register student/worker
462
+ - `REQUEST`: Student requests a sample slot to be filled
463
+ - `READY`: Worker is available for work
464
+ - `WORK`: Broker assigns work to worker
465
+ - `DONE`: Worker completed writing to slot
466
+ - `COMPLETE`: Broker notifies student slot is ready
467
+ - `SET_MODEL`: Leader sets new model
468
+ - `STOP`: Shutdown workers
469
+
470
+ ## Troubleshooting
471
+
472
+ ### "Shared memory name too long" (macOS)
473
+
474
+ macOS limits shared memory names to 31 characters. The library uses short prefixes (`sl_`).
475
+
476
+ ### Stale samples after model switch
477
+
478
+ The generation counter ensures stale samples are discarded. If you see stale data, ensure:
479
+ 1. Dataset checks `generation_value` before yielding
480
+ 2. You're using `client.generation_value` (shared with dataset)
481
+
482
+ ### Worker not receiving work
483
+
484
+ Check:
485
+ 1. Broker is running
486
+ 2. Worker called `hello()` and is in main loop
487
+ 3. Student has called `set_model()` (workers wait for a model)
488
+
489
+ ### Memory not released
490
+
491
+ Call `client.close()` to properly unlink shared memory. Use context managers:
492
+
493
+ ```python
494
+ with StudentClient(...) as client:
495
+ # ...
496
+ # Automatically closes and unlinks
497
+ ```
498
+
499
+ ## License
500
+
501
+ MIT
502
+