softs 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
softs/__init__.py ADDED
@@ -0,0 +1,25 @@
1
+ """Soft labels: async on-the-fly training data generation for PyTorch."""
2
+
3
+ import logging
4
+
5
+ __version__ = "0.6.0"
6
+
7
+
8
+ def setup_logging(level: int | str = logging.INFO) -> None:
9
+ if isinstance(level, str):
10
+ level = getattr(logging, level.upper(), logging.INFO)
11
+ logging.basicConfig(
12
+ level=level,
13
+ format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
14
+ datefmt="%H:%M:%S",
15
+ )
16
+
17
+
18
+ # User-facing
19
+ from .configs import BatchConfig, TensorSpec, make_xy_config
20
+ from .market import EndpointConfig
21
+ from .dataset import SoftIterableDataset, SoftDataLoader, Batch, make_collate_fn
22
+
23
+ # Market internals (for advanced use)
24
+ from .market import Broker, Supplier, Client, BrokerStats
25
+ from .market import ShmMedium, FilesystemMedium, TCPMedium
softs/configs.py ADDED
@@ -0,0 +1,196 @@
1
+ """Batch config, tensor specs, and encoding utilities."""
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import yaml
8
+
9
+ _DTYPE_MAP: dict[str, torch.dtype] = {
10
+ "float64": torch.float64,
11
+ "float32": torch.float32,
12
+ "float16": torch.float16,
13
+ "bfloat16": torch.bfloat16,
14
+ "int64": torch.int64,
15
+ "int32": torch.int32,
16
+ "int16": torch.int16,
17
+ "int8": torch.int8,
18
+ "uint8": torch.uint8,
19
+ "bool": torch.bool,
20
+ }
21
+ _TORCH_DTYPE_TO_STR: dict[torch.dtype, str] = {
22
+ v: k for k, v in _DTYPE_MAP.items()
23
+ }
24
+ _DTYPE_SIZES: dict[torch.dtype, int] = {
25
+ d: torch.empty(1, dtype=d).element_size() for d in _DTYPE_MAP.values()
26
+ }
27
+
28
+
29
+ def dtype_to_torch(dtype_str: str) -> torch.dtype:
30
+ try:
31
+ return _DTYPE_MAP[dtype_str]
32
+ except KeyError:
33
+ raise ValueError(f"Unsupported dtype: {dtype_str}") from None
34
+
35
+
36
+ def torch_dtype_to_str(dtype: torch.dtype) -> str:
37
+ try:
38
+ return _TORCH_DTYPE_TO_STR[dtype]
39
+ except KeyError:
40
+ raise ValueError(f"Unsupported torch dtype: {dtype}") from None
41
+
42
+
43
+ @dataclass
44
+ class TensorSpec:
45
+ name: str
46
+ shape: tuple[int, ...]
47
+ dtype: str
48
+
49
+ @property
50
+ def torch_dtype(self) -> torch.dtype:
51
+ return dtype_to_torch(self.dtype)
52
+
53
+ @property
54
+ def numel(self) -> int:
55
+ return math.prod(self.shape)
56
+
57
+ @property
58
+ def nbytes(self) -> int:
59
+ return self.numel * _DTYPE_SIZES[self.torch_dtype]
60
+
61
+ def encode(self, tensor: torch.Tensor) -> bytes:
62
+ if tuple(tensor.shape) != self.shape:
63
+ raise ValueError(
64
+ f"Shape mismatch: {tuple(tensor.shape)} != {self.shape}"
65
+ )
66
+ if tensor.dtype != self.torch_dtype:
67
+ raise ValueError(
68
+ f"Dtype mismatch: {tensor.dtype} != {self.torch_dtype}"
69
+ )
70
+ tensor = tensor.contiguous()
71
+ if tensor.dtype == torch.bfloat16:
72
+ return tensor.view(torch.uint16).numpy().tobytes()
73
+ return tensor.numpy().tobytes()
74
+
75
+ def decode(self, data: bytes) -> torch.Tensor:
76
+ if len(data) != self.nbytes:
77
+ raise ValueError(
78
+ f"Data length {len(data)} != expected {self.nbytes}"
79
+ )
80
+ if self.dtype == "bfloat16":
81
+ tensor = torch.frombuffer(bytearray(data), dtype=torch.uint16).view(
82
+ torch.bfloat16
83
+ )
84
+ else:
85
+ tensor = torch.frombuffer(bytearray(data), dtype=self.torch_dtype)
86
+ return tensor.reshape(self.shape).clone()
87
+
88
+
89
+ @dataclass
90
+ class BatchConfig:
91
+ specs: list[TensorSpec]
92
+
93
+ def __post_init__(self):
94
+ converted = []
95
+ for spec in self.specs:
96
+ if isinstance(spec, TensorSpec):
97
+ converted.append(spec)
98
+ elif hasattr(spec, "items"):
99
+ converted.append(
100
+ TensorSpec(
101
+ **{
102
+ k: v
103
+ for k, v in spec.items()
104
+ if not str(k).startswith("_")
105
+ }
106
+ )
107
+ )
108
+ else:
109
+ raise TypeError(
110
+ f"Expected TensorSpec or dict-like, got {type(spec)}"
111
+ )
112
+ object.__setattr__(self, "specs", converted)
113
+
114
+ names = [s.name for s in self.specs]
115
+ if len(names) != len(set(names)):
116
+ raise ValueError(f"Duplicate tensor names: {names}")
117
+
118
+ self._name_to_idx = {s.name: i for i, s in enumerate(self.specs)}
119
+ self._offsets = []
120
+ off = 0
121
+ for s in self.specs:
122
+ self._offsets.append(off)
123
+ off += s.nbytes
124
+
125
+ @property
126
+ def tensor_names(self) -> list[str]:
127
+ return [s.name for s in self.specs]
128
+
129
+ def nbytes(self) -> int:
130
+ return sum(s.nbytes for s in self.specs)
131
+
132
+ def get_spec(self, name: str) -> TensorSpec:
133
+ idx = self._name_to_idx.get(name)
134
+ if idx is None:
135
+ raise KeyError(
136
+ f"Unknown tensor '{name}'. Available: {self.tensor_names}"
137
+ )
138
+ return self.specs[idx]
139
+
140
+ def get_offset(self, name: str) -> int:
141
+ idx = self._name_to_idx.get(name)
142
+ if idx is None:
143
+ raise KeyError(f"Unknown tensor '{name}'")
144
+ return self._offsets[idx]
145
+
146
+ def encode(self, **tensors: torch.Tensor) -> bytes:
147
+ provided = set(tensors.keys())
148
+ expected = set(self.tensor_names)
149
+ if provided != expected:
150
+ raise ValueError(
151
+ f"Tensor mismatch: missing={expected - provided}, extra={provided - expected}"
152
+ )
153
+ return b"".join(spec.encode(tensors[spec.name]) for spec in self.specs)
154
+
155
+ def decode(self, data: bytes) -> dict[str, torch.Tensor]:
156
+ if len(data) != self.nbytes():
157
+ raise ValueError(
158
+ f"Data length {len(data)} != expected {self.nbytes()}"
159
+ )
160
+ return {
161
+ spec.name: spec.decode(data[off : off + spec.nbytes])
162
+ for spec, off in zip(self.specs, self._offsets)
163
+ }
164
+
165
+ def decode_single(self, data: bytes, name: str) -> torch.Tensor:
166
+ spec = self.get_spec(name)
167
+ off = self.get_offset(name)
168
+ return spec.decode(data[off : off + spec.nbytes])
169
+
170
+ @classmethod
171
+ def from_dict(cls, config: dict) -> "BatchConfig":
172
+ return cls(config.get("specs", []))
173
+
174
+ @classmethod
175
+ def from_yaml(cls, path: str) -> "BatchConfig":
176
+ with open(path) as f:
177
+ return cls.from_dict(yaml.safe_load(f))
178
+
179
+ def to_dict(self) -> dict:
180
+ return {
181
+ "specs": [
182
+ {"name": s.name, "shape": list(s.shape), "dtype": s.dtype}
183
+ for s in self.specs
184
+ ]
185
+ }
186
+
187
+
188
+ def make_xy_config(
189
+ x_shape: tuple[int, ...],
190
+ x_dtype: str,
191
+ y_shape: tuple[int, ...],
192
+ y_dtype: str,
193
+ ) -> BatchConfig:
194
+ return BatchConfig(
195
+ [TensorSpec("x", x_shape, x_dtype), TensorSpec("y", y_shape, y_dtype)]
196
+ )
softs/dataset.py ADDED
@@ -0,0 +1,201 @@
1
+ """PyTorch Dataset and DataLoader for soft label generation."""
2
+
3
+ import ctypes
4
+ import multiprocessing
5
+ import time
6
+ from typing import Callable, Iterator
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader, IterableDataset
10
+
11
+ from .configs import BatchConfig
12
+ from .market import Client
13
+
14
+
15
+ class _SharedStr:
16
+ """Process-safe mutable string via shared memory."""
17
+
18
+ def __init__(self, value: str = "", max_len: int = 256):
19
+ self._buf = multiprocessing.RawArray(ctypes.c_char, max_len)
20
+ self._len = multiprocessing.RawValue(ctypes.c_int, 0)
21
+ self._max = max_len
22
+ if value:
23
+ self.set(value)
24
+
25
+ def set(self, value: str) -> None:
26
+ encoded = value.encode()[: self._max]
27
+ self._buf[: len(encoded)] = encoded
28
+ self._len.value = len(encoded)
29
+
30
+ def get(self) -> str:
31
+ return bytes(self._buf[: self._len.value]).decode()
32
+
33
+
34
+ class SoftIterableDataset(IterableDataset[dict[str, torch.Tensor]]):
35
+ """Infinite dataset yielding decoded tensor dicts.
36
+
37
+ Call ``set_model(model_id)`` to switch models at any time.
38
+ DataLoader workers detect the change and discard pending work.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ model_id: str,
44
+ endpoint: str,
45
+ batch_config: BatchConfig,
46
+ medium_cls,
47
+ num_slots: int = 8,
48
+ max_retries: int = 10,
49
+ retry_delay: float = 0.01,
50
+ ):
51
+ self._model = _SharedStr(model_id)
52
+ self.endpoint = endpoint
53
+ self.num_slots = num_slots
54
+ self.batch_config = batch_config
55
+ self.medium_cls = medium_cls
56
+ self.max_retries = max_retries
57
+ self.retry_delay = retry_delay
58
+ self._client: Client | None = None
59
+
60
+ @property
61
+ def model_id(self) -> str:
62
+ return self._model.get()
63
+
64
+ def set_model(self, model_id: str) -> None:
65
+ self._model.set(model_id)
66
+
67
+ def _ensure_client(self) -> Client:
68
+ if self._client is None:
69
+ self._client = Client(
70
+ endpoint=self.endpoint,
71
+ slot_size=self.batch_config.nbytes(),
72
+ medium_cls=self.medium_cls,
73
+ num_slots=self.num_slots,
74
+ )
75
+ self._client.hello()
76
+ return self._client
77
+
78
+ def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
79
+ client = self._ensure_client()
80
+ retries = 0
81
+ current = self.model_id
82
+ while True:
83
+ wanted = self.model_id
84
+ if wanted != current:
85
+ client.discard()
86
+ current = wanted
87
+ retries = 0
88
+ slot_id = client.request_sample(
89
+ current, timeout_ms=int(self.retry_delay * 1000)
90
+ )
91
+ if slot_id is None:
92
+ retries += 1
93
+ if retries >= self.max_retries:
94
+ time.sleep(self.retry_delay)
95
+ continue
96
+ retries = 0
97
+ tensors = self.batch_config.decode(client.medium.read(slot_id))
98
+ client.release_slot(slot_id)
99
+ yield tensors
100
+
101
+ def __del__(self):
102
+ if self._client is not None:
103
+ self._client.close()
104
+
105
+
106
+ class SoftDataLoader(DataLoader):
107
+ """DataLoader with model switching support.
108
+
109
+ Usage::
110
+
111
+ loader = SoftDataLoader(model_id="teacher_v1", slot_count=8,
112
+ batch_config=config, endpoints=ep,
113
+ medium_cls=ShmMedium, batch_size=4)
114
+ for batch in loader:
115
+ train(batch)
116
+
117
+ loader.set_model("teacher_v2") # all workers switch automatically
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ model_id: str,
123
+ endpoint: str,
124
+ batch_config: BatchConfig,
125
+ medium_cls,
126
+ num_slots: int = 8,
127
+ max_retries: int = 10,
128
+ retry_delay: float = 0.01,
129
+ **dataloader_kwargs,
130
+ ):
131
+ self.dataset = SoftIterableDataset(
132
+ model_id=model_id,
133
+ endpoint=endpoint,
134
+ batch_config=batch_config,
135
+ medium_cls=medium_cls,
136
+ num_slots=num_slots,
137
+ max_retries=max_retries,
138
+ retry_delay=retry_delay,
139
+ )
140
+ super().__init__(self.dataset, **dataloader_kwargs)
141
+
142
+ def set_model(self, model_id: str) -> None:
143
+ self.dataset.set_model(model_id)
144
+
145
+ @property
146
+ def model_id(self) -> str:
147
+ return self.dataset.model_id
148
+
149
+
150
+ class Batch:
151
+ def __init__(
152
+ self,
153
+ tensors: dict[str, torch.Tensor],
154
+ slot_ids: list[int] | None = None,
155
+ client: Client | None = None,
156
+ ):
157
+ self.tensors = tensors
158
+ self.slot_ids = slot_ids or []
159
+ self._client = client
160
+ self._released = False
161
+
162
+ def __getitem__(self, key: str) -> torch.Tensor:
163
+ return self.tensors[key]
164
+
165
+ def __contains__(self, key: str) -> bool:
166
+ return key in self.tensors
167
+
168
+ def keys(self):
169
+ return self.tensors.keys()
170
+
171
+ def release(self) -> None:
172
+ if self._released or self._client is None:
173
+ return
174
+ for slot_id in self.slot_ids:
175
+ self._client.release_slot(slot_id)
176
+ self._released = True
177
+
178
+ def __del__(self):
179
+ self.release()
180
+
181
+
182
+ def make_collate_fn(
183
+ client: Client, batch_config: BatchConfig, auto_release: bool = True
184
+ ) -> Callable[[list[int]], Batch]:
185
+ def collate_fn(slot_ids: list[int]) -> Batch:
186
+ tensor_lists: dict[str, list[torch.Tensor]] = {
187
+ name: [] for name in batch_config.tensor_names
188
+ }
189
+ for slot_id in slot_ids:
190
+ for name, tensor in batch_config.decode(
191
+ client.medium.read(slot_id)
192
+ ).items():
193
+ tensor_lists[name].append(tensor)
194
+ batched = {name: torch.stack(ts) for name, ts in tensor_lists.items()}
195
+ if auto_release:
196
+ for slot_id in slot_ids:
197
+ client.release_slot(slot_id)
198
+ return Batch(batched, slot_ids, None)
199
+ return Batch(batched, slot_ids, client)
200
+
201
+ return collate_fn
@@ -0,0 +1,7 @@
1
+ """Marketplace: broker, supplier, client, and transfer mediums."""
2
+
3
+ from .broker import Broker
4
+ from .supplier import Supplier
5
+ from .client import Client
6
+ from .protocol import EndpointConfig, BrokerStats
7
+ from .mediums import Medium, ShmMedium, FilesystemMedium, TCPMedium