softs 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- softs/__init__.py +25 -0
- softs/configs.py +196 -0
- softs/dataset.py +201 -0
- softs/market/__init__.py +7 -0
- softs/market/broker.py +446 -0
- softs/market/client.py +241 -0
- softs/market/mediums/__init__.py +8 -0
- softs/market/mediums/base.py +48 -0
- softs/market/mediums/filesystem.py +76 -0
- softs/market/mediums/shm.py +50 -0
- softs/market/mediums/tcp.py +135 -0
- softs/market/protocol.py +146 -0
- softs/market/supplier.py +159 -0
- softs-0.1.0.dist-info/METADATA +502 -0
- softs-0.1.0.dist-info/RECORD +16 -0
- softs-0.1.0.dist-info/WHEEL +4 -0
softs/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Soft labels: async on-the-fly training data generation for PyTorch."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
__version__ = "0.6.0"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def setup_logging(level: int | str = logging.INFO) -> None:
|
|
9
|
+
if isinstance(level, str):
|
|
10
|
+
level = getattr(logging, level.upper(), logging.INFO)
|
|
11
|
+
logging.basicConfig(
|
|
12
|
+
level=level,
|
|
13
|
+
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
|
14
|
+
datefmt="%H:%M:%S",
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# User-facing
|
|
19
|
+
from .configs import BatchConfig, TensorSpec, make_xy_config
|
|
20
|
+
from .market import EndpointConfig
|
|
21
|
+
from .dataset import SoftIterableDataset, SoftDataLoader, Batch, make_collate_fn
|
|
22
|
+
|
|
23
|
+
# Market internals (for advanced use)
|
|
24
|
+
from .market import Broker, Supplier, Client, BrokerStats
|
|
25
|
+
from .market import ShmMedium, FilesystemMedium, TCPMedium
|
softs/configs.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""Batch config, tensor specs, and encoding utilities."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
_DTYPE_MAP: dict[str, torch.dtype] = {
|
|
10
|
+
"float64": torch.float64,
|
|
11
|
+
"float32": torch.float32,
|
|
12
|
+
"float16": torch.float16,
|
|
13
|
+
"bfloat16": torch.bfloat16,
|
|
14
|
+
"int64": torch.int64,
|
|
15
|
+
"int32": torch.int32,
|
|
16
|
+
"int16": torch.int16,
|
|
17
|
+
"int8": torch.int8,
|
|
18
|
+
"uint8": torch.uint8,
|
|
19
|
+
"bool": torch.bool,
|
|
20
|
+
}
|
|
21
|
+
_TORCH_DTYPE_TO_STR: dict[torch.dtype, str] = {
|
|
22
|
+
v: k for k, v in _DTYPE_MAP.items()
|
|
23
|
+
}
|
|
24
|
+
_DTYPE_SIZES: dict[torch.dtype, int] = {
|
|
25
|
+
d: torch.empty(1, dtype=d).element_size() for d in _DTYPE_MAP.values()
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def dtype_to_torch(dtype_str: str) -> torch.dtype:
|
|
30
|
+
try:
|
|
31
|
+
return _DTYPE_MAP[dtype_str]
|
|
32
|
+
except KeyError:
|
|
33
|
+
raise ValueError(f"Unsupported dtype: {dtype_str}") from None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def torch_dtype_to_str(dtype: torch.dtype) -> str:
|
|
37
|
+
try:
|
|
38
|
+
return _TORCH_DTYPE_TO_STR[dtype]
|
|
39
|
+
except KeyError:
|
|
40
|
+
raise ValueError(f"Unsupported torch dtype: {dtype}") from None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class TensorSpec:
|
|
45
|
+
name: str
|
|
46
|
+
shape: tuple[int, ...]
|
|
47
|
+
dtype: str
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def torch_dtype(self) -> torch.dtype:
|
|
51
|
+
return dtype_to_torch(self.dtype)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def numel(self) -> int:
|
|
55
|
+
return math.prod(self.shape)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def nbytes(self) -> int:
|
|
59
|
+
return self.numel * _DTYPE_SIZES[self.torch_dtype]
|
|
60
|
+
|
|
61
|
+
def encode(self, tensor: torch.Tensor) -> bytes:
|
|
62
|
+
if tuple(tensor.shape) != self.shape:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Shape mismatch: {tuple(tensor.shape)} != {self.shape}"
|
|
65
|
+
)
|
|
66
|
+
if tensor.dtype != self.torch_dtype:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Dtype mismatch: {tensor.dtype} != {self.torch_dtype}"
|
|
69
|
+
)
|
|
70
|
+
tensor = tensor.contiguous()
|
|
71
|
+
if tensor.dtype == torch.bfloat16:
|
|
72
|
+
return tensor.view(torch.uint16).numpy().tobytes()
|
|
73
|
+
return tensor.numpy().tobytes()
|
|
74
|
+
|
|
75
|
+
def decode(self, data: bytes) -> torch.Tensor:
|
|
76
|
+
if len(data) != self.nbytes:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Data length {len(data)} != expected {self.nbytes}"
|
|
79
|
+
)
|
|
80
|
+
if self.dtype == "bfloat16":
|
|
81
|
+
tensor = torch.frombuffer(bytearray(data), dtype=torch.uint16).view(
|
|
82
|
+
torch.bfloat16
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
tensor = torch.frombuffer(bytearray(data), dtype=self.torch_dtype)
|
|
86
|
+
return tensor.reshape(self.shape).clone()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class BatchConfig:
|
|
91
|
+
specs: list[TensorSpec]
|
|
92
|
+
|
|
93
|
+
def __post_init__(self):
|
|
94
|
+
converted = []
|
|
95
|
+
for spec in self.specs:
|
|
96
|
+
if isinstance(spec, TensorSpec):
|
|
97
|
+
converted.append(spec)
|
|
98
|
+
elif hasattr(spec, "items"):
|
|
99
|
+
converted.append(
|
|
100
|
+
TensorSpec(
|
|
101
|
+
**{
|
|
102
|
+
k: v
|
|
103
|
+
for k, v in spec.items()
|
|
104
|
+
if not str(k).startswith("_")
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
raise TypeError(
|
|
110
|
+
f"Expected TensorSpec or dict-like, got {type(spec)}"
|
|
111
|
+
)
|
|
112
|
+
object.__setattr__(self, "specs", converted)
|
|
113
|
+
|
|
114
|
+
names = [s.name for s in self.specs]
|
|
115
|
+
if len(names) != len(set(names)):
|
|
116
|
+
raise ValueError(f"Duplicate tensor names: {names}")
|
|
117
|
+
|
|
118
|
+
self._name_to_idx = {s.name: i for i, s in enumerate(self.specs)}
|
|
119
|
+
self._offsets = []
|
|
120
|
+
off = 0
|
|
121
|
+
for s in self.specs:
|
|
122
|
+
self._offsets.append(off)
|
|
123
|
+
off += s.nbytes
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def tensor_names(self) -> list[str]:
|
|
127
|
+
return [s.name for s in self.specs]
|
|
128
|
+
|
|
129
|
+
def nbytes(self) -> int:
|
|
130
|
+
return sum(s.nbytes for s in self.specs)
|
|
131
|
+
|
|
132
|
+
def get_spec(self, name: str) -> TensorSpec:
|
|
133
|
+
idx = self._name_to_idx.get(name)
|
|
134
|
+
if idx is None:
|
|
135
|
+
raise KeyError(
|
|
136
|
+
f"Unknown tensor '{name}'. Available: {self.tensor_names}"
|
|
137
|
+
)
|
|
138
|
+
return self.specs[idx]
|
|
139
|
+
|
|
140
|
+
def get_offset(self, name: str) -> int:
|
|
141
|
+
idx = self._name_to_idx.get(name)
|
|
142
|
+
if idx is None:
|
|
143
|
+
raise KeyError(f"Unknown tensor '{name}'")
|
|
144
|
+
return self._offsets[idx]
|
|
145
|
+
|
|
146
|
+
def encode(self, **tensors: torch.Tensor) -> bytes:
|
|
147
|
+
provided = set(tensors.keys())
|
|
148
|
+
expected = set(self.tensor_names)
|
|
149
|
+
if provided != expected:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"Tensor mismatch: missing={expected - provided}, extra={provided - expected}"
|
|
152
|
+
)
|
|
153
|
+
return b"".join(spec.encode(tensors[spec.name]) for spec in self.specs)
|
|
154
|
+
|
|
155
|
+
def decode(self, data: bytes) -> dict[str, torch.Tensor]:
|
|
156
|
+
if len(data) != self.nbytes():
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Data length {len(data)} != expected {self.nbytes()}"
|
|
159
|
+
)
|
|
160
|
+
return {
|
|
161
|
+
spec.name: spec.decode(data[off : off + spec.nbytes])
|
|
162
|
+
for spec, off in zip(self.specs, self._offsets)
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
def decode_single(self, data: bytes, name: str) -> torch.Tensor:
|
|
166
|
+
spec = self.get_spec(name)
|
|
167
|
+
off = self.get_offset(name)
|
|
168
|
+
return spec.decode(data[off : off + spec.nbytes])
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def from_dict(cls, config: dict) -> "BatchConfig":
|
|
172
|
+
return cls(config.get("specs", []))
|
|
173
|
+
|
|
174
|
+
@classmethod
|
|
175
|
+
def from_yaml(cls, path: str) -> "BatchConfig":
|
|
176
|
+
with open(path) as f:
|
|
177
|
+
return cls.from_dict(yaml.safe_load(f))
|
|
178
|
+
|
|
179
|
+
def to_dict(self) -> dict:
|
|
180
|
+
return {
|
|
181
|
+
"specs": [
|
|
182
|
+
{"name": s.name, "shape": list(s.shape), "dtype": s.dtype}
|
|
183
|
+
for s in self.specs
|
|
184
|
+
]
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def make_xy_config(
|
|
189
|
+
x_shape: tuple[int, ...],
|
|
190
|
+
x_dtype: str,
|
|
191
|
+
y_shape: tuple[int, ...],
|
|
192
|
+
y_dtype: str,
|
|
193
|
+
) -> BatchConfig:
|
|
194
|
+
return BatchConfig(
|
|
195
|
+
[TensorSpec("x", x_shape, x_dtype), TensorSpec("y", y_shape, y_dtype)]
|
|
196
|
+
)
|
softs/dataset.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""PyTorch Dataset and DataLoader for soft label generation."""
|
|
2
|
+
|
|
3
|
+
import ctypes
|
|
4
|
+
import multiprocessing
|
|
5
|
+
import time
|
|
6
|
+
from typing import Callable, Iterator
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import DataLoader, IterableDataset
|
|
10
|
+
|
|
11
|
+
from .configs import BatchConfig
|
|
12
|
+
from .market import Client
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _SharedStr:
|
|
16
|
+
"""Process-safe mutable string via shared memory."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, value: str = "", max_len: int = 256):
|
|
19
|
+
self._buf = multiprocessing.RawArray(ctypes.c_char, max_len)
|
|
20
|
+
self._len = multiprocessing.RawValue(ctypes.c_int, 0)
|
|
21
|
+
self._max = max_len
|
|
22
|
+
if value:
|
|
23
|
+
self.set(value)
|
|
24
|
+
|
|
25
|
+
def set(self, value: str) -> None:
|
|
26
|
+
encoded = value.encode()[: self._max]
|
|
27
|
+
self._buf[: len(encoded)] = encoded
|
|
28
|
+
self._len.value = len(encoded)
|
|
29
|
+
|
|
30
|
+
def get(self) -> str:
|
|
31
|
+
return bytes(self._buf[: self._len.value]).decode()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SoftIterableDataset(IterableDataset[dict[str, torch.Tensor]]):
|
|
35
|
+
"""Infinite dataset yielding decoded tensor dicts.
|
|
36
|
+
|
|
37
|
+
Call ``set_model(model_id)`` to switch models at any time.
|
|
38
|
+
DataLoader workers detect the change and discard pending work.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model_id: str,
|
|
44
|
+
endpoint: str,
|
|
45
|
+
batch_config: BatchConfig,
|
|
46
|
+
medium_cls,
|
|
47
|
+
num_slots: int = 8,
|
|
48
|
+
max_retries: int = 10,
|
|
49
|
+
retry_delay: float = 0.01,
|
|
50
|
+
):
|
|
51
|
+
self._model = _SharedStr(model_id)
|
|
52
|
+
self.endpoint = endpoint
|
|
53
|
+
self.num_slots = num_slots
|
|
54
|
+
self.batch_config = batch_config
|
|
55
|
+
self.medium_cls = medium_cls
|
|
56
|
+
self.max_retries = max_retries
|
|
57
|
+
self.retry_delay = retry_delay
|
|
58
|
+
self._client: Client | None = None
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def model_id(self) -> str:
|
|
62
|
+
return self._model.get()
|
|
63
|
+
|
|
64
|
+
def set_model(self, model_id: str) -> None:
|
|
65
|
+
self._model.set(model_id)
|
|
66
|
+
|
|
67
|
+
def _ensure_client(self) -> Client:
|
|
68
|
+
if self._client is None:
|
|
69
|
+
self._client = Client(
|
|
70
|
+
endpoint=self.endpoint,
|
|
71
|
+
slot_size=self.batch_config.nbytes(),
|
|
72
|
+
medium_cls=self.medium_cls,
|
|
73
|
+
num_slots=self.num_slots,
|
|
74
|
+
)
|
|
75
|
+
self._client.hello()
|
|
76
|
+
return self._client
|
|
77
|
+
|
|
78
|
+
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
|
79
|
+
client = self._ensure_client()
|
|
80
|
+
retries = 0
|
|
81
|
+
current = self.model_id
|
|
82
|
+
while True:
|
|
83
|
+
wanted = self.model_id
|
|
84
|
+
if wanted != current:
|
|
85
|
+
client.discard()
|
|
86
|
+
current = wanted
|
|
87
|
+
retries = 0
|
|
88
|
+
slot_id = client.request_sample(
|
|
89
|
+
current, timeout_ms=int(self.retry_delay * 1000)
|
|
90
|
+
)
|
|
91
|
+
if slot_id is None:
|
|
92
|
+
retries += 1
|
|
93
|
+
if retries >= self.max_retries:
|
|
94
|
+
time.sleep(self.retry_delay)
|
|
95
|
+
continue
|
|
96
|
+
retries = 0
|
|
97
|
+
tensors = self.batch_config.decode(client.medium.read(slot_id))
|
|
98
|
+
client.release_slot(slot_id)
|
|
99
|
+
yield tensors
|
|
100
|
+
|
|
101
|
+
def __del__(self):
|
|
102
|
+
if self._client is not None:
|
|
103
|
+
self._client.close()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class SoftDataLoader(DataLoader):
|
|
107
|
+
"""DataLoader with model switching support.
|
|
108
|
+
|
|
109
|
+
Usage::
|
|
110
|
+
|
|
111
|
+
loader = SoftDataLoader(model_id="teacher_v1", slot_count=8,
|
|
112
|
+
batch_config=config, endpoints=ep,
|
|
113
|
+
medium_cls=ShmMedium, batch_size=4)
|
|
114
|
+
for batch in loader:
|
|
115
|
+
train(batch)
|
|
116
|
+
|
|
117
|
+
loader.set_model("teacher_v2") # all workers switch automatically
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
model_id: str,
|
|
123
|
+
endpoint: str,
|
|
124
|
+
batch_config: BatchConfig,
|
|
125
|
+
medium_cls,
|
|
126
|
+
num_slots: int = 8,
|
|
127
|
+
max_retries: int = 10,
|
|
128
|
+
retry_delay: float = 0.01,
|
|
129
|
+
**dataloader_kwargs,
|
|
130
|
+
):
|
|
131
|
+
self.dataset = SoftIterableDataset(
|
|
132
|
+
model_id=model_id,
|
|
133
|
+
endpoint=endpoint,
|
|
134
|
+
batch_config=batch_config,
|
|
135
|
+
medium_cls=medium_cls,
|
|
136
|
+
num_slots=num_slots,
|
|
137
|
+
max_retries=max_retries,
|
|
138
|
+
retry_delay=retry_delay,
|
|
139
|
+
)
|
|
140
|
+
super().__init__(self.dataset, **dataloader_kwargs)
|
|
141
|
+
|
|
142
|
+
def set_model(self, model_id: str) -> None:
|
|
143
|
+
self.dataset.set_model(model_id)
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def model_id(self) -> str:
|
|
147
|
+
return self.dataset.model_id
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class Batch:
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
tensors: dict[str, torch.Tensor],
|
|
154
|
+
slot_ids: list[int] | None = None,
|
|
155
|
+
client: Client | None = None,
|
|
156
|
+
):
|
|
157
|
+
self.tensors = tensors
|
|
158
|
+
self.slot_ids = slot_ids or []
|
|
159
|
+
self._client = client
|
|
160
|
+
self._released = False
|
|
161
|
+
|
|
162
|
+
def __getitem__(self, key: str) -> torch.Tensor:
|
|
163
|
+
return self.tensors[key]
|
|
164
|
+
|
|
165
|
+
def __contains__(self, key: str) -> bool:
|
|
166
|
+
return key in self.tensors
|
|
167
|
+
|
|
168
|
+
def keys(self):
|
|
169
|
+
return self.tensors.keys()
|
|
170
|
+
|
|
171
|
+
def release(self) -> None:
|
|
172
|
+
if self._released or self._client is None:
|
|
173
|
+
return
|
|
174
|
+
for slot_id in self.slot_ids:
|
|
175
|
+
self._client.release_slot(slot_id)
|
|
176
|
+
self._released = True
|
|
177
|
+
|
|
178
|
+
def __del__(self):
|
|
179
|
+
self.release()
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def make_collate_fn(
|
|
183
|
+
client: Client, batch_config: BatchConfig, auto_release: bool = True
|
|
184
|
+
) -> Callable[[list[int]], Batch]:
|
|
185
|
+
def collate_fn(slot_ids: list[int]) -> Batch:
|
|
186
|
+
tensor_lists: dict[str, list[torch.Tensor]] = {
|
|
187
|
+
name: [] for name in batch_config.tensor_names
|
|
188
|
+
}
|
|
189
|
+
for slot_id in slot_ids:
|
|
190
|
+
for name, tensor in batch_config.decode(
|
|
191
|
+
client.medium.read(slot_id)
|
|
192
|
+
).items():
|
|
193
|
+
tensor_lists[name].append(tensor)
|
|
194
|
+
batched = {name: torch.stack(ts) for name, ts in tensor_lists.items()}
|
|
195
|
+
if auto_release:
|
|
196
|
+
for slot_id in slot_ids:
|
|
197
|
+
client.release_slot(slot_id)
|
|
198
|
+
return Batch(batched, slot_ids, None)
|
|
199
|
+
return Batch(batched, slot_ids, client)
|
|
200
|
+
|
|
201
|
+
return collate_fn
|
softs/market/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Marketplace: broker, supplier, client, and transfer mediums."""
|
|
2
|
+
|
|
3
|
+
from .broker import Broker
|
|
4
|
+
from .supplier import Supplier
|
|
5
|
+
from .client import Client
|
|
6
|
+
from .protocol import EndpointConfig, BrokerStats
|
|
7
|
+
from .mediums import Medium, ShmMedium, FilesystemMedium, TCPMedium
|