flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (66) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +19 -14
  4. flwr/cli/new/new.py +51 -22
  5. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  6. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  7. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  8. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  9. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
  10. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  11. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  13. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
  14. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  15. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  16. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  17. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  18. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  19. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  20. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  21. flwr/cli/run/run.py +1 -1
  22. flwr/cli/utils.py +18 -17
  23. flwr/client/__init__.py +3 -1
  24. flwr/client/app.py +20 -142
  25. flwr/client/grpc_client/connection.py +8 -2
  26. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  27. flwr/client/grpc_rere_client/connection.py +33 -4
  28. flwr/client/mod/centraldp_mods.py +4 -2
  29. flwr/client/mod/localdp_mod.py +9 -3
  30. flwr/client/rest_client/connection.py +92 -169
  31. flwr/client/supernode/__init__.py +24 -0
  32. flwr/client/supernode/app.py +281 -0
  33. flwr/common/grpc.py +5 -1
  34. flwr/common/logger.py +37 -4
  35. flwr/common/message.py +105 -86
  36. flwr/common/record/parametersrecord.py +0 -1
  37. flwr/common/record/recordset.py +78 -27
  38. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  39. flwr/common/telemetry.py +4 -0
  40. flwr/server/app.py +116 -6
  41. flwr/server/compat/app.py +2 -2
  42. flwr/server/compat/app_utils.py +1 -1
  43. flwr/server/compat/driver_client_proxy.py +27 -70
  44. flwr/server/driver/__init__.py +2 -1
  45. flwr/server/driver/driver.py +12 -139
  46. flwr/server/driver/grpc_driver.py +199 -13
  47. flwr/server/run_serverapp.py +18 -4
  48. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  49. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  50. flwr/server/superlink/driver/driver_servicer.py +1 -1
  51. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  52. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  53. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
  54. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  55. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  56. flwr/server/superlink/state/in_memory_state.py +89 -12
  57. flwr/server/superlink/state/sqlite_state.py +133 -16
  58. flwr/server/superlink/state/state.py +56 -6
  59. flwr/simulation/__init__.py +2 -2
  60. flwr/simulation/app.py +16 -1
  61. flwr/simulation/run_simulation.py +10 -7
  62. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
  63. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
  64. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
  65. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
  66. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
@@ -293,7 +293,7 @@ def start_vce(
293
293
  node_states[node_id] = NodeState()
294
294
 
295
295
  # Load backend config
296
- log(INFO, "Supported backends: %s", list(supported_backends.keys()))
296
+ log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
297
297
  backend_config = json.loads(backend_config_json_stream)
298
298
 
299
299
  try:
@@ -30,15 +30,24 @@ from flwr.server.utils import validate_task_ins_or_res
30
30
  from .utils import make_node_unavailable_taskres
31
31
 
32
32
 
33
- class InMemoryState(State):
33
+ class InMemoryState(State): # pylint: disable=R0902,R0904
34
34
  """In-memory State implementation."""
35
35
 
36
36
  def __init__(self) -> None:
37
+
37
38
  # Map node_id to (online_until, ping_interval)
38
39
  self.node_ids: Dict[int, Tuple[float, float]] = {}
39
- self.run_ids: Set[int] = set()
40
+ self.public_key_to_node_id: Dict[bytes, int] = {}
41
+
42
+ # Map run_id to (fab_id, fab_version)
43
+ self.run_ids: Dict[int, Tuple[str, str]] = {}
40
44
  self.task_ins_store: Dict[UUID, TaskIns] = {}
41
45
  self.task_res_store: Dict[UUID, TaskRes] = {}
46
+
47
+ self.client_public_keys: Set[bytes] = set()
48
+ self.server_public_key: Optional[bytes] = None
49
+ self.server_private_key: Optional[bytes] = None
50
+
42
51
  self.lock = threading.Lock()
43
52
 
44
53
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
@@ -201,23 +210,46 @@ class InMemoryState(State):
201
210
  """
202
211
  return len(self.task_res_store)
203
212
 
204
- def create_node(self, ping_interval: float) -> int:
213
+ def create_node(
214
+ self, ping_interval: float, public_key: Optional[bytes] = None
215
+ ) -> int:
205
216
  """Create, store in state, and return `node_id`."""
206
217
  # Sample a random int64 as node_id
207
218
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
208
219
 
209
220
  with self.lock:
210
- if node_id not in self.node_ids:
211
- self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
212
- return node_id
213
- log(ERROR, "Unexpected node registration failure.")
214
- return 0
221
+ if node_id in self.node_ids:
222
+ log(ERROR, "Unexpected node registration failure.")
223
+ return 0
224
+
225
+ if public_key is not None:
226
+ if (
227
+ public_key in self.public_key_to_node_id
228
+ or node_id in self.public_key_to_node_id.values()
229
+ ):
230
+ log(ERROR, "Unexpected node registration failure.")
231
+ return 0
215
232
 
216
- def delete_node(self, node_id: int) -> None:
233
+ self.public_key_to_node_id[public_key] = node_id
234
+
235
+ self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
236
+ return node_id
237
+
238
+ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
217
239
  """Delete a client node."""
218
240
  with self.lock:
219
241
  if node_id not in self.node_ids:
220
242
  raise ValueError(f"Node {node_id} not found")
243
+
244
+ if public_key is not None:
245
+ if (
246
+ public_key not in self.public_key_to_node_id
247
+ or node_id not in self.public_key_to_node_id.values()
248
+ ):
249
+ raise ValueError("Public key or node_id not found")
250
+
251
+ del self.public_key_to_node_id[public_key]
252
+
221
253
  del self.node_ids[node_id]
222
254
 
223
255
  def get_nodes(self, run_id: int) -> Set[int]:
@@ -238,18 +270,63 @@ class InMemoryState(State):
238
270
  if online_until > current_time
239
271
  }
240
272
 
241
- def create_run(self) -> int:
242
- """Create one run."""
273
+ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
274
+ """Retrieve stored `node_id` filtered by `client_public_keys`."""
275
+ return self.public_key_to_node_id.get(client_public_key)
276
+
277
+ def create_run(self, fab_id: str, fab_version: str) -> int:
278
+ """Create a new run for the specified `fab_id` and `fab_version`."""
243
279
  # Sample a random int64 as run_id
244
280
  with self.lock:
245
281
  run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
246
282
 
247
283
  if run_id not in self.run_ids:
248
- self.run_ids.add(run_id)
284
+ self.run_ids[run_id] = (fab_id, fab_version)
249
285
  return run_id
250
286
  log(ERROR, "Unexpected run creation failure.")
251
287
  return 0
252
288
 
289
+ def store_server_private_public_key(
290
+ self, private_key: bytes, public_key: bytes
291
+ ) -> None:
292
+ """Store `server_private_key` and `server_public_key` in state."""
293
+ with self.lock:
294
+ if self.server_private_key is None and self.server_public_key is None:
295
+ self.server_private_key = private_key
296
+ self.server_public_key = public_key
297
+ else:
298
+ raise RuntimeError("Server private and public key already set")
299
+
300
+ def get_server_private_key(self) -> Optional[bytes]:
301
+ """Retrieve `server_private_key` in urlsafe bytes."""
302
+ return self.server_private_key
303
+
304
+ def get_server_public_key(self) -> Optional[bytes]:
305
+ """Retrieve `server_public_key` in urlsafe bytes."""
306
+ return self.server_public_key
307
+
308
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
309
+ """Store a set of `client_public_keys` in state."""
310
+ with self.lock:
311
+ self.client_public_keys = public_keys
312
+
313
+ def store_client_public_key(self, public_key: bytes) -> None:
314
+ """Store a `client_public_key` in state."""
315
+ with self.lock:
316
+ self.client_public_keys.add(public_key)
317
+
318
+ def get_client_public_keys(self) -> Set[bytes]:
319
+ """Retrieve all currently stored `client_public_keys` as a set."""
320
+ return self.client_public_keys
321
+
322
+ def get_run(self, run_id: int) -> Tuple[int, str, str]:
323
+ """Retrieve information about the run with the specified `run_id`."""
324
+ with self.lock:
325
+ if run_id not in self.run_ids:
326
+ log(ERROR, "`run_id` is invalid")
327
+ return 0, "", ""
328
+ return run_id, *self.run_ids[run_id]
329
+
253
330
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
254
331
  """Acknowledge a ping received from a node, serving as a heartbeat."""
255
332
  with self.lock:
@@ -20,7 +20,7 @@ import re
20
20
  import sqlite3
21
21
  import time
22
22
  from logging import DEBUG, ERROR
23
- from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
23
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24
24
  from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import log, now
@@ -36,7 +36,21 @@ SQL_CREATE_TABLE_NODE = """
36
36
  CREATE TABLE IF NOT EXISTS node(
37
37
  node_id INTEGER UNIQUE,
38
38
  online_until REAL,
39
- ping_interval REAL
39
+ ping_interval REAL,
40
+ public_key BLOB
41
+ );
42
+ """
43
+
44
+ SQL_CREATE_TABLE_CREDENTIAL = """
45
+ CREATE TABLE IF NOT EXISTS credential(
46
+ private_key BLOB PRIMARY KEY,
47
+ public_key BLOB
48
+ );
49
+ """
50
+
51
+ SQL_CREATE_TABLE_PUBLIC_KEY = """
52
+ CREATE TABLE IF NOT EXISTS public_key(
53
+ public_key BLOB UNIQUE
40
54
  );
41
55
  """
42
56
 
@@ -46,7 +60,9 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
46
60
 
47
61
  SQL_CREATE_TABLE_RUN = """
48
62
  CREATE TABLE IF NOT EXISTS run(
49
- run_id INTEGER UNIQUE
63
+ run_id INTEGER UNIQUE,
64
+ fab_id TEXT,
65
+ fab_version TEXT
50
66
  );
51
67
  """
52
68
 
@@ -70,7 +86,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
70
86
  );
71
87
  """
72
88
 
73
-
74
89
  SQL_CREATE_TABLE_TASK_RES = """
75
90
  CREATE TABLE IF NOT EXISTS task_res(
76
91
  task_id TEXT UNIQUE,
@@ -94,7 +109,7 @@ CREATE TABLE IF NOT EXISTS task_res(
94
109
  DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
95
110
 
96
111
 
97
- class SqliteState(State):
112
+ class SqliteState(State): # pylint: disable=R0904
98
113
  """SQLite-based state implementation."""
99
114
 
100
115
  def __init__(
@@ -132,6 +147,8 @@ class SqliteState(State):
132
147
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
133
148
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
134
149
  cur.execute(SQL_CREATE_TABLE_NODE)
150
+ cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
151
+ cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
135
152
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
136
153
  res = cur.execute("SELECT name FROM sqlite_schema;")
137
154
 
@@ -140,7 +157,7 @@ class SqliteState(State):
140
157
  def query(
141
158
  self,
142
159
  query: str,
143
- data: Optional[Union[List[DictOrTuple], DictOrTuple]] = None,
160
+ data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
144
161
  ) -> List[Dict[str, Any]]:
145
162
  """Execute a SQL query."""
146
163
  if self.conn is None:
@@ -518,26 +535,54 @@ class SqliteState(State):
518
535
 
519
536
  return None
520
537
 
521
- def create_node(self, ping_interval: float) -> int:
538
+ def create_node(
539
+ self, ping_interval: float, public_key: Optional[bytes] = None
540
+ ) -> int:
522
541
  """Create, store in state, and return `node_id`."""
523
542
  # Sample a random int64 as node_id
524
543
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
525
544
 
545
+ query = "SELECT node_id FROM node WHERE public_key = :public_key;"
546
+ row = self.query(query, {"public_key": public_key})
547
+
548
+ if len(row) > 0:
549
+ log(ERROR, "Unexpected node registration failure.")
550
+ return 0
551
+
526
552
  query = (
527
- "INSERT INTO node (node_id, online_until, ping_interval) VALUES (?, ?, ?)"
553
+ "INSERT INTO node "
554
+ "(node_id, online_until, ping_interval, public_key) "
555
+ "VALUES (?, ?, ?, ?)"
528
556
  )
529
557
 
530
558
  try:
531
- self.query(query, (node_id, time.time() + ping_interval, ping_interval))
559
+ self.query(
560
+ query, (node_id, time.time() + ping_interval, ping_interval, public_key)
561
+ )
532
562
  except sqlite3.IntegrityError:
533
563
  log(ERROR, "Unexpected node registration failure.")
534
564
  return 0
535
565
  return node_id
536
566
 
537
- def delete_node(self, node_id: int) -> None:
567
+ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
538
568
  """Delete a client node."""
539
- query = "DELETE FROM node WHERE node_id = :node_id;"
540
- self.query(query, {"node_id": node_id})
569
+ query = "DELETE FROM node WHERE node_id = ?"
570
+ params = (node_id,)
571
+
572
+ if public_key is not None:
573
+ query += " AND public_key = ?"
574
+ params += (public_key,) # type: ignore
575
+
576
+ if self.conn is None:
577
+ raise AttributeError("State is not initialized.")
578
+
579
+ try:
580
+ with self.conn:
581
+ rows = self.conn.execute(query, params)
582
+ if rows.rowcount < 1:
583
+ raise ValueError("Public key or node_id not found")
584
+ except KeyError as exc:
585
+ log(ERROR, {"query": query, "data": params, "exception": exc})
541
586
 
542
587
  def get_nodes(self, run_id: int) -> Set[int]:
543
588
  """Retrieve all currently stored node IDs as a set.
@@ -558,8 +603,17 @@ class SqliteState(State):
558
603
  result: Set[int] = {row["node_id"] for row in rows}
559
604
  return result
560
605
 
561
- def create_run(self) -> int:
562
- """Create one run and store it in state."""
606
+ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
607
+ """Retrieve stored `node_id` filtered by `client_public_keys`."""
608
+ query = "SELECT node_id FROM node WHERE public_key = :public_key;"
609
+ row = self.query(query, {"public_key": client_public_key})
610
+ if len(row) > 0:
611
+ node_id: int = row[0]["node_id"]
612
+ return node_id
613
+ return None
614
+
615
+ def create_run(self, fab_id: str, fab_version: str) -> int:
616
+ """Create a new run for the specified `fab_id` and `fab_version`."""
563
617
  # Sample a random int64 as run_id
564
618
  run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
565
619
 
@@ -567,12 +621,75 @@ class SqliteState(State):
567
621
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
568
622
  # If run_id does not exist
569
623
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
570
- query = "INSERT INTO run VALUES(:run_id);"
571
- self.query(query, {"run_id": run_id})
624
+ query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
625
+ self.query(query, (run_id, fab_id, fab_version))
572
626
  return run_id
573
627
  log(ERROR, "Unexpected run creation failure.")
574
628
  return 0
575
629
 
630
+ def store_server_private_public_key(
631
+ self, private_key: bytes, public_key: bytes
632
+ ) -> None:
633
+ """Store `server_private_key` and `server_public_key` in state."""
634
+ query = "SELECT COUNT(*) FROM credential"
635
+ count = self.query(query)[0]["COUNT(*)"]
636
+ if count < 1:
637
+ query = (
638
+ "INSERT OR REPLACE INTO credential (private_key, public_key) "
639
+ "VALUES (:private_key, :public_key)"
640
+ )
641
+ self.query(query, {"private_key": private_key, "public_key": public_key})
642
+ else:
643
+ raise RuntimeError("Server private and public key already set")
644
+
645
+ def get_server_private_key(self) -> Optional[bytes]:
646
+ """Retrieve `server_private_key` in urlsafe bytes."""
647
+ query = "SELECT private_key FROM credential"
648
+ rows = self.query(query)
649
+ try:
650
+ private_key: Optional[bytes] = rows[0]["private_key"]
651
+ except IndexError:
652
+ private_key = None
653
+ return private_key
654
+
655
+ def get_server_public_key(self) -> Optional[bytes]:
656
+ """Retrieve `server_public_key` in urlsafe bytes."""
657
+ query = "SELECT public_key FROM credential"
658
+ rows = self.query(query)
659
+ try:
660
+ public_key: Optional[bytes] = rows[0]["public_key"]
661
+ except IndexError:
662
+ public_key = None
663
+ return public_key
664
+
665
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
666
+ """Store a set of `client_public_keys` in state."""
667
+ query = "INSERT INTO public_key (public_key) VALUES (?)"
668
+ data = [(key,) for key in public_keys]
669
+ self.query(query, data)
670
+
671
+ def store_client_public_key(self, public_key: bytes) -> None:
672
+ """Store a `client_public_key` in state."""
673
+ query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
674
+ self.query(query, {"public_key": public_key})
675
+
676
+ def get_client_public_keys(self) -> Set[bytes]:
677
+ """Retrieve all currently stored `client_public_keys` as a set."""
678
+ query = "SELECT public_key FROM public_key"
679
+ rows = self.query(query)
680
+ result: Set[bytes] = {row["public_key"] for row in rows}
681
+ return result
682
+
683
+ def get_run(self, run_id: int) -> Tuple[int, str, str]:
684
+ """Retrieve information about the run with the specified `run_id`."""
685
+ query = "SELECT * FROM run WHERE run_id = ?;"
686
+ try:
687
+ row = self.query(query, (run_id,))[0]
688
+ return run_id, row["fab_id"], row["fab_version"]
689
+ except sqlite3.IntegrityError:
690
+ log(ERROR, "`run_id` does not exist.")
691
+ return 0, "", ""
692
+
576
693
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
577
694
  """Acknowledge a ping received from a node, serving as a heartbeat."""
578
695
  # Update `online_until` and `ping_interval` for the given `node_id`
@@ -16,13 +16,13 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import List, Optional, Set
19
+ from typing import List, Optional, Set, Tuple
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
23
23
 
24
24
 
25
- class State(abc.ABC):
25
+ class State(abc.ABC): # pylint: disable=R0904
26
26
  """Abstract State."""
27
27
 
28
28
  @abc.abstractmethod
@@ -132,11 +132,13 @@ class State(abc.ABC):
132
132
  """Delete all delivered TaskIns/TaskRes pairs."""
133
133
 
134
134
  @abc.abstractmethod
135
- def create_node(self, ping_interval: float) -> int:
135
+ def create_node(
136
+ self, ping_interval: float, public_key: Optional[bytes] = None
137
+ ) -> int:
136
138
  """Create, store in state, and return `node_id`."""
137
139
 
138
140
  @abc.abstractmethod
139
- def delete_node(self, node_id: int) -> None:
141
+ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
140
142
  """Remove `node_id` from state."""
141
143
 
142
144
  @abc.abstractmethod
@@ -150,8 +152,56 @@ class State(abc.ABC):
150
152
  """
151
153
 
152
154
  @abc.abstractmethod
153
- def create_run(self) -> int:
154
- """Create one run."""
155
+ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
156
+ """Retrieve stored `node_id` filtered by `client_public_keys`."""
157
+
158
+ @abc.abstractmethod
159
+ def create_run(self, fab_id: str, fab_version: str) -> int:
160
+ """Create a new run for the specified `fab_id` and `fab_version`."""
161
+
162
+ @abc.abstractmethod
163
+ def get_run(self, run_id: int) -> Tuple[int, str, str]:
164
+ """Retrieve information about the run with the specified `run_id`.
165
+
166
+ Parameters
167
+ ----------
168
+ run_id : int
169
+ The identifier of the run.
170
+
171
+ Returns
172
+ -------
173
+ Tuple[int, str, str]
174
+ A tuple containing three elements:
175
+ - `run_id`: The identifier of the run, same as the specified `run_id`.
176
+ - `fab_id`: The identifier of the FAB used in the specified run.
177
+ - `fab_version`: The version of the FAB used in the specified run.
178
+ """
179
+
180
+ @abc.abstractmethod
181
+ def store_server_private_public_key(
182
+ self, private_key: bytes, public_key: bytes
183
+ ) -> None:
184
+ """Store `server_private_key` and `server_public_key` in state."""
185
+
186
+ @abc.abstractmethod
187
+ def get_server_private_key(self) -> Optional[bytes]:
188
+ """Retrieve `server_private_key` in urlsafe bytes."""
189
+
190
+ @abc.abstractmethod
191
+ def get_server_public_key(self) -> Optional[bytes]:
192
+ """Retrieve `server_public_key` in urlsafe bytes."""
193
+
194
+ @abc.abstractmethod
195
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
196
+ """Store a set of `client_public_keys` in state."""
197
+
198
+ @abc.abstractmethod
199
+ def store_client_public_key(self, public_key: bytes) -> None:
200
+ """Store a `client_public_key` in state."""
201
+
202
+ @abc.abstractmethod
203
+ def get_client_public_keys(self) -> Set[bytes]:
204
+ """Retrieve all currently stored `client_public_keys` as a set."""
155
205
 
156
206
  @abc.abstractmethod
157
207
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
@@ -17,7 +17,7 @@
17
17
 
18
18
  import importlib
19
19
 
20
- from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli
20
+ from flwr.simulation.run_simulation import run_simulation
21
21
 
22
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
23
23
 
@@ -36,4 +36,4 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
36
  raise ImportError(RAY_IMPORT_ERROR)
37
37
 
38
38
 
39
- __all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"]
39
+ __all__ = ["start_simulation", "run_simulation"]
flwr/simulation/app.py CHANGED
@@ -15,6 +15,8 @@
15
15
  """Flower simulation app."""
16
16
 
17
17
 
18
+ import asyncio
19
+ import logging
18
20
  import sys
19
21
  import threading
20
22
  import traceback
@@ -27,7 +29,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
27
29
 
28
30
  from flwr.client import ClientFn
29
31
  from flwr.common import EventType, event
30
- from flwr.common.logger import log
32
+ from flwr.common.logger import log, set_logger_propagation
31
33
  from flwr.server.client_manager import ClientManager
32
34
  from flwr.server.history import History
33
35
  from flwr.server.server import Server, init_defaults, run_fl
@@ -156,6 +158,7 @@ def start_simulation(
156
158
  is an advanced feature. For all details, please refer to the Ray documentation:
157
159
  https://docs.ray.io/en/latest/ray-core/scheduling/index.html
158
160
 
161
+
159
162
  Returns
160
163
  -------
161
164
  hist : flwr.server.history.History
@@ -167,6 +170,18 @@ def start_simulation(
167
170
  {"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
168
171
  )
169
172
 
173
+ # Set logger propagation
174
+ loop: Optional[asyncio.AbstractEventLoop] = None
175
+ try:
176
+ loop = asyncio.get_running_loop()
177
+ except RuntimeError:
178
+ loop = None
179
+ finally:
180
+ if loop and loop.is_running():
181
+ # Set logger propagation to False to prevent duplicated log output in Colab.
182
+ logger = logging.getLogger("flwr")
183
+ _ = set_logger_propagation(logger, False)
184
+
170
185
  # Initialize server and server config
171
186
  initialized_server, initialized_config = init_defaults(
172
187
  server=server,
@@ -28,8 +28,9 @@ import grpc
28
28
 
29
29
  from flwr.client import ClientApp
30
30
  from flwr.common import EventType, event, log
31
+ from flwr.common.logger import set_logger_propagation, update_console_handler
31
32
  from flwr.common.typing import ConfigsRecordValues
32
- from flwr.server.driver.driver import Driver
33
+ from flwr.server.driver import Driver, GrpcDriver
33
34
  from flwr.server.run_serverapp import run
34
35
  from flwr.server.server_app import ServerApp
35
36
  from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
@@ -154,7 +155,7 @@ def run_serverapp_th(
154
155
  # Upon completion, trigger stop event if one was passed
155
156
  if stop_event is not None:
156
157
  stop_event.set()
157
- log(WARNING, "Triggered stop event for Simulation Engine.")
158
+ log(DEBUG, "Triggered stop event for Simulation Engine.")
158
159
 
159
160
  serverapp_th = threading.Thread(
160
161
  target=server_th_with_start_checks,
@@ -204,7 +205,7 @@ def _main_loop(
204
205
  serverapp_th = None
205
206
  try:
206
207
  # Initialize Driver
207
- driver = Driver(
208
+ driver = GrpcDriver(
208
209
  driver_service_address=driver_api_address,
209
210
  root_certificates=None,
210
211
  )
@@ -248,7 +249,7 @@ def _main_loop(
248
249
  if serverapp_th:
249
250
  serverapp_th.join()
250
251
 
251
- log(INFO, "Stopping Simulation Engine now.")
252
+ log(DEBUG, "Stopping Simulation Engine now.")
252
253
 
253
254
 
254
255
  # pylint: disable=too-many-arguments,too-many-locals
@@ -317,9 +318,9 @@ def _run_simulation(
317
318
  enabled, DEBUG-level logs will be displayed.
318
319
  """
319
320
  # Set logging level
320
- if not verbose_logging:
321
- logger = logging.getLogger("flwr")
322
- logger.setLevel(INFO)
321
+ logger = logging.getLogger("flwr")
322
+ if verbose_logging:
323
+ update_console_handler(level=DEBUG, timestamps=True, colored=True)
323
324
 
324
325
  if backend_config is None:
325
326
  backend_config = {}
@@ -364,6 +365,8 @@ def _run_simulation(
364
365
 
365
366
  finally:
366
367
  if run_in_thread:
368
+ # Set logger propagation to False to prevent duplicated log output in Colab.
369
+ logger = set_logger_propagation(logger, False)
367
370
  log(DEBUG, "Starting Simulation Engine on a new thread.")
368
371
  simulation_engine_th = threading.Thread(target=_main_loop, args=args)
369
372
  simulation_engine_th.start()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.9.0.dev20240417
3
+ Version: 1.9.0.dev20240507
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -36,6 +36,7 @@ Requires-Dist: cryptography (>=42.0.4,<43.0.0)
36
36
  Requires-Dist: grpcio (>=1.60.0,<2.0.0)
37
37
  Requires-Dist: iterators (>=0.0.2,<0.0.3)
38
38
  Requires-Dist: numpy (>=1.21.0,<2.0.0)
39
+ Requires-Dist: pathspec (>=0.12.1,<0.13.0)
39
40
  Requires-Dist: protobuf (>=4.25.2,<5.0.0)
40
41
  Requires-Dist: pycryptodome (>=3.18.0,<4.0.0)
41
42
  Requires-Dist: ray (==2.6.3) ; (python_version >= "3.8" and python_version < "3.12") and (extra == "simulation")
@@ -193,7 +194,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
193
194
  - [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated)
194
195
  - [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl)
195
196
  - [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning)
196
- - [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/fedllm-finetune)
197
+ - [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/llm-flowertune)
197
198
  - [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/vit-finetune)
198
199
  - [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow)
199
200
  - [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)