edgeml-sdk 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edgeml_sdk-0.1.2/PKG-INFO +47 -0
- edgeml_sdk-0.1.2/README.md +39 -0
- edgeml_sdk-0.1.2/edgeml/__init__.py +3 -0
- edgeml_sdk-0.1.2/edgeml/client.py +499 -0
- edgeml_sdk-0.1.2/edgeml_sdk.egg-info/PKG-INFO +47 -0
- edgeml_sdk-0.1.2/edgeml_sdk.egg-info/SOURCES.txt +9 -0
- edgeml_sdk-0.1.2/edgeml_sdk.egg-info/dependency_links.txt +1 -0
- edgeml_sdk-0.1.2/edgeml_sdk.egg-info/requires.txt +1 -0
- edgeml_sdk-0.1.2/edgeml_sdk.egg-info/top_level.txt +1 -0
- edgeml_sdk-0.1.2/pyproject.toml +17 -0
- edgeml_sdk-0.1.2/setup.cfg +4 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: edgeml-sdk
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: EdgeML Python SDK
|
|
5
|
+
Requires-Python: >=3.9
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: httpx>=0.25.0
|
|
8
|
+
|
|
9
|
+
# EdgeML Python SDK
|
|
10
|
+
|
|
11
|
+
Minimal Python SDK wrapper for the EdgeML REST API.
|
|
12
|
+
|
|
13
|
+
## Quickstart
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from edgeml import Federation, FederatedClient
|
|
17
|
+
|
|
18
|
+
# Admin / coordinator
|
|
19
|
+
fed = Federation(api_key="ek_live_...", org_id="default")
|
|
20
|
+
fed.invite(org_ids=["org_hospital_a", "org_hospital_b"])
|
|
21
|
+
fed.train("tumor_detection", rounds=10)
|
|
22
|
+
fed.deploy()
|
|
23
|
+
|
|
24
|
+
# Client device
|
|
25
|
+
client = FederatedClient(api_key="ek_live_...", org_id="org_hospital_a")
|
|
26
|
+
client.join_federation("cancer_research_consortium")
|
|
27
|
+
client.train("tumor_detection", local_data, rounds=10)
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Install (local dev)
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install -e .
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Publish (automated)
|
|
37
|
+
|
|
38
|
+
Tag a release with `edgeml-vX.Y.Z` and push the tag:
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
git tag edgeml-v0.1.1
|
|
42
|
+
git push origin edgeml-v0.1.1
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
The GitHub Action `Publish Python SDK` will build and publish to PyPI. Configure
|
|
46
|
+
PyPI trusted publishing for this repo (or set the `PYPI_API_TOKEN` secret and
|
|
47
|
+
update the workflow to use it).
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# EdgeML Python SDK
|
|
2
|
+
|
|
3
|
+
Minimal Python SDK wrapper for the EdgeML REST API.
|
|
4
|
+
|
|
5
|
+
## Quickstart
|
|
6
|
+
|
|
7
|
+
```python
|
|
8
|
+
from edgeml import Federation, FederatedClient
|
|
9
|
+
|
|
10
|
+
# Admin / coordinator
|
|
11
|
+
fed = Federation(api_key="ek_live_...", org_id="default")
|
|
12
|
+
fed.invite(org_ids=["org_hospital_a", "org_hospital_b"])
|
|
13
|
+
fed.train("tumor_detection", rounds=10)
|
|
14
|
+
fed.deploy()
|
|
15
|
+
|
|
16
|
+
# Client device
|
|
17
|
+
client = FederatedClient(api_key="ek_live_...", org_id="org_hospital_a")
|
|
18
|
+
client.join_federation("cancer_research_consortium")
|
|
19
|
+
client.train("tumor_detection", local_data, rounds=10)
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Install (local dev)
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
pip install -e .
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Publish (automated)
|
|
29
|
+
|
|
30
|
+
Tag a release with `edgeml-vX.Y.Z` and push the tag:
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
git tag edgeml-v0.1.1
|
|
34
|
+
git push origin edgeml-v0.1.1
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
The GitHub Action `Publish Python SDK` will build and publish to PyPI. Configure
|
|
38
|
+
PyPI trusted publishing for this repo (or set the `PYPI_API_TOKEN` secret and
|
|
39
|
+
update the workflow to use it).
|
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import contextlib
|
|
5
|
+
import io
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import Any, Callable, Iterable, Optional
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EdgeMLClientError(RuntimeError):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _ApiClient:
|
|
17
|
+
def __init__(self, api_key: str, api_base: str, timeout: float = 20.0):
|
|
18
|
+
self.api_key = api_key
|
|
19
|
+
self.api_base = api_base.rstrip("/")
|
|
20
|
+
self.timeout = timeout
|
|
21
|
+
|
|
22
|
+
def _headers(self) -> dict[str, str]:
|
|
23
|
+
return {"Authorization": f"Bearer {self.api_key}"}
|
|
24
|
+
|
|
25
|
+
def get(self, path: str, params: Optional[dict[str, Any]] = None) -> Any:
|
|
26
|
+
with httpx.Client(timeout=self.timeout) as client:
|
|
27
|
+
res = client.get(f"{self.api_base}{path}", params=params, headers=self._headers())
|
|
28
|
+
if res.status_code >= 400:
|
|
29
|
+
raise EdgeMLClientError(res.text)
|
|
30
|
+
return res.json()
|
|
31
|
+
|
|
32
|
+
def post(self, path: str, payload: dict[str, Any]) -> Any:
|
|
33
|
+
with httpx.Client(timeout=self.timeout) as client:
|
|
34
|
+
res = client.post(f"{self.api_base}{path}", json=payload, headers=self._headers())
|
|
35
|
+
if res.status_code >= 400:
|
|
36
|
+
raise EdgeMLClientError(res.text)
|
|
37
|
+
return res.json()
|
|
38
|
+
|
|
39
|
+
def get_bytes(self, path: str, params: Optional[dict[str, Any]] = None) -> bytes:
|
|
40
|
+
with httpx.Client(timeout=self.timeout) as client:
|
|
41
|
+
res = client.get(f"{self.api_base}{path}", params=params, headers=self._headers())
|
|
42
|
+
if res.status_code >= 400:
|
|
43
|
+
raise EdgeMLClientError(res.text)
|
|
44
|
+
return res.content
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Federation:
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
api_key: str,
|
|
51
|
+
name: str | None = None,
|
|
52
|
+
org_id: str = "default",
|
|
53
|
+
api_base: str = "https://api.edgeml.io/api/v1",
|
|
54
|
+
):
|
|
55
|
+
self.api = _ApiClient(api_key=api_key, api_base=api_base)
|
|
56
|
+
self.org_id = org_id
|
|
57
|
+
self.name = name or "default"
|
|
58
|
+
self.last_model_id: Optional[str] = None
|
|
59
|
+
self.last_version: Optional[str] = None
|
|
60
|
+
self.federation_id = self._resolve_or_create_federation()
|
|
61
|
+
|
|
62
|
+
def _resolve_or_create_federation(self) -> str:
|
|
63
|
+
existing = self.api.get(
|
|
64
|
+
"/federations",
|
|
65
|
+
params={"org_id": self.org_id, "name": self.name},
|
|
66
|
+
)
|
|
67
|
+
if existing:
|
|
68
|
+
return existing[0]["id"]
|
|
69
|
+
created = self.api.post(
|
|
70
|
+
"/federations",
|
|
71
|
+
{"org_id": self.org_id, "name": self.name},
|
|
72
|
+
)
|
|
73
|
+
return created["id"]
|
|
74
|
+
|
|
75
|
+
def invite(self, org_ids: Iterable[str]) -> list[dict[str, Any]]:
|
|
76
|
+
payload = {"org_ids": list(org_ids)}
|
|
77
|
+
return self.api.post(f"/federations/{self.federation_id}/invite", payload)
|
|
78
|
+
|
|
79
|
+
def _resolve_model_id(self, model: str) -> str:
|
|
80
|
+
# Try name lookup first; if not found, assume it's an ID
|
|
81
|
+
data = self.api.get("/models", params={"org_id": self.org_id})
|
|
82
|
+
for item in data.get("models", []):
|
|
83
|
+
if item.get("name") == model:
|
|
84
|
+
return item["id"]
|
|
85
|
+
return model
|
|
86
|
+
|
|
87
|
+
def train(
|
|
88
|
+
self,
|
|
89
|
+
model: str,
|
|
90
|
+
algorithm: str = "fedavg",
|
|
91
|
+
rounds: int = 1,
|
|
92
|
+
min_updates: int = 1,
|
|
93
|
+
base_version: Optional[str] = None,
|
|
94
|
+
new_version: Optional[str] = None,
|
|
95
|
+
publish: bool = True,
|
|
96
|
+
strategy: str = "metrics",
|
|
97
|
+
update_format: str = "delta",
|
|
98
|
+
architecture: Optional[str] = None,
|
|
99
|
+
input_dim: int = 16,
|
|
100
|
+
hidden_dim: int = 8,
|
|
101
|
+
output_dim: int = 4,
|
|
102
|
+
) -> dict[str, Any]:
|
|
103
|
+
if algorithm.lower() != "fedavg":
|
|
104
|
+
raise EdgeMLClientError(f"Unsupported algorithm: {algorithm}")
|
|
105
|
+
|
|
106
|
+
model_id = self._resolve_model_id(model)
|
|
107
|
+
self.last_model_id = model_id
|
|
108
|
+
result: Optional[dict[str, Any]] = None
|
|
109
|
+
current_base = base_version
|
|
110
|
+
|
|
111
|
+
for _ in range(rounds):
|
|
112
|
+
payload = {
|
|
113
|
+
"model_id": model_id,
|
|
114
|
+
"base_version": current_base,
|
|
115
|
+
"new_version": new_version,
|
|
116
|
+
"min_updates": min_updates,
|
|
117
|
+
"publish": publish,
|
|
118
|
+
"strategy": strategy,
|
|
119
|
+
"update_format": update_format,
|
|
120
|
+
"architecture": architecture,
|
|
121
|
+
"input_dim": input_dim,
|
|
122
|
+
"hidden_dim": hidden_dim,
|
|
123
|
+
"output_dim": output_dim,
|
|
124
|
+
}
|
|
125
|
+
result = self.api.post("/training/aggregate", payload)
|
|
126
|
+
current_base = result.get("new_version")
|
|
127
|
+
self.last_version = current_base
|
|
128
|
+
new_version = None
|
|
129
|
+
|
|
130
|
+
return result or {}
|
|
131
|
+
|
|
132
|
+
def deploy(
|
|
133
|
+
self,
|
|
134
|
+
model_id: Optional[str] = None,
|
|
135
|
+
version: Optional[str] = None,
|
|
136
|
+
rollout_percentage: int = 10,
|
|
137
|
+
target_percentage: int = 100,
|
|
138
|
+
increment_step: int = 10,
|
|
139
|
+
start_immediately: bool = True,
|
|
140
|
+
) -> dict[str, Any]:
|
|
141
|
+
model_id = model_id or self.last_model_id
|
|
142
|
+
if not model_id:
|
|
143
|
+
raise EdgeMLClientError("model_id is required for deploy()")
|
|
144
|
+
|
|
145
|
+
if not version:
|
|
146
|
+
if self.last_version:
|
|
147
|
+
version = self.last_version
|
|
148
|
+
else:
|
|
149
|
+
latest = self.api.get(f"/models/{model_id}/versions/latest")
|
|
150
|
+
version = latest.get("version")
|
|
151
|
+
if not version:
|
|
152
|
+
raise EdgeMLClientError("version is required for deploy()")
|
|
153
|
+
|
|
154
|
+
payload = {
|
|
155
|
+
"version": version,
|
|
156
|
+
"rollout_percentage": rollout_percentage,
|
|
157
|
+
"target_percentage": target_percentage,
|
|
158
|
+
"increment_step": increment_step,
|
|
159
|
+
"start_immediately": start_immediately,
|
|
160
|
+
}
|
|
161
|
+
return self.api.post(f"/models/{model_id}/rollouts", payload)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class FederatedClient:
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
api_key: str,
|
|
168
|
+
org_id: str = "default",
|
|
169
|
+
api_base: str = "https://api.edgeml.io/api/v1",
|
|
170
|
+
device_identifier: Optional[str] = None,
|
|
171
|
+
platform: str = "python",
|
|
172
|
+
):
|
|
173
|
+
self.api = _ApiClient(api_key=api_key, api_base=api_base)
|
|
174
|
+
self.org_id = org_id
|
|
175
|
+
self.device_identifier = device_identifier or f"client-{uuid.uuid4().hex[:10]}"
|
|
176
|
+
self.platform = platform
|
|
177
|
+
self.device_id: Optional[str] = None
|
|
178
|
+
|
|
179
|
+
def register(self) -> str:
|
|
180
|
+
if self.device_id:
|
|
181
|
+
return self.device_id
|
|
182
|
+
payload = {
|
|
183
|
+
"device_identifier": self.device_identifier,
|
|
184
|
+
"org_id": self.org_id,
|
|
185
|
+
"platform": self.platform,
|
|
186
|
+
"os_version": "macos",
|
|
187
|
+
"sdk_version": "0.1.0",
|
|
188
|
+
"app_version": "0.1.0",
|
|
189
|
+
"metadata": {"client": "python-sdk"},
|
|
190
|
+
"capabilities": {"training": True},
|
|
191
|
+
}
|
|
192
|
+
response = self.api.post("/devices/register", payload)
|
|
193
|
+
self.device_id = response.get("id")
|
|
194
|
+
if not self.device_id:
|
|
195
|
+
raise EdgeMLClientError("Device registration failed: missing device ID")
|
|
196
|
+
return self.device_id
|
|
197
|
+
|
|
198
|
+
def join_federation(self, federation_name: str) -> dict[str, Any]:
|
|
199
|
+
self.register()
|
|
200
|
+
existing = self.api.get("/federations", params={"name": federation_name})
|
|
201
|
+
if existing:
|
|
202
|
+
federation_id = existing[0]["id"]
|
|
203
|
+
else:
|
|
204
|
+
created = self.api.post(
|
|
205
|
+
"/federations",
|
|
206
|
+
{"org_id": self.org_id, "name": federation_name},
|
|
207
|
+
)
|
|
208
|
+
federation_id = created["id"]
|
|
209
|
+
return self.api.post(
|
|
210
|
+
f"/federations/{federation_id}/join",
|
|
211
|
+
{"org_id": self.org_id},
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def train(
|
|
215
|
+
self,
|
|
216
|
+
model: str,
|
|
217
|
+
local_data: Any,
|
|
218
|
+
rounds: int = 1,
|
|
219
|
+
version: Optional[str] = None,
|
|
220
|
+
sample_count: int = 0,
|
|
221
|
+
metrics: Optional[dict[str, float]] = None,
|
|
222
|
+
update_format: str = "delta",
|
|
223
|
+
) -> list[dict[str, Any]]:
|
|
224
|
+
self.register()
|
|
225
|
+
results = []
|
|
226
|
+
|
|
227
|
+
model_id = self._resolve_model_id(model)
|
|
228
|
+
if not version:
|
|
229
|
+
latest = self.api.get(f"/models/{model_id}/versions/latest")
|
|
230
|
+
version = latest.get("version")
|
|
231
|
+
if not version:
|
|
232
|
+
raise EdgeMLClientError("Failed to resolve model version")
|
|
233
|
+
|
|
234
|
+
for _ in range(rounds):
|
|
235
|
+
if callable(local_data):
|
|
236
|
+
weights_data, sample_count, metrics = local_data()
|
|
237
|
+
else:
|
|
238
|
+
weights_data = local_data
|
|
239
|
+
|
|
240
|
+
weights_data = self._serialize_weights(weights_data)
|
|
241
|
+
|
|
242
|
+
weights_b64 = base64.b64encode(weights_data).decode("ascii")
|
|
243
|
+
payload = {
|
|
244
|
+
"model_id": model_id,
|
|
245
|
+
"version": version,
|
|
246
|
+
"device_id": self.device_id,
|
|
247
|
+
"sample_count": sample_count or 0,
|
|
248
|
+
"metrics": metrics or {},
|
|
249
|
+
"update_format": update_format,
|
|
250
|
+
"weights_data": weights_b64,
|
|
251
|
+
}
|
|
252
|
+
results.append(self.api.post("/training/weights", payload))
|
|
253
|
+
|
|
254
|
+
return results
|
|
255
|
+
|
|
256
|
+
def pull_model(
|
|
257
|
+
self,
|
|
258
|
+
model: str,
|
|
259
|
+
version: Optional[str] = None,
|
|
260
|
+
format: str = "pytorch",
|
|
261
|
+
) -> bytes:
|
|
262
|
+
model_id = self._resolve_model_id(model)
|
|
263
|
+
if not version:
|
|
264
|
+
latest = self.api.get(f"/models/{model_id}/versions/latest")
|
|
265
|
+
version = latest.get("version")
|
|
266
|
+
if not version:
|
|
267
|
+
raise EdgeMLClientError("Failed to resolve model version")
|
|
268
|
+
return self.api.get_bytes(
|
|
269
|
+
f"/models/{model_id}/versions/{version}/download",
|
|
270
|
+
params={"format": format},
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def train_from_remote(
|
|
274
|
+
self,
|
|
275
|
+
model: str,
|
|
276
|
+
local_train_fn: Any,
|
|
277
|
+
rounds: int = 1,
|
|
278
|
+
version: Optional[str] = None,
|
|
279
|
+
update_format: str = "weights",
|
|
280
|
+
format: str = "pytorch",
|
|
281
|
+
) -> list[dict[str, Any]]:
|
|
282
|
+
self.register()
|
|
283
|
+
model_id = self._resolve_model_id(model)
|
|
284
|
+
if not version:
|
|
285
|
+
latest = self.api.get(f"/models/{model_id}/versions/latest")
|
|
286
|
+
version = latest.get("version")
|
|
287
|
+
if not version:
|
|
288
|
+
raise EdgeMLClientError("Failed to resolve model version")
|
|
289
|
+
|
|
290
|
+
results = []
|
|
291
|
+
for _ in range(rounds):
|
|
292
|
+
base_bytes = self.pull_model(model_id, version=version, format=format)
|
|
293
|
+
base_state = self._deserialize_weights(base_bytes)
|
|
294
|
+
updated_state, sample_count, metrics = local_train_fn(base_state)
|
|
295
|
+
if update_format == "delta":
|
|
296
|
+
updated_state = compute_state_dict_delta(base_state, updated_state)
|
|
297
|
+
weights_data = self._serialize_weights(updated_state)
|
|
298
|
+
payload = {
|
|
299
|
+
"model_id": model_id,
|
|
300
|
+
"version": version,
|
|
301
|
+
"device_id": self.device_id,
|
|
302
|
+
"sample_count": sample_count or 0,
|
|
303
|
+
"metrics": metrics or {},
|
|
304
|
+
"update_format": update_format,
|
|
305
|
+
"weights_data": base64.b64encode(weights_data).decode("ascii"),
|
|
306
|
+
}
|
|
307
|
+
results.append(self.api.post("/training/weights", payload))
|
|
308
|
+
return results
|
|
309
|
+
|
|
310
|
+
def _resolve_model_id(self, model: str) -> str:
|
|
311
|
+
data = self.api.get("/models", params={"org_id": self.org_id})
|
|
312
|
+
for item in data.get("models", []):
|
|
313
|
+
if item.get("name") == model:
|
|
314
|
+
return item["id"]
|
|
315
|
+
return model
|
|
316
|
+
|
|
317
|
+
def _serialize_weights(self, weights: Any) -> bytes:
|
|
318
|
+
if isinstance(weights, (bytes, bytearray)):
|
|
319
|
+
return bytes(weights)
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
import torch # type: ignore
|
|
323
|
+
except Exception:
|
|
324
|
+
torch = None
|
|
325
|
+
|
|
326
|
+
if torch is not None:
|
|
327
|
+
if isinstance(weights, torch.nn.Module):
|
|
328
|
+
import io
|
|
329
|
+
buffer = io.BytesIO()
|
|
330
|
+
torch.save(weights.state_dict(), buffer)
|
|
331
|
+
return buffer.getvalue()
|
|
332
|
+
if isinstance(weights, dict):
|
|
333
|
+
import io
|
|
334
|
+
buffer = io.BytesIO()
|
|
335
|
+
torch.save(weights, buffer)
|
|
336
|
+
return buffer.getvalue()
|
|
337
|
+
|
|
338
|
+
raise EdgeMLClientError(
|
|
339
|
+
"local_data must be bytes, a torch.nn.Module, a state_dict dict, "
|
|
340
|
+
"or a callable returning (weights, sample_count, metrics)"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def _deserialize_weights(self, payload: bytes) -> dict:
|
|
344
|
+
try:
|
|
345
|
+
import torch # type: ignore
|
|
346
|
+
except Exception as exc:
|
|
347
|
+
raise EdgeMLClientError("torch is required to load remote weights") from exc
|
|
348
|
+
buffer = io.BytesIO(payload)
|
|
349
|
+
state = torch.load(buffer, map_location="cpu")
|
|
350
|
+
if not isinstance(state, dict):
|
|
351
|
+
raise EdgeMLClientError("Remote payload was not a state_dict")
|
|
352
|
+
return state
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def compute_state_dict_delta(base_state: dict, updated_state: dict) -> dict:
|
|
356
|
+
"""
|
|
357
|
+
Compute a delta state_dict = updated - base.
|
|
358
|
+
|
|
359
|
+
Intended for small demo models (fits in memory).
|
|
360
|
+
"""
|
|
361
|
+
try:
|
|
362
|
+
import torch # type: ignore
|
|
363
|
+
except Exception as exc:
|
|
364
|
+
raise EdgeMLClientError("torch is required to compute state_dict deltas") from exc
|
|
365
|
+
|
|
366
|
+
delta: dict = {}
|
|
367
|
+
for key, base_tensor in base_state.items():
|
|
368
|
+
updated_tensor = updated_state.get(key)
|
|
369
|
+
if updated_tensor is None:
|
|
370
|
+
continue
|
|
371
|
+
if torch.is_tensor(base_tensor) and torch.is_tensor(updated_tensor):
|
|
372
|
+
delta[key] = updated_tensor.detach().cpu() - base_tensor.detach().cpu()
|
|
373
|
+
return delta
|
|
374
|
+
|
|
375
|
+
def deploy(
|
|
376
|
+
self,
|
|
377
|
+
model_id: str,
|
|
378
|
+
version: Optional[str] = None,
|
|
379
|
+
rollout_percentage: int = 10,
|
|
380
|
+
target_percentage: int = 100,
|
|
381
|
+
increment_step: int = 10,
|
|
382
|
+
start_immediately: bool = True,
|
|
383
|
+
) -> dict[str, Any]:
|
|
384
|
+
if not version:
|
|
385
|
+
latest = self.api.get(f"/models/{model_id}/versions/latest")
|
|
386
|
+
version = latest.get("version")
|
|
387
|
+
if not version:
|
|
388
|
+
raise EdgeMLClientError("Failed to resolve model version")
|
|
389
|
+
|
|
390
|
+
payload = {
|
|
391
|
+
"version": version,
|
|
392
|
+
"rollout_percentage": rollout_percentage,
|
|
393
|
+
"target_percentage": target_percentage,
|
|
394
|
+
"increment_step": increment_step,
|
|
395
|
+
"start_immediately": start_immediately,
|
|
396
|
+
}
|
|
397
|
+
return self.api.post(f"/models/{model_id}/rollouts", payload)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
class ModelRegistry:
|
|
401
|
+
def __init__(
|
|
402
|
+
self,
|
|
403
|
+
api_key: str,
|
|
404
|
+
org_id: str = "default",
|
|
405
|
+
api_base: str = "https://api.edgeml.io/api/v1",
|
|
406
|
+
timeout: float = 60.0,
|
|
407
|
+
):
|
|
408
|
+
self.api = _ApiClient(api_key=api_key, api_base=api_base, timeout=timeout)
|
|
409
|
+
self.org_id = org_id
|
|
410
|
+
|
|
411
|
+
def resolve_model_id(self, model: str) -> str:
|
|
412
|
+
data = self.api.get("/models", params={"org_id": self.org_id})
|
|
413
|
+
for item in data.get("models", []):
|
|
414
|
+
if item.get("name") == model:
|
|
415
|
+
return item["id"]
|
|
416
|
+
return model
|
|
417
|
+
|
|
418
|
+
def ensure_model(
|
|
419
|
+
self,
|
|
420
|
+
name: str,
|
|
421
|
+
framework: str,
|
|
422
|
+
use_case: str,
|
|
423
|
+
description: str | None = None,
|
|
424
|
+
) -> dict[str, Any]:
|
|
425
|
+
data = self.api.get("/models", params={"org_id": self.org_id})
|
|
426
|
+
for item in data.get("models", []):
|
|
427
|
+
if item.get("name") == name:
|
|
428
|
+
return item
|
|
429
|
+
payload = {
|
|
430
|
+
"name": name,
|
|
431
|
+
"description": description or "",
|
|
432
|
+
"framework": framework,
|
|
433
|
+
"use_case": use_case,
|
|
434
|
+
"org_id": self.org_id,
|
|
435
|
+
}
|
|
436
|
+
return self.api.post("/models", payload)
|
|
437
|
+
|
|
438
|
+
def upload_version_from_path(
|
|
439
|
+
self,
|
|
440
|
+
model_id: str,
|
|
441
|
+
file_path: str,
|
|
442
|
+
version: str,
|
|
443
|
+
description: str | None = None,
|
|
444
|
+
formats: str | None = None,
|
|
445
|
+
onnx_data_path: str | None = None,
|
|
446
|
+
architecture: str | None = None,
|
|
447
|
+
input_dim: int | None = None,
|
|
448
|
+
hidden_dim: int | None = None,
|
|
449
|
+
output_dim: int | None = None,
|
|
450
|
+
) -> dict[str, Any]:
|
|
451
|
+
data: dict[str, Any] = {"version": version}
|
|
452
|
+
if description:
|
|
453
|
+
data["description"] = description
|
|
454
|
+
if formats:
|
|
455
|
+
data["formats"] = formats
|
|
456
|
+
if architecture:
|
|
457
|
+
data["architecture"] = architecture
|
|
458
|
+
if input_dim is not None:
|
|
459
|
+
data["input_dim"] = str(input_dim)
|
|
460
|
+
if hidden_dim is not None:
|
|
461
|
+
data["hidden_dim"] = str(hidden_dim)
|
|
462
|
+
if output_dim is not None:
|
|
463
|
+
data["output_dim"] = str(output_dim)
|
|
464
|
+
|
|
465
|
+
with contextlib.ExitStack() as stack:
|
|
466
|
+
files: dict[str, Any] = {"file": stack.enter_context(open(file_path, "rb"))}
|
|
467
|
+
if onnx_data_path:
|
|
468
|
+
files["onnx_data"] = stack.enter_context(open(onnx_data_path, "rb"))
|
|
469
|
+
with httpx.Client(timeout=self.api.timeout) as client:
|
|
470
|
+
res = client.post(
|
|
471
|
+
f"{self.api.api_base}/models/{model_id}/versions/upload",
|
|
472
|
+
data=data,
|
|
473
|
+
files=files,
|
|
474
|
+
headers=self.api._headers(),
|
|
475
|
+
)
|
|
476
|
+
if res.status_code >= 400:
|
|
477
|
+
raise EdgeMLClientError(res.text)
|
|
478
|
+
return res.json()
|
|
479
|
+
|
|
480
|
+
def publish_version(self, model_id: str, version: str) -> dict[str, Any]:
|
|
481
|
+
return self.api.post(f"/models/{model_id}/versions/{version}/publish", {})
|
|
482
|
+
|
|
483
|
+
def create_rollout(
|
|
484
|
+
self,
|
|
485
|
+
model_id: str,
|
|
486
|
+
version: str,
|
|
487
|
+
rollout_percentage: int = 10,
|
|
488
|
+
target_percentage: int = 100,
|
|
489
|
+
increment_step: int = 10,
|
|
490
|
+
start_immediately: bool = True,
|
|
491
|
+
) -> dict[str, Any]:
|
|
492
|
+
payload = {
|
|
493
|
+
"version": version,
|
|
494
|
+
"rollout_percentage": rollout_percentage,
|
|
495
|
+
"target_percentage": target_percentage,
|
|
496
|
+
"increment_step": increment_step,
|
|
497
|
+
"start_immediately": start_immediately,
|
|
498
|
+
}
|
|
499
|
+
return self.api.post(f"/models/{model_id}/rollouts", payload)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: edgeml-sdk
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: EdgeML Python SDK
|
|
5
|
+
Requires-Python: >=3.9
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: httpx>=0.25.0
|
|
8
|
+
|
|
9
|
+
# EdgeML Python SDK
|
|
10
|
+
|
|
11
|
+
Minimal Python SDK wrapper for the EdgeML REST API.
|
|
12
|
+
|
|
13
|
+
## Quickstart
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from edgeml import Federation, FederatedClient
|
|
17
|
+
|
|
18
|
+
# Admin / coordinator
|
|
19
|
+
fed = Federation(api_key="ek_live_...", org_id="default")
|
|
20
|
+
fed.invite(org_ids=["org_hospital_a", "org_hospital_b"])
|
|
21
|
+
fed.train("tumor_detection", rounds=10)
|
|
22
|
+
fed.deploy()
|
|
23
|
+
|
|
24
|
+
# Client device
|
|
25
|
+
client = FederatedClient(api_key="ek_live_...", org_id="org_hospital_a")
|
|
26
|
+
client.join_federation("cancer_research_consortium")
|
|
27
|
+
client.train("tumor_detection", local_data, rounds=10)
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Install (local dev)
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install -e .
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Publish (automated)
|
|
37
|
+
|
|
38
|
+
Tag a release with `edgeml-vX.Y.Z` and push the tag:
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
git tag edgeml-v0.1.1
|
|
42
|
+
git push origin edgeml-v0.1.1
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
The GitHub Action `Publish Python SDK` will build and publish to PyPI. Configure
|
|
46
|
+
PyPI trusted publishing for this repo (or set the `PYPI_API_TOKEN` secret and
|
|
47
|
+
update the workflow to use it).
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
httpx>=0.25.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
edgeml
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "edgeml-sdk"
|
|
7
|
+
version = "0.1.2"
|
|
8
|
+
description = "EdgeML Python SDK"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"httpx>=0.25.0",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
[tool.setuptools.packages.find]
|
|
16
|
+
where = ["."]
|
|
17
|
+
include = ["edgeml*"]
|