flwr 1.14.0__py3-none-any.whl → 1.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. flwr/cli/auth_plugin/__init__.py +31 -0
  2. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  3. flwr/cli/cli_user_auth_interceptor.py +6 -2
  4. flwr/cli/config_utils.py +24 -147
  5. flwr/cli/constant.py +27 -0
  6. flwr/cli/install.py +1 -1
  7. flwr/cli/log.py +18 -3
  8. flwr/cli/login/login.py +43 -8
  9. flwr/cli/ls.py +14 -5
  10. flwr/cli/new/templates/app/README.md.tpl +3 -2
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  20. flwr/cli/run/run.py +21 -11
  21. flwr/cli/stop.py +13 -4
  22. flwr/cli/utils.py +54 -40
  23. flwr/client/app.py +36 -48
  24. flwr/client/clientapp/app.py +19 -25
  25. flwr/client/clientapp/utils.py +1 -1
  26. flwr/client/grpc_client/connection.py +1 -12
  27. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  28. flwr/client/grpc_rere_client/connection.py +46 -36
  29. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  30. flwr/client/message_handler/task_handler.py +0 -17
  31. flwr/client/rest_client/connection.py +34 -26
  32. flwr/client/supernode/app.py +18 -72
  33. flwr/common/args.py +25 -47
  34. flwr/common/auth_plugin/auth_plugin.py +34 -23
  35. flwr/common/config.py +166 -16
  36. flwr/common/constant.py +24 -9
  37. flwr/common/differential_privacy.py +2 -1
  38. flwr/common/exit/__init__.py +24 -0
  39. flwr/common/exit/exit.py +99 -0
  40. flwr/common/exit/exit_code.py +93 -0
  41. flwr/common/exit_handlers.py +32 -30
  42. flwr/common/grpc.py +167 -4
  43. flwr/common/logger.py +26 -7
  44. flwr/common/object_ref.py +0 -14
  45. flwr/common/record/recordset.py +1 -1
  46. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  47. flwr/common/serde.py +6 -4
  48. flwr/common/typing.py +20 -0
  49. flwr/proto/clientappio_pb2.py +1 -1
  50. flwr/proto/error_pb2.py +1 -1
  51. flwr/proto/exec_pb2.py +13 -25
  52. flwr/proto/exec_pb2.pyi +27 -54
  53. flwr/proto/fab_pb2.py +1 -1
  54. flwr/proto/fleet_pb2.py +31 -31
  55. flwr/proto/fleet_pb2.pyi +23 -23
  56. flwr/proto/fleet_pb2_grpc.py +30 -30
  57. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  58. flwr/proto/grpcadapter_pb2.py +1 -1
  59. flwr/proto/log_pb2.py +1 -1
  60. flwr/proto/message_pb2.py +1 -1
  61. flwr/proto/node_pb2.py +3 -3
  62. flwr/proto/node_pb2.pyi +1 -4
  63. flwr/proto/recordset_pb2.py +1 -1
  64. flwr/proto/run_pb2.py +1 -1
  65. flwr/proto/serverappio_pb2.py +24 -25
  66. flwr/proto/serverappio_pb2.pyi +26 -32
  67. flwr/proto/serverappio_pb2_grpc.py +28 -28
  68. flwr/proto/serverappio_pb2_grpc.pyi +16 -16
  69. flwr/proto/simulationio_pb2.py +1 -1
  70. flwr/proto/task_pb2.py +1 -1
  71. flwr/proto/transport_pb2.py +1 -1
  72. flwr/server/app.py +116 -128
  73. flwr/server/compat/app_utils.py +0 -1
  74. flwr/server/compat/driver_client_proxy.py +1 -2
  75. flwr/server/driver/grpc_driver.py +32 -27
  76. flwr/server/driver/inmemory_driver.py +2 -1
  77. flwr/server/serverapp/app.py +12 -10
  78. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  79. flwr/server/superlink/driver/serverappio_servicer.py +74 -48
  80. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  81. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  82. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
  83. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +110 -168
  84. flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
  85. flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
  86. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  87. flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
  88. flwr/server/superlink/linkstate/linkstate.py +17 -38
  89. flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
  90. flwr/server/superlink/linkstate/utils.py +18 -8
  91. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  92. flwr/server/utils/validator.py +9 -34
  93. flwr/simulation/app.py +4 -6
  94. flwr/simulation/legacy_app.py +4 -2
  95. flwr/simulation/run_simulation.py +1 -1
  96. flwr/simulation/simulationio_connection.py +2 -1
  97. flwr/superexec/exec_grpc.py +1 -1
  98. flwr/superexec/exec_servicer.py +23 -2
  99. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/METADATA +8 -8
  100. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/RECORD +103 -97
  101. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/LICENSE +0 -0
  102. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/WHEEL +0 -0
  103. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/entry_points.txt +0 -0
@@ -31,6 +31,7 @@ from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
33
33
  RUN_ID_NUM_BYTES,
34
+ SUPERLINK_NODE_ID,
34
35
  Status,
35
36
  )
36
37
  from flwr.common.record import ConfigsRecord
@@ -70,16 +71,9 @@ CREATE TABLE IF NOT EXISTS node(
70
71
  );
71
72
  """
72
73
 
73
- SQL_CREATE_TABLE_CREDENTIAL = """
74
- CREATE TABLE IF NOT EXISTS credential(
75
- private_key BLOB PRIMARY KEY,
76
- public_key BLOB
77
- );
78
- """
79
-
80
74
  SQL_CREATE_TABLE_PUBLIC_KEY = """
81
75
  CREATE TABLE IF NOT EXISTS public_key(
82
- public_key BLOB UNIQUE
76
+ public_key BLOB PRIMARY KEY
83
77
  );
84
78
  """
85
79
 
@@ -128,9 +122,7 @@ CREATE TABLE IF NOT EXISTS task_ins(
128
122
  task_id TEXT UNIQUE,
129
123
  group_id TEXT,
130
124
  run_id INTEGER,
131
- producer_anonymous BOOLEAN,
132
125
  producer_node_id INTEGER,
133
- consumer_anonymous BOOLEAN,
134
126
  consumer_node_id INTEGER,
135
127
  created_at REAL,
136
128
  delivered_at TEXT,
@@ -148,9 +140,7 @@ CREATE TABLE IF NOT EXISTS task_res(
148
140
  task_id TEXT UNIQUE,
149
141
  group_id TEXT,
150
142
  run_id INTEGER,
151
- producer_anonymous BOOLEAN,
152
143
  producer_node_id INTEGER,
153
- consumer_anonymous BOOLEAN,
154
144
  consumer_node_id INTEGER,
155
145
  created_at REAL,
156
146
  delivered_at TEXT,
@@ -211,7 +201,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
211
201
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
212
202
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
213
203
  cur.execute(SQL_CREATE_TABLE_NODE)
214
- cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
215
204
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
216
205
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
217
206
  res = cur.execute("SELECT name FROM sqlite_schema;")
@@ -263,11 +252,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
263
252
 
264
253
  Constraints
265
254
  -----------
266
- If `task_ins.task.consumer.anonymous` is `True`, then
267
- `task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
268
255
 
269
- If `task_ins.task.consumer.anonymous` is `False`, then
270
- `task_ins.task.consumer.node_id` MUST be set (not 0)
256
+ `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
271
257
  """
272
258
  # Validate task
273
259
  errors = validate_task_ins_or_res(task_ins)
@@ -292,7 +278,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
292
278
  log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
293
279
  return None
294
280
  # Validate source node ID
295
- if task_ins.task.producer.node_id != 0:
281
+ if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
296
282
  log(
297
283
  ERROR,
298
284
  "Invalid source node ID for TaskIns: %s",
@@ -301,14 +287,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
301
287
  return None
302
288
  # Validate destination node ID
303
289
  query = "SELECT node_id FROM node WHERE node_id = ?;"
304
- if not task_ins.task.consumer.anonymous:
305
- if not self.query(query, (data[0]["consumer_node_id"],)):
306
- log(
307
- ERROR,
308
- "Invalid destination node ID for TaskIns: %s",
309
- task_ins.task.consumer.node_id,
310
- )
311
- return None
290
+ if not self.query(query, (data[0]["consumer_node_id"],)):
291
+ log(
292
+ ERROR,
293
+ "Invalid destination node ID for TaskIns: %s",
294
+ task_ins.task.consumer.node_id,
295
+ )
296
+ return None
312
297
 
313
298
  columns = ", ".join([f":{key}" for key in data[0]])
314
299
  query = f"INSERT INTO task_ins VALUES({columns});"
@@ -319,25 +304,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
319
304
 
320
305
  return task_id
321
306
 
322
- def get_task_ins(
323
- self, node_id: Optional[int], limit: Optional[int]
324
- ) -> list[TaskIns]:
325
- """Get undelivered TaskIns for one node (either anonymous or with ID).
307
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
308
+ """Get undelivered TaskIns for one node.
326
309
 
327
310
  Usually, the Fleet API calls this for Nodes planning to work on one or more
328
311
  TaskIns.
329
312
 
330
313
  Constraints
331
314
  -----------
332
- If `node_id` is not `None`, retrieve all TaskIns where
315
+ Retrieve all TaskIns where
333
316
 
334
317
  1. the `task_ins.task.consumer.node_id` equals `node_id` AND
335
- 2. the `task_ins.task.consumer.anonymous` equals `False` AND
336
- 3. the `task_ins.task.delivered_at` equals `""`.
337
-
338
- If `node_id` is `None`, retrieve all TaskIns where the
339
- `task_ins.task.consumer.node_id` equals `0` and
340
- `task_ins.task.consumer.anonymous` is set to `True`.
318
+ 2. the `task_ins.task.delivered_at` equals `""`.
341
319
 
342
320
  `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
343
321
  the result.
@@ -348,38 +326,23 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
348
326
  if limit is not None and limit < 1:
349
327
  raise AssertionError("`limit` must be >= 1")
350
328
 
351
- if node_id == 0:
352
- msg = (
353
- "`node_id` must be >= 1"
354
- "\n\n For requesting anonymous tasks use `node_id` equal `None`"
355
- )
329
+ if node_id == SUPERLINK_NODE_ID:
330
+ msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
356
331
  raise AssertionError(msg)
357
332
 
358
333
  data: dict[str, Union[str, int]] = {}
359
334
 
360
- if node_id is None:
361
- # Retrieve all anonymous Tasks
362
- query = """
363
- SELECT task_id
364
- FROM task_ins
365
- WHERE consumer_anonymous == 1
366
- AND consumer_node_id == 0
367
- AND delivered_at = ""
368
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
369
- """
370
- else:
371
- # Convert the uint64 value to sint64 for SQLite
372
- data["node_id"] = convert_uint64_to_sint64(node_id)
335
+ # Convert the uint64 value to sint64 for SQLite
336
+ data["node_id"] = convert_uint64_to_sint64(node_id)
373
337
 
374
- # Retrieve all TaskIns for node_id
375
- query = """
376
- SELECT task_id
377
- FROM task_ins
378
- WHERE consumer_anonymous == 0
379
- AND consumer_node_id == :node_id
380
- AND delivered_at = ""
381
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
382
- """
338
+ # Retrieve all TaskIns for node_id
339
+ query = """
340
+ SELECT task_id
341
+ FROM task_ins
342
+ WHERE consumer_node_id == :node_id
343
+ AND delivered_at = ""
344
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
345
+ """
383
346
 
384
347
  if limit is not None:
385
348
  query += " LIMIT :limit"
@@ -429,11 +392,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
429
392
 
430
393
  Constraints
431
394
  -----------
432
- If `task_res.task.consumer.anonymous` is `True`, then
433
- `task_res.task.consumer.node_id` MUST NOT be set (equal 0).
434
-
435
- If `task_res.task.consumer.anonymous` is `False`, then
436
- `task_res.task.consumer.node_id` MUST be set (not 0)
395
+ `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
437
396
  """
438
397
  # Validate task
439
398
  errors = validate_task_ins_or_res(task_res)
@@ -459,7 +418,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
459
418
  if (
460
419
  task_ins
461
420
  and task_res
462
- and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
463
421
  and convert_sint64_to_uint64(task_ins["consumer_node_id"])
464
422
  != task_res.task.producer.node_id
465
423
  ):
@@ -635,23 +593,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
635
593
 
636
594
  return {UUID(row["task_id"]) for row in rows}
637
595
 
638
- def create_node(
639
- self, ping_interval: float, public_key: Optional[bytes] = None
640
- ) -> int:
596
+ def create_node(self, ping_interval: float) -> int:
641
597
  """Create, store in the link state, and return `node_id`."""
642
598
  # Sample a random uint64 as node_id
643
- uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
599
+ uint64_node_id = generate_rand_int_from_bytes(
600
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
601
+ )
644
602
 
645
603
  # Convert the uint64 value to sint64 for SQLite
646
604
  sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
647
605
 
648
- query = "SELECT node_id FROM node WHERE public_key = :public_key;"
649
- row = self.query(query, {"public_key": public_key})
650
-
651
- if len(row) > 0:
652
- log(ERROR, "Unexpected node registration failure.")
653
- return 0
654
-
655
606
  query = (
656
607
  "INSERT INTO node "
657
608
  "(node_id, online_until, ping_interval, public_key) "
@@ -665,7 +616,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
665
616
  sint64_node_id,
666
617
  time.time() + ping_interval,
667
618
  ping_interval,
668
- public_key,
619
+ b"", # Initialize with an empty public key
669
620
  ),
670
621
  )
671
622
  except sqlite3.IntegrityError:
@@ -675,7 +626,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
675
626
  # Note: we need to return the uint64 value of the node_id
676
627
  return uint64_node_id
677
628
 
678
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
629
+ def delete_node(self, node_id: int) -> None:
679
630
  """Delete a node."""
680
631
  # Convert the uint64 value to sint64 for SQLite
681
632
  sint64_node_id = convert_uint64_to_sint64(node_id)
@@ -683,10 +634,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
683
634
  query = "DELETE FROM node WHERE node_id = ?"
684
635
  params = (sint64_node_id,)
685
636
 
686
- if public_key is not None:
687
- query += " AND public_key = ?"
688
- params += (public_key,) # type: ignore
689
-
690
637
  if self.conn is None:
691
638
  raise AttributeError("LinkState is not initialized.")
692
639
 
@@ -694,7 +641,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
694
641
  with self.conn:
695
642
  rows = self.conn.execute(query, params)
696
643
  if rows.rowcount < 1:
697
- raise ValueError("Public key or node_id not found")
644
+ raise ValueError(f"Node {node_id} not found")
698
645
  except KeyError as exc:
699
646
  log(ERROR, {"query": query, "data": params, "exception": exc})
700
647
 
@@ -722,6 +669,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
722
669
  result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
723
670
  return result
724
671
 
672
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
673
+ """Set `public_key` for the specified `node_id`."""
674
+ # Convert the uint64 value to sint64 for SQLite
675
+ sint64_node_id = convert_uint64_to_sint64(node_id)
676
+
677
+ # Check if the node exists in the `node` table
678
+ query = "SELECT 1 FROM node WHERE node_id = ?"
679
+ if not self.query(query, (sint64_node_id,)):
680
+ raise ValueError(f"Node {node_id} not found")
681
+
682
+ # Check if the public key is already in use in the `node` table
683
+ query = "SELECT 1 FROM node WHERE public_key = ?"
684
+ if self.query(query, (public_key,)):
685
+ raise ValueError("Public key already in use")
686
+
687
+ # Update the `node` table to set the public key for the given node ID
688
+ query = "UPDATE node SET public_key = ? WHERE node_id = ?"
689
+ self.query(query, (public_key, sint64_node_id))
690
+
691
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
692
+ """Get `public_key` for the specified `node_id`."""
693
+ # Convert the uint64 value to sint64 for SQLite
694
+ sint64_node_id = convert_uint64_to_sint64(node_id)
695
+
696
+ # Query the public key for the given node_id
697
+ query = "SELECT public_key FROM node WHERE node_id = ?"
698
+ rows = self.query(query, (sint64_node_id,))
699
+
700
+ # If no result is found, return None
701
+ if not rows:
702
+ raise ValueError(f"Node {node_id} not found")
703
+
704
+ # Return the public key if it is not empty, otherwise return None
705
+ return rows[0]["public_key"] or None
706
+
725
707
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
726
708
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
727
709
  query = "SELECT node_id FROM node WHERE public_key = :public_key;"
@@ -783,46 +765,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
783
765
  log(ERROR, "Unexpected run creation failure.")
784
766
  return 0
785
767
 
786
- def store_server_private_public_key(
787
- self, private_key: bytes, public_key: bytes
788
- ) -> None:
789
- """Store `server_private_key` and `server_public_key` in the link state."""
790
- query = "SELECT COUNT(*) FROM credential"
791
- count = self.query(query)[0]["COUNT(*)"]
792
- if count < 1:
793
- query = (
794
- "INSERT OR REPLACE INTO credential (private_key, public_key) "
795
- "VALUES (:private_key, :public_key)"
796
- )
797
- self.query(query, {"private_key": private_key, "public_key": public_key})
798
- else:
799
- raise RuntimeError("Server private and public key already set")
800
-
801
- def get_server_private_key(self) -> Optional[bytes]:
802
- """Retrieve `server_private_key` in urlsafe bytes."""
803
- query = "SELECT private_key FROM credential"
804
- rows = self.query(query)
805
- try:
806
- private_key: Optional[bytes] = rows[0]["private_key"]
807
- except IndexError:
808
- private_key = None
809
- return private_key
810
-
811
- def get_server_public_key(self) -> Optional[bytes]:
812
- """Retrieve `server_public_key` in urlsafe bytes."""
813
- query = "SELECT public_key FROM credential"
814
- rows = self.query(query)
815
- try:
816
- public_key: Optional[bytes] = rows[0]["public_key"]
817
- except IndexError:
818
- public_key = None
819
- return public_key
820
-
821
- def clear_supernode_auth_keys_and_credentials(self) -> None:
822
- """Clear stored `node_public_keys` and credentials in the link state if any."""
823
- queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
824
- for query in queries:
825
- self.query(query)
768
+ def clear_supernode_auth_keys(self) -> None:
769
+ """Clear stored `node_public_keys` in the link state if any."""
770
+ self.query("DELETE FROM public_key;")
826
771
 
827
772
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
828
773
  """Store a set of `node_public_keys` in the link state."""
@@ -982,17 +927,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
982
927
  """Acknowledge a ping received from a node, serving as a heartbeat."""
983
928
  sint64_node_id = convert_uint64_to_sint64(node_id)
984
929
 
985
- # Update `online_until` and `ping_interval` for the given `node_id`
986
- query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
987
- try:
988
- self.query(
989
- query, (time.time() + ping_interval, ping_interval, sint64_node_id)
990
- )
991
- return True
992
- except sqlite3.IntegrityError:
993
- log(ERROR, "`node_id` does not exist.")
930
+ # Check if the node exists in the `node` table
931
+ query = "SELECT 1 FROM node WHERE node_id = ?"
932
+ if not self.query(query, (sint64_node_id,)):
994
933
  return False
995
934
 
935
+ # Update `online_until` and `ping_interval` for the given `node_id`
936
+ query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
937
+ self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
938
+ return True
939
+
996
940
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
997
941
  """Get the context for the specified `run_id`."""
998
942
  # Retrieve context if any
@@ -1105,9 +1049,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1105
1049
  "task_id": task_msg.task_id,
1106
1050
  "group_id": task_msg.group_id,
1107
1051
  "run_id": task_msg.run_id,
1108
- "producer_anonymous": task_msg.task.producer.anonymous,
1109
1052
  "producer_node_id": task_msg.task.producer.node_id,
1110
- "consumer_anonymous": task_msg.task.consumer.anonymous,
1111
1053
  "consumer_node_id": task_msg.task.consumer.node_id,
1112
1054
  "created_at": task_msg.task.created_at,
1113
1055
  "delivered_at": task_msg.task.delivered_at,
@@ -1126,9 +1068,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1126
1068
  "task_id": task_msg.task_id,
1127
1069
  "group_id": task_msg.group_id,
1128
1070
  "run_id": task_msg.run_id,
1129
- "producer_anonymous": task_msg.task.producer.anonymous,
1130
1071
  "producer_node_id": task_msg.task.producer.node_id,
1131
- "consumer_anonymous": task_msg.task.consumer.anonymous,
1132
1072
  "consumer_node_id": task_msg.task.consumer.node_id,
1133
1073
  "created_at": task_msg.task.created_at,
1134
1074
  "delivered_at": task_msg.task.delivered_at,
@@ -1153,11 +1093,9 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1153
1093
  task=Task(
1154
1094
  producer=Node(
1155
1095
  node_id=task_dict["producer_node_id"],
1156
- anonymous=task_dict["producer_anonymous"],
1157
1096
  ),
1158
1097
  consumer=Node(
1159
1098
  node_id=task_dict["consumer_node_id"],
1160
- anonymous=task_dict["consumer_anonymous"],
1161
1099
  ),
1162
1100
  created_at=task_dict["created_at"],
1163
1101
  delivered_at=task_dict["delivered_at"],
@@ -1183,11 +1121,9 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1183
1121
  task=Task(
1184
1122
  producer=Node(
1185
1123
  node_id=task_dict["producer_node_id"],
1186
- anonymous=task_dict["producer_anonymous"],
1187
1124
  ),
1188
1125
  consumer=Node(
1189
1126
  node_id=task_dict["consumer_node_id"],
1190
- anonymous=task_dict["consumer_anonymous"],
1191
1127
  ),
1192
1128
  created_at=task_dict["created_at"],
1193
1129
  delivered_at=task_dict["delivered_at"],
@@ -21,7 +21,7 @@ from typing import Optional, Union
21
21
  from uuid import UUID, uuid4
22
22
 
23
23
  from flwr.common import ConfigsRecord, Context, log, now, serde
24
- from flwr.common.constant import ErrorCode, Status, SubStatus
24
+ from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
25
25
  from flwr.common.typing import RunStatus
26
26
 
27
27
  # pylint: disable=E0611
@@ -60,9 +60,19 @@ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
60
60
  )
61
61
 
62
62
 
63
- def generate_rand_int_from_bytes(num_bytes: int) -> int:
64
- """Generate a random unsigned integer from `num_bytes` bytes."""
65
- return int.from_bytes(urandom(num_bytes), "little", signed=False)
63
+ def generate_rand_int_from_bytes(
64
+ num_bytes: int, exclude: Optional[list[int]] = None
65
+ ) -> int:
66
+ """Generate a random unsigned integer from `num_bytes` bytes.
67
+
68
+ If `exclude` is set, this function guarantees such number is not returned.
69
+ """
70
+ num = int.from_bytes(urandom(num_bytes), "little", signed=False)
71
+
72
+ if exclude:
73
+ while num in exclude:
74
+ num = int.from_bytes(urandom(num_bytes), "little", signed=False)
75
+ return num
66
76
 
67
77
 
68
78
  def convert_uint64_to_sint64(u: int) -> int:
@@ -246,8 +256,8 @@ def create_taskres_for_unavailable_taskins(taskins_id: Union[str, UUID]) -> Task
246
256
  run_id=0, # Unknown run ID
247
257
  task=Task(
248
258
  # This function is only called by SuperLink, and thus it's the producer.
249
- producer=Node(node_id=0, anonymous=False),
250
- consumer=Node(node_id=0, anonymous=False),
259
+ producer=Node(node_id=SUPERLINK_NODE_ID),
260
+ consumer=Node(node_id=SUPERLINK_NODE_ID),
251
261
  created_at=current_time,
252
262
  ttl=0,
253
263
  ancestry=[str(taskins_id)],
@@ -285,8 +295,8 @@ def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
285
295
  run_id=ref_taskins.run_id,
286
296
  task=Task(
287
297
  # This function is only called by SuperLink, and thus it's the producer.
288
- producer=Node(node_id=0, anonymous=False),
289
- consumer=Node(node_id=0, anonymous=False),
298
+ producer=Node(node_id=SUPERLINK_NODE_ID),
299
+ consumer=Node(node_id=SUPERLINK_NODE_ID),
290
300
  created_at=current_time,
291
301
  ttl=ttl,
292
302
  ancestry=[ref_taskins.task_id],
@@ -21,6 +21,7 @@ from typing import Optional
21
21
  import grpc
22
22
 
23
23
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
24
+ from flwr.common.grpc import generic_create_grpc_server
24
25
  from flwr.common.logger import log
25
26
  from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
26
27
  add_SimulationIoServicer_to_server,
@@ -28,7 +29,6 @@ from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
28
29
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
29
30
  from flwr.server.superlink.linkstate import LinkStateFactory
30
31
 
31
- from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
32
32
  from .simulationio_servicer import SimulationIoServicer
33
33
 
34
34
 
@@ -18,6 +18,7 @@
18
18
  import time
19
19
  from typing import Union
20
20
 
21
+ from flwr.common.constant import SUPERLINK_NODE_ID
21
22
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
22
23
 
23
24
 
@@ -58,24 +59,14 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str
58
59
  # Task producer
59
60
  if not tasks_ins_res.task.HasField("producer"):
60
61
  validation_errors.append("`producer` does not set field `producer`")
61
- if tasks_ins_res.task.producer.node_id != 0:
62
- validation_errors.append("`producer.node_id` is not 0")
63
- if not tasks_ins_res.task.producer.anonymous:
64
- validation_errors.append("`producer` is not anonymous")
62
+ if tasks_ins_res.task.producer.node_id != SUPERLINK_NODE_ID:
63
+ validation_errors.append(f"`producer.node_id` is not {SUPERLINK_NODE_ID}")
65
64
 
66
65
  # Task consumer
67
66
  if not tasks_ins_res.task.HasField("consumer"):
68
67
  validation_errors.append("`consumer` does not set field `consumer`")
69
- if (
70
- tasks_ins_res.task.consumer.anonymous
71
- and tasks_ins_res.task.consumer.node_id != 0
72
- ):
73
- validation_errors.append("anonymous consumers MUST NOT set a `node_id`")
74
- if (
75
- not tasks_ins_res.task.consumer.anonymous
76
- and tasks_ins_res.task.consumer.node_id == 0
77
- ):
78
- validation_errors.append("non-anonymous consumer MUST provide a `node_id`")
68
+ if tasks_ins_res.task.consumer.node_id == SUPERLINK_NODE_ID:
69
+ validation_errors.append("consumer MUST provide a valid `node_id`")
79
70
 
80
71
  # Content check
81
72
  if tasks_ins_res.task.task_type == "":
@@ -95,30 +86,14 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str
95
86
  # Task producer
96
87
  if not tasks_ins_res.task.HasField("producer"):
97
88
  validation_errors.append("`producer` does not set field `producer`")
98
- if (
99
- tasks_ins_res.task.producer.anonymous
100
- and tasks_ins_res.task.producer.node_id != 0
101
- ):
102
- validation_errors.append("anonymous producers MUST NOT set a `node_id`")
103
- if (
104
- not tasks_ins_res.task.producer.anonymous
105
- and tasks_ins_res.task.producer.node_id == 0
106
- ):
107
- validation_errors.append("non-anonymous producer MUST provide a `node_id`")
89
+ if tasks_ins_res.task.producer.node_id == SUPERLINK_NODE_ID:
90
+ validation_errors.append("producer MUST provide a valid `node_id`")
108
91
 
109
92
  # Task consumer
110
93
  if not tasks_ins_res.task.HasField("consumer"):
111
94
  validation_errors.append("`consumer` does not set field `consumer`")
112
- if (
113
- tasks_ins_res.task.consumer.anonymous
114
- and tasks_ins_res.task.consumer.node_id != 0
115
- ):
116
- validation_errors.append("anonymous consumers MUST NOT set a `node_id`")
117
- if (
118
- not tasks_ins_res.task.consumer.anonymous
119
- and tasks_ins_res.task.consumer.node_id == 0
120
- ):
121
- validation_errors.append("non-anonymous consumer MUST provide a `node_id`")
95
+ if tasks_ins_res.task.consumer.node_id != SUPERLINK_NODE_ID:
96
+ validation_errors.append(f"consumer is not {SUPERLINK_NODE_ID}")
122
97
 
123
98
  # Content check
124
99
  if tasks_ins_res.task.task_type == "":
flwr/simulation/app.py CHANGED
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import sys
20
19
  from logging import DEBUG, ERROR, INFO
21
20
  from queue import Queue
22
21
  from time import sleep
@@ -39,6 +38,7 @@ from flwr.common.constant import (
39
38
  Status,
40
39
  SubStatus,
41
40
  )
41
+ from flwr.common.exit import ExitCode, flwr_exit
42
42
  from flwr.common.logger import (
43
43
  log,
44
44
  mirror_output_to_queue,
@@ -81,12 +81,10 @@ def flwr_simulation() -> None:
81
81
  log(INFO, "Starting Flower Simulation")
82
82
 
83
83
  if not args.insecure:
84
- log(
85
- ERROR,
86
- "`flwr-simulation` does not support TLS yet. "
87
- "Please use the '--insecure' flag.",
84
+ flwr_exit(
85
+ ExitCode.COMMON_TLS_NOT_SUPPORTED,
86
+ "`flwr-simulation` does not support TLS yet. ",
88
87
  )
89
- sys.exit(1)
90
88
 
91
89
  log(
92
90
  DEBUG,
@@ -29,7 +29,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
29
29
 
30
30
  from flwr.client import ClientFnExt
31
31
  from flwr.common import EventType, event
32
- from flwr.common.constant import NODE_ID_NUM_BYTES
32
+ from flwr.common.constant import NODE_ID_NUM_BYTES, SUPERLINK_NODE_ID
33
33
  from flwr.common.logger import (
34
34
  log,
35
35
  set_logger_propagation,
@@ -87,7 +87,9 @@ def _create_node_id_to_partition_mapping(
87
87
  nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
88
88
  for i in range(num_clients):
89
89
  while True:
90
- node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
90
+ node_id = generate_rand_int_from_bytes(
91
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
92
+ )
91
93
  if node_id not in nodes_mapping:
92
94
  break
93
95
  nodes_mapping[node_id] = i
@@ -350,7 +350,7 @@ def _main_loop(
350
350
  # Initialize Driver
351
351
  driver = InMemoryDriver(state_factory=state_factory)
352
352
  driver.set_run(run_id=run.run_id)
353
- output_context_queue: "Queue[Context]" = Queue()
353
+ output_context_queue: Queue[Context] = Queue()
354
354
 
355
355
  # Get and run ServerApp thread
356
356
  serverapp_th = run_serverapp_th(
@@ -21,7 +21,7 @@ from typing import Optional, cast
21
21
  import grpc
22
22
 
23
23
  from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS
24
- from flwr.common.grpc import create_channel
24
+ from flwr.common.grpc import create_channel, on_channel_state_change
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
27
27
  from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
@@ -73,6 +73,7 @@ class SimulationIoConnection:
73
73
  insecure=(self._cert is None),
74
74
  root_certificates=self._cert,
75
75
  )
76
+ self._channel.subscribe(on_channel_state_change)
76
77
  self._grpc_stub = SimulationIoStub(self._channel)
77
78
  _wrap_stub(self._grpc_stub, self._retry_invoker)
78
79
  log(DEBUG, "[SimulationIO] Connected to %s", self._addr)
@@ -23,11 +23,11 @@ import grpc
23
23
 
24
24
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
25
25
  from flwr.common.auth_plugin import ExecAuthPlugin
26
+ from flwr.common.grpc import generic_create_grpc_server
26
27
  from flwr.common.logger import log
27
28
  from flwr.common.typing import UserConfig
28
29
  from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
29
30
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
30
- from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
31
31
  from flwr.server.superlink.linkstate import LinkStateFactory
32
32
  from flwr.superexec.exec_user_auth_interceptor import ExecUserAuthInterceptor
33
33