flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +19 -14
- flwr/cli/new/new.py +51 -22
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +3 -1
- flwr/client/app.py +20 -142
- flwr/client/grpc_client/connection.py +8 -2
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +33 -4
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +92 -169
- flwr/client/supernode/__init__.py +24 -0
- flwr/client/supernode/app.py +281 -0
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/common/telemetry.py +4 -0
- flwr/server/app.py +116 -6
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -70
- flwr/server/driver/__init__.py +2 -1
- flwr/server/driver/driver.py +12 -139
- flwr/server/driver/grpc_driver.py +199 -13
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +89 -12
- flwr/server/superlink/state/sqlite_state.py +133 -16
- flwr/server/superlink/state/state.py +56 -6
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
|
@@ -293,7 +293,7 @@ def start_vce(
|
|
|
293
293
|
node_states[node_id] = NodeState()
|
|
294
294
|
|
|
295
295
|
# Load backend config
|
|
296
|
-
log(
|
|
296
|
+
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
297
297
|
backend_config = json.loads(backend_config_json_stream)
|
|
298
298
|
|
|
299
299
|
try:
|
|
@@ -30,15 +30,24 @@ from flwr.server.utils import validate_task_ins_or_res
|
|
|
30
30
|
from .utils import make_node_unavailable_taskres
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class InMemoryState(State):
|
|
33
|
+
class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
34
34
|
"""In-memory State implementation."""
|
|
35
35
|
|
|
36
36
|
def __init__(self) -> None:
|
|
37
|
+
|
|
37
38
|
# Map node_id to (online_until, ping_interval)
|
|
38
39
|
self.node_ids: Dict[int, Tuple[float, float]] = {}
|
|
39
|
-
self.
|
|
40
|
+
self.public_key_to_node_id: Dict[bytes, int] = {}
|
|
41
|
+
|
|
42
|
+
# Map run_id to (fab_id, fab_version)
|
|
43
|
+
self.run_ids: Dict[int, Tuple[str, str]] = {}
|
|
40
44
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
41
45
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
46
|
+
|
|
47
|
+
self.client_public_keys: Set[bytes] = set()
|
|
48
|
+
self.server_public_key: Optional[bytes] = None
|
|
49
|
+
self.server_private_key: Optional[bytes] = None
|
|
50
|
+
|
|
42
51
|
self.lock = threading.Lock()
|
|
43
52
|
|
|
44
53
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
@@ -201,23 +210,46 @@ class InMemoryState(State):
|
|
|
201
210
|
"""
|
|
202
211
|
return len(self.task_res_store)
|
|
203
212
|
|
|
204
|
-
def create_node(
|
|
213
|
+
def create_node(
|
|
214
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
215
|
+
) -> int:
|
|
205
216
|
"""Create, store in state, and return `node_id`."""
|
|
206
217
|
# Sample a random int64 as node_id
|
|
207
218
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
208
219
|
|
|
209
220
|
with self.lock:
|
|
210
|
-
if node_id
|
|
211
|
-
|
|
212
|
-
return
|
|
213
|
-
|
|
214
|
-
|
|
221
|
+
if node_id in self.node_ids:
|
|
222
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
223
|
+
return 0
|
|
224
|
+
|
|
225
|
+
if public_key is not None:
|
|
226
|
+
if (
|
|
227
|
+
public_key in self.public_key_to_node_id
|
|
228
|
+
or node_id in self.public_key_to_node_id.values()
|
|
229
|
+
):
|
|
230
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
231
|
+
return 0
|
|
215
232
|
|
|
216
|
-
|
|
233
|
+
self.public_key_to_node_id[public_key] = node_id
|
|
234
|
+
|
|
235
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
236
|
+
return node_id
|
|
237
|
+
|
|
238
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
217
239
|
"""Delete a client node."""
|
|
218
240
|
with self.lock:
|
|
219
241
|
if node_id not in self.node_ids:
|
|
220
242
|
raise ValueError(f"Node {node_id} not found")
|
|
243
|
+
|
|
244
|
+
if public_key is not None:
|
|
245
|
+
if (
|
|
246
|
+
public_key not in self.public_key_to_node_id
|
|
247
|
+
or node_id not in self.public_key_to_node_id.values()
|
|
248
|
+
):
|
|
249
|
+
raise ValueError("Public key or node_id not found")
|
|
250
|
+
|
|
251
|
+
del self.public_key_to_node_id[public_key]
|
|
252
|
+
|
|
221
253
|
del self.node_ids[node_id]
|
|
222
254
|
|
|
223
255
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
@@ -238,18 +270,63 @@ class InMemoryState(State):
|
|
|
238
270
|
if online_until > current_time
|
|
239
271
|
}
|
|
240
272
|
|
|
241
|
-
def
|
|
242
|
-
"""
|
|
273
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
274
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
275
|
+
return self.public_key_to_node_id.get(client_public_key)
|
|
276
|
+
|
|
277
|
+
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
278
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
243
279
|
# Sample a random int64 as run_id
|
|
244
280
|
with self.lock:
|
|
245
281
|
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
246
282
|
|
|
247
283
|
if run_id not in self.run_ids:
|
|
248
|
-
self.run_ids
|
|
284
|
+
self.run_ids[run_id] = (fab_id, fab_version)
|
|
249
285
|
return run_id
|
|
250
286
|
log(ERROR, "Unexpected run creation failure.")
|
|
251
287
|
return 0
|
|
252
288
|
|
|
289
|
+
def store_server_private_public_key(
|
|
290
|
+
self, private_key: bytes, public_key: bytes
|
|
291
|
+
) -> None:
|
|
292
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
293
|
+
with self.lock:
|
|
294
|
+
if self.server_private_key is None and self.server_public_key is None:
|
|
295
|
+
self.server_private_key = private_key
|
|
296
|
+
self.server_public_key = public_key
|
|
297
|
+
else:
|
|
298
|
+
raise RuntimeError("Server private and public key already set")
|
|
299
|
+
|
|
300
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
301
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
302
|
+
return self.server_private_key
|
|
303
|
+
|
|
304
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
305
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
306
|
+
return self.server_public_key
|
|
307
|
+
|
|
308
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
309
|
+
"""Store a set of `client_public_keys` in state."""
|
|
310
|
+
with self.lock:
|
|
311
|
+
self.client_public_keys = public_keys
|
|
312
|
+
|
|
313
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
314
|
+
"""Store a `client_public_key` in state."""
|
|
315
|
+
with self.lock:
|
|
316
|
+
self.client_public_keys.add(public_key)
|
|
317
|
+
|
|
318
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
319
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
320
|
+
return self.client_public_keys
|
|
321
|
+
|
|
322
|
+
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
323
|
+
"""Retrieve information about the run with the specified `run_id`."""
|
|
324
|
+
with self.lock:
|
|
325
|
+
if run_id not in self.run_ids:
|
|
326
|
+
log(ERROR, "`run_id` is invalid")
|
|
327
|
+
return 0, "", ""
|
|
328
|
+
return run_id, *self.run_ids[run_id]
|
|
329
|
+
|
|
253
330
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
254
331
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
255
332
|
with self.lock:
|
|
@@ -20,7 +20,7 @@ import re
|
|
|
20
20
|
import sqlite3
|
|
21
21
|
import time
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
23
|
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
|
|
24
24
|
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
26
|
from flwr.common import log, now
|
|
@@ -36,7 +36,21 @@ SQL_CREATE_TABLE_NODE = """
|
|
|
36
36
|
CREATE TABLE IF NOT EXISTS node(
|
|
37
37
|
node_id INTEGER UNIQUE,
|
|
38
38
|
online_until REAL,
|
|
39
|
-
ping_interval REAL
|
|
39
|
+
ping_interval REAL,
|
|
40
|
+
public_key BLOB
|
|
41
|
+
);
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
SQL_CREATE_TABLE_CREDENTIAL = """
|
|
45
|
+
CREATE TABLE IF NOT EXISTS credential(
|
|
46
|
+
private_key BLOB PRIMARY KEY,
|
|
47
|
+
public_key BLOB
|
|
48
|
+
);
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
52
|
+
CREATE TABLE IF NOT EXISTS public_key(
|
|
53
|
+
public_key BLOB UNIQUE
|
|
40
54
|
);
|
|
41
55
|
"""
|
|
42
56
|
|
|
@@ -46,7 +60,9 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
|
46
60
|
|
|
47
61
|
SQL_CREATE_TABLE_RUN = """
|
|
48
62
|
CREATE TABLE IF NOT EXISTS run(
|
|
49
|
-
run_id
|
|
63
|
+
run_id INTEGER UNIQUE,
|
|
64
|
+
fab_id TEXT,
|
|
65
|
+
fab_version TEXT
|
|
50
66
|
);
|
|
51
67
|
"""
|
|
52
68
|
|
|
@@ -70,7 +86,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
70
86
|
);
|
|
71
87
|
"""
|
|
72
88
|
|
|
73
|
-
|
|
74
89
|
SQL_CREATE_TABLE_TASK_RES = """
|
|
75
90
|
CREATE TABLE IF NOT EXISTS task_res(
|
|
76
91
|
task_id TEXT UNIQUE,
|
|
@@ -94,7 +109,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
94
109
|
DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
|
|
95
110
|
|
|
96
111
|
|
|
97
|
-
class SqliteState(State):
|
|
112
|
+
class SqliteState(State): # pylint: disable=R0904
|
|
98
113
|
"""SQLite-based state implementation."""
|
|
99
114
|
|
|
100
115
|
def __init__(
|
|
@@ -132,6 +147,8 @@ class SqliteState(State):
|
|
|
132
147
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
133
148
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
134
149
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
150
|
+
cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
|
|
151
|
+
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
135
152
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
136
153
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
137
154
|
|
|
@@ -140,7 +157,7 @@ class SqliteState(State):
|
|
|
140
157
|
def query(
|
|
141
158
|
self,
|
|
142
159
|
query: str,
|
|
143
|
-
data: Optional[Union[
|
|
160
|
+
data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
|
|
144
161
|
) -> List[Dict[str, Any]]:
|
|
145
162
|
"""Execute a SQL query."""
|
|
146
163
|
if self.conn is None:
|
|
@@ -518,26 +535,54 @@ class SqliteState(State):
|
|
|
518
535
|
|
|
519
536
|
return None
|
|
520
537
|
|
|
521
|
-
def create_node(
|
|
538
|
+
def create_node(
|
|
539
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
540
|
+
) -> int:
|
|
522
541
|
"""Create, store in state, and return `node_id`."""
|
|
523
542
|
# Sample a random int64 as node_id
|
|
524
543
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
525
544
|
|
|
545
|
+
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
546
|
+
row = self.query(query, {"public_key": public_key})
|
|
547
|
+
|
|
548
|
+
if len(row) > 0:
|
|
549
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
550
|
+
return 0
|
|
551
|
+
|
|
526
552
|
query = (
|
|
527
|
-
"INSERT INTO node
|
|
553
|
+
"INSERT INTO node "
|
|
554
|
+
"(node_id, online_until, ping_interval, public_key) "
|
|
555
|
+
"VALUES (?, ?, ?, ?)"
|
|
528
556
|
)
|
|
529
557
|
|
|
530
558
|
try:
|
|
531
|
-
self.query(
|
|
559
|
+
self.query(
|
|
560
|
+
query, (node_id, time.time() + ping_interval, ping_interval, public_key)
|
|
561
|
+
)
|
|
532
562
|
except sqlite3.IntegrityError:
|
|
533
563
|
log(ERROR, "Unexpected node registration failure.")
|
|
534
564
|
return 0
|
|
535
565
|
return node_id
|
|
536
566
|
|
|
537
|
-
def delete_node(self, node_id: int) -> None:
|
|
567
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
538
568
|
"""Delete a client node."""
|
|
539
|
-
query = "DELETE FROM node WHERE node_id =
|
|
540
|
-
|
|
569
|
+
query = "DELETE FROM node WHERE node_id = ?"
|
|
570
|
+
params = (node_id,)
|
|
571
|
+
|
|
572
|
+
if public_key is not None:
|
|
573
|
+
query += " AND public_key = ?"
|
|
574
|
+
params += (public_key,) # type: ignore
|
|
575
|
+
|
|
576
|
+
if self.conn is None:
|
|
577
|
+
raise AttributeError("State is not initialized.")
|
|
578
|
+
|
|
579
|
+
try:
|
|
580
|
+
with self.conn:
|
|
581
|
+
rows = self.conn.execute(query, params)
|
|
582
|
+
if rows.rowcount < 1:
|
|
583
|
+
raise ValueError("Public key or node_id not found")
|
|
584
|
+
except KeyError as exc:
|
|
585
|
+
log(ERROR, {"query": query, "data": params, "exception": exc})
|
|
541
586
|
|
|
542
587
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
543
588
|
"""Retrieve all currently stored node IDs as a set.
|
|
@@ -558,8 +603,17 @@ class SqliteState(State):
|
|
|
558
603
|
result: Set[int] = {row["node_id"] for row in rows}
|
|
559
604
|
return result
|
|
560
605
|
|
|
561
|
-
def
|
|
562
|
-
"""
|
|
606
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
607
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
608
|
+
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
609
|
+
row = self.query(query, {"public_key": client_public_key})
|
|
610
|
+
if len(row) > 0:
|
|
611
|
+
node_id: int = row[0]["node_id"]
|
|
612
|
+
return node_id
|
|
613
|
+
return None
|
|
614
|
+
|
|
615
|
+
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
616
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
563
617
|
# Sample a random int64 as run_id
|
|
564
618
|
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
565
619
|
|
|
@@ -567,12 +621,75 @@ class SqliteState(State):
|
|
|
567
621
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
568
622
|
# If run_id does not exist
|
|
569
623
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
570
|
-
query = "INSERT INTO run
|
|
571
|
-
self.query(query,
|
|
624
|
+
query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
|
|
625
|
+
self.query(query, (run_id, fab_id, fab_version))
|
|
572
626
|
return run_id
|
|
573
627
|
log(ERROR, "Unexpected run creation failure.")
|
|
574
628
|
return 0
|
|
575
629
|
|
|
630
|
+
def store_server_private_public_key(
|
|
631
|
+
self, private_key: bytes, public_key: bytes
|
|
632
|
+
) -> None:
|
|
633
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
634
|
+
query = "SELECT COUNT(*) FROM credential"
|
|
635
|
+
count = self.query(query)[0]["COUNT(*)"]
|
|
636
|
+
if count < 1:
|
|
637
|
+
query = (
|
|
638
|
+
"INSERT OR REPLACE INTO credential (private_key, public_key) "
|
|
639
|
+
"VALUES (:private_key, :public_key)"
|
|
640
|
+
)
|
|
641
|
+
self.query(query, {"private_key": private_key, "public_key": public_key})
|
|
642
|
+
else:
|
|
643
|
+
raise RuntimeError("Server private and public key already set")
|
|
644
|
+
|
|
645
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
646
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
647
|
+
query = "SELECT private_key FROM credential"
|
|
648
|
+
rows = self.query(query)
|
|
649
|
+
try:
|
|
650
|
+
private_key: Optional[bytes] = rows[0]["private_key"]
|
|
651
|
+
except IndexError:
|
|
652
|
+
private_key = None
|
|
653
|
+
return private_key
|
|
654
|
+
|
|
655
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
656
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
657
|
+
query = "SELECT public_key FROM credential"
|
|
658
|
+
rows = self.query(query)
|
|
659
|
+
try:
|
|
660
|
+
public_key: Optional[bytes] = rows[0]["public_key"]
|
|
661
|
+
except IndexError:
|
|
662
|
+
public_key = None
|
|
663
|
+
return public_key
|
|
664
|
+
|
|
665
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
666
|
+
"""Store a set of `client_public_keys` in state."""
|
|
667
|
+
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
668
|
+
data = [(key,) for key in public_keys]
|
|
669
|
+
self.query(query, data)
|
|
670
|
+
|
|
671
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
672
|
+
"""Store a `client_public_key` in state."""
|
|
673
|
+
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
674
|
+
self.query(query, {"public_key": public_key})
|
|
675
|
+
|
|
676
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
677
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
678
|
+
query = "SELECT public_key FROM public_key"
|
|
679
|
+
rows = self.query(query)
|
|
680
|
+
result: Set[bytes] = {row["public_key"] for row in rows}
|
|
681
|
+
return result
|
|
682
|
+
|
|
683
|
+
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
684
|
+
"""Retrieve information about the run with the specified `run_id`."""
|
|
685
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
686
|
+
try:
|
|
687
|
+
row = self.query(query, (run_id,))[0]
|
|
688
|
+
return run_id, row["fab_id"], row["fab_version"]
|
|
689
|
+
except sqlite3.IntegrityError:
|
|
690
|
+
log(ERROR, "`run_id` does not exist.")
|
|
691
|
+
return 0, "", ""
|
|
692
|
+
|
|
576
693
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
577
694
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
578
695
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
@@ -16,13 +16,13 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import List, Optional, Set
|
|
19
|
+
from typing import List, Optional, Set, Tuple
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class State(abc.ABC):
|
|
25
|
+
class State(abc.ABC): # pylint: disable=R0904
|
|
26
26
|
"""Abstract State."""
|
|
27
27
|
|
|
28
28
|
@abc.abstractmethod
|
|
@@ -132,11 +132,13 @@ class State(abc.ABC):
|
|
|
132
132
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
133
133
|
|
|
134
134
|
@abc.abstractmethod
|
|
135
|
-
def create_node(
|
|
135
|
+
def create_node(
|
|
136
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
137
|
+
) -> int:
|
|
136
138
|
"""Create, store in state, and return `node_id`."""
|
|
137
139
|
|
|
138
140
|
@abc.abstractmethod
|
|
139
|
-
def delete_node(self, node_id: int) -> None:
|
|
141
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
140
142
|
"""Remove `node_id` from state."""
|
|
141
143
|
|
|
142
144
|
@abc.abstractmethod
|
|
@@ -150,8 +152,56 @@ class State(abc.ABC):
|
|
|
150
152
|
"""
|
|
151
153
|
|
|
152
154
|
@abc.abstractmethod
|
|
153
|
-
def
|
|
154
|
-
"""
|
|
155
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
156
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
157
|
+
|
|
158
|
+
@abc.abstractmethod
|
|
159
|
+
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
160
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
161
|
+
|
|
162
|
+
@abc.abstractmethod
|
|
163
|
+
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
164
|
+
"""Retrieve information about the run with the specified `run_id`.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
run_id : int
|
|
169
|
+
The identifier of the run.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
Tuple[int, str, str]
|
|
174
|
+
A tuple containing three elements:
|
|
175
|
+
- `run_id`: The identifier of the run, same as the specified `run_id`.
|
|
176
|
+
- `fab_id`: The identifier of the FAB used in the specified run.
|
|
177
|
+
- `fab_version`: The version of the FAB used in the specified run.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
@abc.abstractmethod
|
|
181
|
+
def store_server_private_public_key(
|
|
182
|
+
self, private_key: bytes, public_key: bytes
|
|
183
|
+
) -> None:
|
|
184
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
185
|
+
|
|
186
|
+
@abc.abstractmethod
|
|
187
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
188
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
189
|
+
|
|
190
|
+
@abc.abstractmethod
|
|
191
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
192
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
193
|
+
|
|
194
|
+
@abc.abstractmethod
|
|
195
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
196
|
+
"""Store a set of `client_public_keys` in state."""
|
|
197
|
+
|
|
198
|
+
@abc.abstractmethod
|
|
199
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
200
|
+
"""Store a `client_public_key` in state."""
|
|
201
|
+
|
|
202
|
+
@abc.abstractmethod
|
|
203
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
204
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
155
205
|
|
|
156
206
|
@abc.abstractmethod
|
|
157
207
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
flwr/simulation/__init__.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import importlib
|
|
19
19
|
|
|
20
|
-
from flwr.simulation.run_simulation import run_simulation
|
|
20
|
+
from flwr.simulation.run_simulation import run_simulation
|
|
21
21
|
|
|
22
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
|
23
23
|
|
|
@@ -36,4 +36,4 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
|
36
36
|
raise ImportError(RAY_IMPORT_ERROR)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
__all__ = ["start_simulation", "
|
|
39
|
+
__all__ = ["start_simulation", "run_simulation"]
|
flwr/simulation/app.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
"""Flower simulation app."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import asyncio
|
|
19
|
+
import logging
|
|
18
20
|
import sys
|
|
19
21
|
import threading
|
|
20
22
|
import traceback
|
|
@@ -27,7 +29,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
|
27
29
|
|
|
28
30
|
from flwr.client import ClientFn
|
|
29
31
|
from flwr.common import EventType, event
|
|
30
|
-
from flwr.common.logger import log
|
|
32
|
+
from flwr.common.logger import log, set_logger_propagation
|
|
31
33
|
from flwr.server.client_manager import ClientManager
|
|
32
34
|
from flwr.server.history import History
|
|
33
35
|
from flwr.server.server import Server, init_defaults, run_fl
|
|
@@ -156,6 +158,7 @@ def start_simulation(
|
|
|
156
158
|
is an advanced feature. For all details, please refer to the Ray documentation:
|
|
157
159
|
https://docs.ray.io/en/latest/ray-core/scheduling/index.html
|
|
158
160
|
|
|
161
|
+
|
|
159
162
|
Returns
|
|
160
163
|
-------
|
|
161
164
|
hist : flwr.server.history.History
|
|
@@ -167,6 +170,18 @@ def start_simulation(
|
|
|
167
170
|
{"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
|
|
168
171
|
)
|
|
169
172
|
|
|
173
|
+
# Set logger propagation
|
|
174
|
+
loop: Optional[asyncio.AbstractEventLoop] = None
|
|
175
|
+
try:
|
|
176
|
+
loop = asyncio.get_running_loop()
|
|
177
|
+
except RuntimeError:
|
|
178
|
+
loop = None
|
|
179
|
+
finally:
|
|
180
|
+
if loop and loop.is_running():
|
|
181
|
+
# Set logger propagation to False to prevent duplicated log output in Colab.
|
|
182
|
+
logger = logging.getLogger("flwr")
|
|
183
|
+
_ = set_logger_propagation(logger, False)
|
|
184
|
+
|
|
170
185
|
# Initialize server and server config
|
|
171
186
|
initialized_server, initialized_config = init_defaults(
|
|
172
187
|
server=server,
|
|
@@ -28,8 +28,9 @@ import grpc
|
|
|
28
28
|
|
|
29
29
|
from flwr.client import ClientApp
|
|
30
30
|
from flwr.common import EventType, event, log
|
|
31
|
+
from flwr.common.logger import set_logger_propagation, update_console_handler
|
|
31
32
|
from flwr.common.typing import ConfigsRecordValues
|
|
32
|
-
from flwr.server.driver
|
|
33
|
+
from flwr.server.driver import Driver, GrpcDriver
|
|
33
34
|
from flwr.server.run_serverapp import run
|
|
34
35
|
from flwr.server.server_app import ServerApp
|
|
35
36
|
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
|
@@ -154,7 +155,7 @@ def run_serverapp_th(
|
|
|
154
155
|
# Upon completion, trigger stop event if one was passed
|
|
155
156
|
if stop_event is not None:
|
|
156
157
|
stop_event.set()
|
|
157
|
-
log(
|
|
158
|
+
log(DEBUG, "Triggered stop event for Simulation Engine.")
|
|
158
159
|
|
|
159
160
|
serverapp_th = threading.Thread(
|
|
160
161
|
target=server_th_with_start_checks,
|
|
@@ -204,7 +205,7 @@ def _main_loop(
|
|
|
204
205
|
serverapp_th = None
|
|
205
206
|
try:
|
|
206
207
|
# Initialize Driver
|
|
207
|
-
driver =
|
|
208
|
+
driver = GrpcDriver(
|
|
208
209
|
driver_service_address=driver_api_address,
|
|
209
210
|
root_certificates=None,
|
|
210
211
|
)
|
|
@@ -248,7 +249,7 @@ def _main_loop(
|
|
|
248
249
|
if serverapp_th:
|
|
249
250
|
serverapp_th.join()
|
|
250
251
|
|
|
251
|
-
log(
|
|
252
|
+
log(DEBUG, "Stopping Simulation Engine now.")
|
|
252
253
|
|
|
253
254
|
|
|
254
255
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
@@ -317,9 +318,9 @@ def _run_simulation(
|
|
|
317
318
|
enabled, DEBUG-level logs will be displayed.
|
|
318
319
|
"""
|
|
319
320
|
# Set logging level
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
321
|
+
logger = logging.getLogger("flwr")
|
|
322
|
+
if verbose_logging:
|
|
323
|
+
update_console_handler(level=DEBUG, timestamps=True, colored=True)
|
|
323
324
|
|
|
324
325
|
if backend_config is None:
|
|
325
326
|
backend_config = {}
|
|
@@ -364,6 +365,8 @@ def _run_simulation(
|
|
|
364
365
|
|
|
365
366
|
finally:
|
|
366
367
|
if run_in_thread:
|
|
368
|
+
# Set logger propagation to False to prevent duplicated log output in Colab.
|
|
369
|
+
logger = set_logger_propagation(logger, False)
|
|
367
370
|
log(DEBUG, "Starting Simulation Engine on a new thread.")
|
|
368
371
|
simulation_engine_th = threading.Thread(target=_main_loop, args=args)
|
|
369
372
|
simulation_engine_th.start()
|
{flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: flwr-nightly
|
|
3
|
-
Version: 1.9.0.
|
|
3
|
+
Version: 1.9.0.dev20240507
|
|
4
4
|
Summary: Flower: A Friendly Federated Learning Framework
|
|
5
5
|
Home-page: https://flower.ai
|
|
6
6
|
License: Apache-2.0
|
|
@@ -36,6 +36,7 @@ Requires-Dist: cryptography (>=42.0.4,<43.0.0)
|
|
|
36
36
|
Requires-Dist: grpcio (>=1.60.0,<2.0.0)
|
|
37
37
|
Requires-Dist: iterators (>=0.0.2,<0.0.3)
|
|
38
38
|
Requires-Dist: numpy (>=1.21.0,<2.0.0)
|
|
39
|
+
Requires-Dist: pathspec (>=0.12.1,<0.13.0)
|
|
39
40
|
Requires-Dist: protobuf (>=4.25.2,<5.0.0)
|
|
40
41
|
Requires-Dist: pycryptodome (>=3.18.0,<4.0.0)
|
|
41
42
|
Requires-Dist: ray (==2.6.3) ; (python_version >= "3.8" and python_version < "3.12") and (extra == "simulation")
|
|
@@ -193,7 +194,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
|
|
|
193
194
|
- [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated)
|
|
194
195
|
- [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl)
|
|
195
196
|
- [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning)
|
|
196
|
-
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/
|
|
197
|
+
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/llm-flowertune)
|
|
197
198
|
- [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/vit-finetune)
|
|
198
199
|
- [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow)
|
|
199
200
|
- [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)
|