flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +42 -18
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/app.py +111 -1
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
|
16
16
|
|
|
17
17
|
import pathlib
|
|
18
|
-
from logging import ERROR, INFO
|
|
18
|
+
from logging import DEBUG, ERROR, INFO
|
|
19
19
|
from typing import Callable, Dict, List, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
@@ -46,7 +46,7 @@ class RayBackend(Backend):
|
|
|
46
46
|
) -> None:
|
|
47
47
|
"""Prepare RayBackend by initialising Ray and creating the ActorPool."""
|
|
48
48
|
log(INFO, "Initialising: %s", self.__class__.__name__)
|
|
49
|
-
log(
|
|
49
|
+
log(DEBUG, "Backend config: %s", backend_config)
|
|
50
50
|
|
|
51
51
|
if not pathlib.Path(work_dir).exists():
|
|
52
52
|
raise ValueError(f"Specified work_dir {work_dir} does not exist.")
|
|
@@ -109,7 +109,7 @@ class RayBackend(Backend):
|
|
|
109
109
|
else:
|
|
110
110
|
client_resources = {"num_cpus": 2, "num_gpus": 0.0}
|
|
111
111
|
log(
|
|
112
|
-
|
|
112
|
+
DEBUG,
|
|
113
113
|
"`%s` not specified in backend config. Applying default setting: %s",
|
|
114
114
|
self.client_resources_key,
|
|
115
115
|
client_resources,
|
|
@@ -129,7 +129,7 @@ class RayBackend(Backend):
|
|
|
129
129
|
async def build(self) -> None:
|
|
130
130
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
131
131
|
await self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
132
|
-
log(
|
|
132
|
+
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
133
133
|
|
|
134
134
|
async def process_message(
|
|
135
135
|
self,
|
|
@@ -173,4 +173,4 @@ class RayBackend(Backend):
|
|
|
173
173
|
"""Terminate all actors in actor pool."""
|
|
174
174
|
await self.pool.terminate_all_actors()
|
|
175
175
|
ray.shutdown()
|
|
176
|
-
log(
|
|
176
|
+
log(DEBUG, "Terminated %s", self.__class__.__name__)
|
|
@@ -293,7 +293,7 @@ def start_vce(
|
|
|
293
293
|
node_states[node_id] = NodeState()
|
|
294
294
|
|
|
295
295
|
# Load backend config
|
|
296
|
-
log(
|
|
296
|
+
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
297
297
|
backend_config = json.loads(backend_config_json_stream)
|
|
298
298
|
|
|
299
299
|
try:
|
|
@@ -30,16 +30,24 @@ from flwr.server.utils import validate_task_ins_or_res
|
|
|
30
30
|
from .utils import make_node_unavailable_taskres
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class InMemoryState(State):
|
|
33
|
+
class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
34
34
|
"""In-memory State implementation."""
|
|
35
35
|
|
|
36
36
|
def __init__(self) -> None:
|
|
37
|
+
|
|
37
38
|
# Map node_id to (online_until, ping_interval)
|
|
38
39
|
self.node_ids: Dict[int, Tuple[float, float]] = {}
|
|
40
|
+
self.public_key_to_node_id: Dict[bytes, int] = {}
|
|
41
|
+
|
|
39
42
|
# Map run_id to (fab_id, fab_version)
|
|
40
43
|
self.run_ids: Dict[int, Tuple[str, str]] = {}
|
|
41
44
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
42
45
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
46
|
+
|
|
47
|
+
self.client_public_keys: Set[bytes] = set()
|
|
48
|
+
self.server_public_key: Optional[bytes] = None
|
|
49
|
+
self.server_private_key: Optional[bytes] = None
|
|
50
|
+
|
|
43
51
|
self.lock = threading.Lock()
|
|
44
52
|
|
|
45
53
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
@@ -202,23 +210,46 @@ class InMemoryState(State):
|
|
|
202
210
|
"""
|
|
203
211
|
return len(self.task_res_store)
|
|
204
212
|
|
|
205
|
-
def create_node(
|
|
213
|
+
def create_node(
|
|
214
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
215
|
+
) -> int:
|
|
206
216
|
"""Create, store in state, and return `node_id`."""
|
|
207
217
|
# Sample a random int64 as node_id
|
|
208
218
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
209
219
|
|
|
210
220
|
with self.lock:
|
|
211
|
-
if node_id
|
|
212
|
-
|
|
213
|
-
return
|
|
214
|
-
|
|
215
|
-
|
|
221
|
+
if node_id in self.node_ids:
|
|
222
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
223
|
+
return 0
|
|
224
|
+
|
|
225
|
+
if public_key is not None:
|
|
226
|
+
if (
|
|
227
|
+
public_key in self.public_key_to_node_id
|
|
228
|
+
or node_id in self.public_key_to_node_id.values()
|
|
229
|
+
):
|
|
230
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
231
|
+
return 0
|
|
232
|
+
|
|
233
|
+
self.public_key_to_node_id[public_key] = node_id
|
|
234
|
+
|
|
235
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
236
|
+
return node_id
|
|
216
237
|
|
|
217
|
-
def delete_node(self, node_id: int) -> None:
|
|
238
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
218
239
|
"""Delete a client node."""
|
|
219
240
|
with self.lock:
|
|
220
241
|
if node_id not in self.node_ids:
|
|
221
242
|
raise ValueError(f"Node {node_id} not found")
|
|
243
|
+
|
|
244
|
+
if public_key is not None:
|
|
245
|
+
if (
|
|
246
|
+
public_key not in self.public_key_to_node_id
|
|
247
|
+
or node_id not in self.public_key_to_node_id.values()
|
|
248
|
+
):
|
|
249
|
+
raise ValueError("Public key or node_id not found")
|
|
250
|
+
|
|
251
|
+
del self.public_key_to_node_id[public_key]
|
|
252
|
+
|
|
222
253
|
del self.node_ids[node_id]
|
|
223
254
|
|
|
224
255
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
@@ -239,6 +270,10 @@ class InMemoryState(State):
|
|
|
239
270
|
if online_until > current_time
|
|
240
271
|
}
|
|
241
272
|
|
|
273
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
274
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
275
|
+
return self.public_key_to_node_id.get(client_public_key)
|
|
276
|
+
|
|
242
277
|
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
243
278
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
244
279
|
# Sample a random int64 as run_id
|
|
@@ -251,6 +286,39 @@ class InMemoryState(State):
|
|
|
251
286
|
log(ERROR, "Unexpected run creation failure.")
|
|
252
287
|
return 0
|
|
253
288
|
|
|
289
|
+
def store_server_private_public_key(
|
|
290
|
+
self, private_key: bytes, public_key: bytes
|
|
291
|
+
) -> None:
|
|
292
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
293
|
+
with self.lock:
|
|
294
|
+
if self.server_private_key is None and self.server_public_key is None:
|
|
295
|
+
self.server_private_key = private_key
|
|
296
|
+
self.server_public_key = public_key
|
|
297
|
+
else:
|
|
298
|
+
raise RuntimeError("Server private and public key already set")
|
|
299
|
+
|
|
300
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
301
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
302
|
+
return self.server_private_key
|
|
303
|
+
|
|
304
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
305
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
306
|
+
return self.server_public_key
|
|
307
|
+
|
|
308
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
309
|
+
"""Store a set of `client_public_keys` in state."""
|
|
310
|
+
with self.lock:
|
|
311
|
+
self.client_public_keys = public_keys
|
|
312
|
+
|
|
313
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
314
|
+
"""Store a `client_public_key` in state."""
|
|
315
|
+
with self.lock:
|
|
316
|
+
self.client_public_keys.add(public_key)
|
|
317
|
+
|
|
318
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
319
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
320
|
+
return self.client_public_keys
|
|
321
|
+
|
|
254
322
|
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
255
323
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
256
324
|
with self.lock:
|
|
@@ -20,7 +20,7 @@ import re
|
|
|
20
20
|
import sqlite3
|
|
21
21
|
import time
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
23
|
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
|
|
24
24
|
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
26
|
from flwr.common import log, now
|
|
@@ -36,7 +36,21 @@ SQL_CREATE_TABLE_NODE = """
|
|
|
36
36
|
CREATE TABLE IF NOT EXISTS node(
|
|
37
37
|
node_id INTEGER UNIQUE,
|
|
38
38
|
online_until REAL,
|
|
39
|
-
ping_interval REAL
|
|
39
|
+
ping_interval REAL,
|
|
40
|
+
public_key BLOB
|
|
41
|
+
);
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
SQL_CREATE_TABLE_CREDENTIAL = """
|
|
45
|
+
CREATE TABLE IF NOT EXISTS credential(
|
|
46
|
+
private_key BLOB PRIMARY KEY,
|
|
47
|
+
public_key BLOB
|
|
48
|
+
);
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
52
|
+
CREATE TABLE IF NOT EXISTS public_key(
|
|
53
|
+
public_key BLOB UNIQUE
|
|
40
54
|
);
|
|
41
55
|
"""
|
|
42
56
|
|
|
@@ -72,7 +86,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
72
86
|
);
|
|
73
87
|
"""
|
|
74
88
|
|
|
75
|
-
|
|
76
89
|
SQL_CREATE_TABLE_TASK_RES = """
|
|
77
90
|
CREATE TABLE IF NOT EXISTS task_res(
|
|
78
91
|
task_id TEXT UNIQUE,
|
|
@@ -96,7 +109,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
96
109
|
DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
|
|
97
110
|
|
|
98
111
|
|
|
99
|
-
class SqliteState(State):
|
|
112
|
+
class SqliteState(State): # pylint: disable=R0904
|
|
100
113
|
"""SQLite-based state implementation."""
|
|
101
114
|
|
|
102
115
|
def __init__(
|
|
@@ -134,6 +147,8 @@ class SqliteState(State):
|
|
|
134
147
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
135
148
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
136
149
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
150
|
+
cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
|
|
151
|
+
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
137
152
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
138
153
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
139
154
|
|
|
@@ -142,7 +157,7 @@ class SqliteState(State):
|
|
|
142
157
|
def query(
|
|
143
158
|
self,
|
|
144
159
|
query: str,
|
|
145
|
-
data: Optional[Union[
|
|
160
|
+
data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
|
|
146
161
|
) -> List[Dict[str, Any]]:
|
|
147
162
|
"""Execute a SQL query."""
|
|
148
163
|
if self.conn is None:
|
|
@@ -520,26 +535,54 @@ class SqliteState(State):
|
|
|
520
535
|
|
|
521
536
|
return None
|
|
522
537
|
|
|
523
|
-
def create_node(
|
|
538
|
+
def create_node(
|
|
539
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
540
|
+
) -> int:
|
|
524
541
|
"""Create, store in state, and return `node_id`."""
|
|
525
542
|
# Sample a random int64 as node_id
|
|
526
543
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
527
544
|
|
|
545
|
+
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
546
|
+
row = self.query(query, {"public_key": public_key})
|
|
547
|
+
|
|
548
|
+
if len(row) > 0:
|
|
549
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
550
|
+
return 0
|
|
551
|
+
|
|
528
552
|
query = (
|
|
529
|
-
"INSERT INTO node
|
|
553
|
+
"INSERT INTO node "
|
|
554
|
+
"(node_id, online_until, ping_interval, public_key) "
|
|
555
|
+
"VALUES (?, ?, ?, ?)"
|
|
530
556
|
)
|
|
531
557
|
|
|
532
558
|
try:
|
|
533
|
-
self.query(
|
|
559
|
+
self.query(
|
|
560
|
+
query, (node_id, time.time() + ping_interval, ping_interval, public_key)
|
|
561
|
+
)
|
|
534
562
|
except sqlite3.IntegrityError:
|
|
535
563
|
log(ERROR, "Unexpected node registration failure.")
|
|
536
564
|
return 0
|
|
537
565
|
return node_id
|
|
538
566
|
|
|
539
|
-
def delete_node(self, node_id: int) -> None:
|
|
567
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
540
568
|
"""Delete a client node."""
|
|
541
|
-
query = "DELETE FROM node WHERE node_id =
|
|
542
|
-
|
|
569
|
+
query = "DELETE FROM node WHERE node_id = ?"
|
|
570
|
+
params = (node_id,)
|
|
571
|
+
|
|
572
|
+
if public_key is not None:
|
|
573
|
+
query += " AND public_key = ?"
|
|
574
|
+
params += (public_key,) # type: ignore
|
|
575
|
+
|
|
576
|
+
if self.conn is None:
|
|
577
|
+
raise AttributeError("State is not initialized.")
|
|
578
|
+
|
|
579
|
+
try:
|
|
580
|
+
with self.conn:
|
|
581
|
+
rows = self.conn.execute(query, params)
|
|
582
|
+
if rows.rowcount < 1:
|
|
583
|
+
raise ValueError("Public key or node_id not found")
|
|
584
|
+
except KeyError as exc:
|
|
585
|
+
log(ERROR, {"query": query, "data": params, "exception": exc})
|
|
543
586
|
|
|
544
587
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
545
588
|
"""Retrieve all currently stored node IDs as a set.
|
|
@@ -560,6 +603,15 @@ class SqliteState(State):
|
|
|
560
603
|
result: Set[int] = {row["node_id"] for row in rows}
|
|
561
604
|
return result
|
|
562
605
|
|
|
606
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
607
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
608
|
+
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
609
|
+
row = self.query(query, {"public_key": client_public_key})
|
|
610
|
+
if len(row) > 0:
|
|
611
|
+
node_id: int = row[0]["node_id"]
|
|
612
|
+
return node_id
|
|
613
|
+
return None
|
|
614
|
+
|
|
563
615
|
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
564
616
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
565
617
|
# Sample a random int64 as run_id
|
|
@@ -575,6 +627,59 @@ class SqliteState(State):
|
|
|
575
627
|
log(ERROR, "Unexpected run creation failure.")
|
|
576
628
|
return 0
|
|
577
629
|
|
|
630
|
+
def store_server_private_public_key(
|
|
631
|
+
self, private_key: bytes, public_key: bytes
|
|
632
|
+
) -> None:
|
|
633
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
634
|
+
query = "SELECT COUNT(*) FROM credential"
|
|
635
|
+
count = self.query(query)[0]["COUNT(*)"]
|
|
636
|
+
if count < 1:
|
|
637
|
+
query = (
|
|
638
|
+
"INSERT OR REPLACE INTO credential (private_key, public_key) "
|
|
639
|
+
"VALUES (:private_key, :public_key)"
|
|
640
|
+
)
|
|
641
|
+
self.query(query, {"private_key": private_key, "public_key": public_key})
|
|
642
|
+
else:
|
|
643
|
+
raise RuntimeError("Server private and public key already set")
|
|
644
|
+
|
|
645
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
646
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
647
|
+
query = "SELECT private_key FROM credential"
|
|
648
|
+
rows = self.query(query)
|
|
649
|
+
try:
|
|
650
|
+
private_key: Optional[bytes] = rows[0]["private_key"]
|
|
651
|
+
except IndexError:
|
|
652
|
+
private_key = None
|
|
653
|
+
return private_key
|
|
654
|
+
|
|
655
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
656
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
657
|
+
query = "SELECT public_key FROM credential"
|
|
658
|
+
rows = self.query(query)
|
|
659
|
+
try:
|
|
660
|
+
public_key: Optional[bytes] = rows[0]["public_key"]
|
|
661
|
+
except IndexError:
|
|
662
|
+
public_key = None
|
|
663
|
+
return public_key
|
|
664
|
+
|
|
665
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
666
|
+
"""Store a set of `client_public_keys` in state."""
|
|
667
|
+
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
668
|
+
data = [(key,) for key in public_keys]
|
|
669
|
+
self.query(query, data)
|
|
670
|
+
|
|
671
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
672
|
+
"""Store a `client_public_key` in state."""
|
|
673
|
+
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
674
|
+
self.query(query, {"public_key": public_key})
|
|
675
|
+
|
|
676
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
677
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
678
|
+
query = "SELECT public_key FROM public_key"
|
|
679
|
+
rows = self.query(query)
|
|
680
|
+
result: Set[bytes] = {row["public_key"] for row in rows}
|
|
681
|
+
return result
|
|
682
|
+
|
|
578
683
|
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
579
684
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
580
685
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
@@ -22,7 +22,7 @@ from uuid import UUID
|
|
|
22
22
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class State(abc.ABC):
|
|
25
|
+
class State(abc.ABC): # pylint: disable=R0904
|
|
26
26
|
"""Abstract State."""
|
|
27
27
|
|
|
28
28
|
@abc.abstractmethod
|
|
@@ -132,11 +132,13 @@ class State(abc.ABC):
|
|
|
132
132
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
133
133
|
|
|
134
134
|
@abc.abstractmethod
|
|
135
|
-
def create_node(
|
|
135
|
+
def create_node(
|
|
136
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
137
|
+
) -> int:
|
|
136
138
|
"""Create, store in state, and return `node_id`."""
|
|
137
139
|
|
|
138
140
|
@abc.abstractmethod
|
|
139
|
-
def delete_node(self, node_id: int) -> None:
|
|
141
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
140
142
|
"""Remove `node_id` from state."""
|
|
141
143
|
|
|
142
144
|
@abc.abstractmethod
|
|
@@ -149,6 +151,10 @@ class State(abc.ABC):
|
|
|
149
151
|
an empty `Set` MUST be returned.
|
|
150
152
|
"""
|
|
151
153
|
|
|
154
|
+
@abc.abstractmethod
|
|
155
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
156
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
157
|
+
|
|
152
158
|
@abc.abstractmethod
|
|
153
159
|
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
154
160
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
@@ -171,6 +177,32 @@ class State(abc.ABC):
|
|
|
171
177
|
- `fab_version`: The version of the FAB used in the specified run.
|
|
172
178
|
"""
|
|
173
179
|
|
|
180
|
+
@abc.abstractmethod
|
|
181
|
+
def store_server_private_public_key(
|
|
182
|
+
self, private_key: bytes, public_key: bytes
|
|
183
|
+
) -> None:
|
|
184
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
185
|
+
|
|
186
|
+
@abc.abstractmethod
|
|
187
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
188
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
189
|
+
|
|
190
|
+
@abc.abstractmethod
|
|
191
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
192
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
193
|
+
|
|
194
|
+
@abc.abstractmethod
|
|
195
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
196
|
+
"""Store a set of `client_public_keys` in state."""
|
|
197
|
+
|
|
198
|
+
@abc.abstractmethod
|
|
199
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
200
|
+
"""Store a `client_public_key` in state."""
|
|
201
|
+
|
|
202
|
+
@abc.abstractmethod
|
|
203
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
204
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
205
|
+
|
|
174
206
|
@abc.abstractmethod
|
|
175
207
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
176
208
|
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
flwr/simulation/__init__.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import importlib
|
|
19
19
|
|
|
20
|
-
from flwr.simulation.run_simulation import run_simulation
|
|
20
|
+
from flwr.simulation.run_simulation import run_simulation
|
|
21
21
|
|
|
22
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
|
23
23
|
|
|
@@ -36,4 +36,4 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
|
36
36
|
raise ImportError(RAY_IMPORT_ERROR)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
__all__ = ["start_simulation", "
|
|
39
|
+
__all__ = ["start_simulation", "run_simulation"]
|
flwr/simulation/app.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
"""Flower simulation app."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import asyncio
|
|
19
|
+
import logging
|
|
18
20
|
import sys
|
|
19
21
|
import threading
|
|
20
22
|
import traceback
|
|
@@ -27,7 +29,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
|
27
29
|
|
|
28
30
|
from flwr.client import ClientFn
|
|
29
31
|
from flwr.common import EventType, event
|
|
30
|
-
from flwr.common.logger import log
|
|
32
|
+
from flwr.common.logger import log, set_logger_propagation
|
|
31
33
|
from flwr.server.client_manager import ClientManager
|
|
32
34
|
from flwr.server.history import History
|
|
33
35
|
from flwr.server.server import Server, init_defaults, run_fl
|
|
@@ -156,6 +158,7 @@ def start_simulation(
|
|
|
156
158
|
is an advanced feature. For all details, please refer to the Ray documentation:
|
|
157
159
|
https://docs.ray.io/en/latest/ray-core/scheduling/index.html
|
|
158
160
|
|
|
161
|
+
|
|
159
162
|
Returns
|
|
160
163
|
-------
|
|
161
164
|
hist : flwr.server.history.History
|
|
@@ -167,6 +170,18 @@ def start_simulation(
|
|
|
167
170
|
{"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
|
|
168
171
|
)
|
|
169
172
|
|
|
173
|
+
# Set logger propagation
|
|
174
|
+
loop: Optional[asyncio.AbstractEventLoop] = None
|
|
175
|
+
try:
|
|
176
|
+
loop = asyncio.get_running_loop()
|
|
177
|
+
except RuntimeError:
|
|
178
|
+
loop = None
|
|
179
|
+
finally:
|
|
180
|
+
if loop and loop.is_running():
|
|
181
|
+
# Set logger propagation to False to prevent duplicated log output in Colab.
|
|
182
|
+
logger = logging.getLogger("flwr")
|
|
183
|
+
_ = set_logger_propagation(logger, False)
|
|
184
|
+
|
|
170
185
|
# Initialize server and server config
|
|
171
186
|
initialized_server, initialized_config = init_defaults(
|
|
172
187
|
server=server,
|
|
@@ -28,8 +28,9 @@ import grpc
|
|
|
28
28
|
|
|
29
29
|
from flwr.client import ClientApp
|
|
30
30
|
from flwr.common import EventType, event, log
|
|
31
|
+
from flwr.common.logger import set_logger_propagation, update_console_handler
|
|
31
32
|
from flwr.common.typing import ConfigsRecordValues
|
|
32
|
-
from flwr.server.driver
|
|
33
|
+
from flwr.server.driver import Driver, GrpcDriver
|
|
33
34
|
from flwr.server.run_serverapp import run
|
|
34
35
|
from flwr.server.server_app import ServerApp
|
|
35
36
|
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
|
@@ -154,7 +155,7 @@ def run_serverapp_th(
|
|
|
154
155
|
# Upon completion, trigger stop event if one was passed
|
|
155
156
|
if stop_event is not None:
|
|
156
157
|
stop_event.set()
|
|
157
|
-
log(
|
|
158
|
+
log(DEBUG, "Triggered stop event for Simulation Engine.")
|
|
158
159
|
|
|
159
160
|
serverapp_th = threading.Thread(
|
|
160
161
|
target=server_th_with_start_checks,
|
|
@@ -204,7 +205,7 @@ def _main_loop(
|
|
|
204
205
|
serverapp_th = None
|
|
205
206
|
try:
|
|
206
207
|
# Initialize Driver
|
|
207
|
-
driver =
|
|
208
|
+
driver = GrpcDriver(
|
|
208
209
|
driver_service_address=driver_api_address,
|
|
209
210
|
root_certificates=None,
|
|
210
211
|
)
|
|
@@ -248,7 +249,7 @@ def _main_loop(
|
|
|
248
249
|
if serverapp_th:
|
|
249
250
|
serverapp_th.join()
|
|
250
251
|
|
|
251
|
-
log(
|
|
252
|
+
log(DEBUG, "Stopping Simulation Engine now.")
|
|
252
253
|
|
|
253
254
|
|
|
254
255
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
@@ -317,9 +318,9 @@ def _run_simulation(
|
|
|
317
318
|
enabled, DEBUG-level logs will be displayed.
|
|
318
319
|
"""
|
|
319
320
|
# Set logging level
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
321
|
+
logger = logging.getLogger("flwr")
|
|
322
|
+
if verbose_logging:
|
|
323
|
+
update_console_handler(level=DEBUG, timestamps=True, colored=True)
|
|
323
324
|
|
|
324
325
|
if backend_config is None:
|
|
325
326
|
backend_config = {}
|
|
@@ -364,6 +365,8 @@ def _run_simulation(
|
|
|
364
365
|
|
|
365
366
|
finally:
|
|
366
367
|
if run_in_thread:
|
|
368
|
+
# Set logger propagation to False to prevent duplicated log output in Colab.
|
|
369
|
+
logger = set_logger_propagation(logger, False)
|
|
367
370
|
log(DEBUG, "Starting Simulation Engine on a new thread.")
|
|
368
371
|
simulation_engine_th = threading.Thread(target=_main_loop, args=args)
|
|
369
372
|
simulation_engine_th.start()
|
{flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: flwr-nightly
|
|
3
|
-
Version: 1.9.0.
|
|
3
|
+
Version: 1.9.0.dev20240507
|
|
4
4
|
Summary: Flower: A Friendly Federated Learning Framework
|
|
5
5
|
Home-page: https://flower.ai
|
|
6
6
|
License: Apache-2.0
|
|
@@ -36,6 +36,7 @@ Requires-Dist: cryptography (>=42.0.4,<43.0.0)
|
|
|
36
36
|
Requires-Dist: grpcio (>=1.60.0,<2.0.0)
|
|
37
37
|
Requires-Dist: iterators (>=0.0.2,<0.0.3)
|
|
38
38
|
Requires-Dist: numpy (>=1.21.0,<2.0.0)
|
|
39
|
+
Requires-Dist: pathspec (>=0.12.1,<0.13.0)
|
|
39
40
|
Requires-Dist: protobuf (>=4.25.2,<5.0.0)
|
|
40
41
|
Requires-Dist: pycryptodome (>=3.18.0,<4.0.0)
|
|
41
42
|
Requires-Dist: ray (==2.6.3) ; (python_version >= "3.8" and python_version < "3.12") and (extra == "simulation")
|
|
@@ -193,7 +194,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
|
|
|
193
194
|
- [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated)
|
|
194
195
|
- [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl)
|
|
195
196
|
- [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning)
|
|
196
|
-
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/
|
|
197
|
+
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/llm-flowertune)
|
|
197
198
|
- [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/vit-finetune)
|
|
198
199
|
- [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow)
|
|
199
200
|
- [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)
|