flwr 1.13.1__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/auth_plugin/__init__.py +31 -0
  3. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  4. flwr/cli/build.py +1 -0
  5. flwr/cli/cli_user_auth_interceptor.py +90 -0
  6. flwr/cli/config_utils.py +43 -149
  7. flwr/cli/constant.py +27 -0
  8. flwr/cli/example.py +1 -0
  9. flwr/cli/install.py +2 -1
  10. flwr/cli/log.py +34 -37
  11. flwr/cli/login/__init__.py +22 -0
  12. flwr/cli/login/login.py +116 -0
  13. flwr/cli/ls.py +214 -106
  14. flwr/cli/new/__init__.py +1 -0
  15. flwr/cli/new/new.py +2 -1
  16. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  17. flwr/cli/new/templates/app/README.md.tpl +3 -2
  18. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  19. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  20. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  21. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  22. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -4
  23. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  24. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  25. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  26. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  27. flwr/cli/run/__init__.py +1 -0
  28. flwr/cli/run/run.py +103 -43
  29. flwr/cli/stop.py +139 -0
  30. flwr/cli/utils.py +186 -8
  31. flwr/client/app.py +49 -50
  32. flwr/client/client.py +1 -32
  33. flwr/client/clientapp/app.py +23 -26
  34. flwr/client/clientapp/utils.py +2 -1
  35. flwr/client/grpc_adapter_client/connection.py +1 -1
  36. flwr/client/grpc_client/connection.py +2 -13
  37. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  38. flwr/client/grpc_rere_client/connection.py +59 -43
  39. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  40. flwr/client/message_handler/message_handler.py +1 -2
  41. flwr/client/message_handler/task_handler.py +0 -17
  42. flwr/client/mod/comms_mods.py +1 -0
  43. flwr/client/mod/localdp_mod.py +1 -1
  44. flwr/client/nodestate/__init__.py +1 -0
  45. flwr/client/nodestate/nodestate.py +1 -0
  46. flwr/client/nodestate/nodestate_factory.py +1 -0
  47. flwr/client/numpy_client.py +0 -44
  48. flwr/client/rest_client/connection.py +37 -29
  49. flwr/client/supernode/app.py +20 -74
  50. flwr/common/address.py +1 -0
  51. flwr/common/args.py +26 -47
  52. flwr/common/auth_plugin/__init__.py +24 -0
  53. flwr/common/auth_plugin/auth_plugin.py +122 -0
  54. flwr/common/config.py +169 -17
  55. flwr/common/constant.py +38 -9
  56. flwr/common/differential_privacy.py +2 -1
  57. flwr/common/exit/__init__.py +24 -0
  58. flwr/common/exit/exit.py +99 -0
  59. flwr/common/exit/exit_code.py +93 -0
  60. flwr/common/exit_handlers.py +24 -10
  61. flwr/common/grpc.py +167 -4
  62. flwr/common/logger.py +66 -7
  63. flwr/common/message.py +1 -0
  64. flwr/common/object_ref.py +57 -54
  65. flwr/common/pyproject.py +1 -0
  66. flwr/common/record/__init__.py +1 -0
  67. flwr/common/record/parametersrecord.py +1 -0
  68. flwr/common/record/recordset.py +1 -1
  69. flwr/common/retry_invoker.py +77 -0
  70. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  71. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  72. flwr/common/serde.py +6 -4
  73. flwr/common/telemetry.py +15 -4
  74. flwr/common/typing.py +32 -0
  75. flwr/common/version.py +1 -0
  76. flwr/proto/clientappio_pb2.py +1 -1
  77. flwr/proto/error_pb2.py +1 -1
  78. flwr/proto/exec_pb2.py +27 -15
  79. flwr/proto/exec_pb2.pyi +80 -2
  80. flwr/proto/exec_pb2_grpc.py +102 -0
  81. flwr/proto/exec_pb2_grpc.pyi +39 -0
  82. flwr/proto/fab_pb2.py +5 -5
  83. flwr/proto/fab_pb2.pyi +4 -1
  84. flwr/proto/fleet_pb2.py +31 -31
  85. flwr/proto/fleet_pb2.pyi +23 -23
  86. flwr/proto/fleet_pb2_grpc.py +30 -30
  87. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  88. flwr/proto/grpcadapter_pb2.py +1 -1
  89. flwr/proto/log_pb2.py +1 -1
  90. flwr/proto/message_pb2.py +1 -1
  91. flwr/proto/node_pb2.py +3 -3
  92. flwr/proto/node_pb2.pyi +1 -4
  93. flwr/proto/recordset_pb2.py +1 -1
  94. flwr/proto/run_pb2.py +1 -1
  95. flwr/proto/serverappio_pb2.py +24 -25
  96. flwr/proto/serverappio_pb2.pyi +32 -32
  97. flwr/proto/serverappio_pb2_grpc.py +62 -28
  98. flwr/proto/serverappio_pb2_grpc.pyi +29 -16
  99. flwr/proto/simulationio_pb2.py +3 -3
  100. flwr/proto/simulationio_pb2_grpc.py +34 -0
  101. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  102. flwr/proto/task_pb2.py +1 -1
  103. flwr/proto/transport_pb2.py +1 -1
  104. flwr/server/app.py +152 -112
  105. flwr/server/compat/app_utils.py +7 -2
  106. flwr/server/compat/driver_client_proxy.py +1 -2
  107. flwr/server/driver/grpc_driver.py +38 -85
  108. flwr/server/driver/inmemory_driver.py +7 -2
  109. flwr/server/run_serverapp.py +8 -9
  110. flwr/server/serverapp/app.py +37 -13
  111. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  112. flwr/server/superlink/driver/serverappio_grpc.py +2 -1
  113. flwr/server/superlink/driver/serverappio_servicer.py +148 -63
  114. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  115. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -87
  116. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  117. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  118. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +56 -35
  119. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +99 -169
  120. flwr/server/superlink/fleet/message_handler/message_handler.py +69 -29
  121. flwr/server/superlink/fleet/rest_rere/rest_api.py +20 -19
  122. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  123. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  124. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  125. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  126. flwr/server/superlink/linkstate/in_memory_linkstate.py +60 -99
  127. flwr/server/superlink/linkstate/linkstate.py +30 -36
  128. flwr/server/superlink/linkstate/sqlite_linkstate.py +105 -188
  129. flwr/server/superlink/linkstate/utils.py +18 -8
  130. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  131. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  132. flwr/server/superlink/utils.py +65 -0
  133. flwr/server/utils/validator.py +9 -34
  134. flwr/simulation/app.py +20 -10
  135. flwr/simulation/legacy_app.py +4 -2
  136. flwr/simulation/ray_transport/ray_actor.py +1 -0
  137. flwr/simulation/ray_transport/utils.py +1 -0
  138. flwr/simulation/run_simulation.py +36 -22
  139. flwr/simulation/simulationio_connection.py +5 -1
  140. flwr/superexec/app.py +1 -0
  141. flwr/superexec/deployment.py +1 -0
  142. flwr/superexec/exec_grpc.py +20 -2
  143. flwr/superexec/exec_servicer.py +97 -2
  144. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  145. flwr/superexec/executor.py +1 -0
  146. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/METADATA +14 -13
  147. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/RECORD +150 -144
  148. flwr/proto/common_pb2.py +0 -36
  149. flwr/proto/common_pb2.pyi +0 -121
  150. flwr/proto/common_pb2_grpc.py +0 -4
  151. flwr/proto/common_pb2_grpc.pyi +0 -4
  152. flwr/proto/control_pb2.py +0 -27
  153. flwr/proto/control_pb2.pyi +0 -7
  154. flwr/proto/control_pb2_grpc.py +0 -135
  155. flwr/proto/control_pb2_grpc.pyi +0 -53
  156. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
  157. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
  158. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
  """SQLite based implemenation of the link state."""
16
16
 
17
+
17
18
  # pylint: disable=too-many-lines
18
19
 
19
20
  import json
20
21
  import re
21
22
  import sqlite3
22
- import threading
23
23
  import time
24
24
  from collections.abc import Sequence
25
25
  from logging import DEBUG, ERROR, WARNING
@@ -31,6 +31,7 @@ from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
33
33
  RUN_ID_NUM_BYTES,
34
+ SUPERLINK_NODE_ID,
34
35
  Status,
35
36
  )
36
37
  from flwr.common.record import ConfigsRecord
@@ -70,16 +71,9 @@ CREATE TABLE IF NOT EXISTS node(
70
71
  );
71
72
  """
72
73
 
73
- SQL_CREATE_TABLE_CREDENTIAL = """
74
- CREATE TABLE IF NOT EXISTS credential(
75
- private_key BLOB PRIMARY KEY,
76
- public_key BLOB
77
- );
78
- """
79
-
80
74
  SQL_CREATE_TABLE_PUBLIC_KEY = """
81
75
  CREATE TABLE IF NOT EXISTS public_key(
82
- public_key BLOB UNIQUE
76
+ public_key BLOB PRIMARY KEY
83
77
  );
84
78
  """
85
79
 
@@ -128,9 +122,7 @@ CREATE TABLE IF NOT EXISTS task_ins(
128
122
  task_id TEXT UNIQUE,
129
123
  group_id TEXT,
130
124
  run_id INTEGER,
131
- producer_anonymous BOOLEAN,
132
125
  producer_node_id INTEGER,
133
- consumer_anonymous BOOLEAN,
134
126
  consumer_node_id INTEGER,
135
127
  created_at REAL,
136
128
  delivered_at TEXT,
@@ -148,9 +140,7 @@ CREATE TABLE IF NOT EXISTS task_res(
148
140
  task_id TEXT UNIQUE,
149
141
  group_id TEXT,
150
142
  run_id INTEGER,
151
- producer_anonymous BOOLEAN,
152
143
  producer_node_id INTEGER,
153
- consumer_anonymous BOOLEAN,
154
144
  consumer_node_id INTEGER,
155
145
  created_at REAL,
156
146
  delivered_at TEXT,
@@ -183,7 +173,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
183
173
  """
184
174
  self.database_path = database_path
185
175
  self.conn: Optional[sqlite3.Connection] = None
186
- self.lock = threading.RLock()
187
176
 
188
177
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
189
178
  """Create tables if they don't exist yet.
@@ -212,11 +201,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
212
201
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
213
202
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
214
203
  cur.execute(SQL_CREATE_TABLE_NODE)
215
- cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
216
204
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
217
205
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
218
206
  res = cur.execute("SELECT name FROM sqlite_schema;")
219
-
220
207
  return res.fetchall()
221
208
 
222
209
  def query(
@@ -265,11 +252,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
265
252
 
266
253
  Constraints
267
254
  -----------
268
- If `task_ins.task.consumer.anonymous` is `True`, then
269
- `task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
270
255
 
271
- If `task_ins.task.consumer.anonymous` is `False`, then
272
- `task_ins.task.consumer.node_id` MUST be set (not 0)
256
+ `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
273
257
  """
274
258
  # Validate task
275
259
  errors = validate_task_ins_or_res(task_ins)
@@ -294,7 +278,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
294
278
  log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
295
279
  return None
296
280
  # Validate source node ID
297
- if task_ins.task.producer.node_id != 0:
281
+ if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
298
282
  log(
299
283
  ERROR,
300
284
  "Invalid source node ID for TaskIns: %s",
@@ -303,14 +287,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
303
287
  return None
304
288
  # Validate destination node ID
305
289
  query = "SELECT node_id FROM node WHERE node_id = ?;"
306
- if not task_ins.task.consumer.anonymous:
307
- if not self.query(query, (data[0]["consumer_node_id"],)):
308
- log(
309
- ERROR,
310
- "Invalid destination node ID for TaskIns: %s",
311
- task_ins.task.consumer.node_id,
312
- )
313
- return None
290
+ if not self.query(query, (data[0]["consumer_node_id"],)):
291
+ log(
292
+ ERROR,
293
+ "Invalid destination node ID for TaskIns: %s",
294
+ task_ins.task.consumer.node_id,
295
+ )
296
+ return None
314
297
 
315
298
  columns = ", ".join([f":{key}" for key in data[0]])
316
299
  query = f"INSERT INTO task_ins VALUES({columns});"
@@ -321,25 +304,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
321
304
 
322
305
  return task_id
323
306
 
324
- def get_task_ins(
325
- self, node_id: Optional[int], limit: Optional[int]
326
- ) -> list[TaskIns]:
327
- """Get undelivered TaskIns for one node (either anonymous or with ID).
307
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
308
+ """Get undelivered TaskIns for one node.
328
309
 
329
310
  Usually, the Fleet API calls this for Nodes planning to work on one or more
330
311
  TaskIns.
331
312
 
332
313
  Constraints
333
314
  -----------
334
- If `node_id` is not `None`, retrieve all TaskIns where
315
+ Retrieve all TaskIns where
335
316
 
336
317
  1. the `task_ins.task.consumer.node_id` equals `node_id` AND
337
- 2. the `task_ins.task.consumer.anonymous` equals `False` AND
338
- 3. the `task_ins.task.delivered_at` equals `""`.
339
-
340
- If `node_id` is `None`, retrieve all TaskIns where the
341
- `task_ins.task.consumer.node_id` equals `0` and
342
- `task_ins.task.consumer.anonymous` is set to `True`.
318
+ 2. the `task_ins.task.delivered_at` equals `""`.
343
319
 
344
320
  `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
345
321
  the result.
@@ -350,38 +326,23 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
350
326
  if limit is not None and limit < 1:
351
327
  raise AssertionError("`limit` must be >= 1")
352
328
 
353
- if node_id == 0:
354
- msg = (
355
- "`node_id` must be >= 1"
356
- "\n\n For requesting anonymous tasks use `node_id` equal `None`"
357
- )
329
+ if node_id == SUPERLINK_NODE_ID:
330
+ msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
358
331
  raise AssertionError(msg)
359
332
 
360
333
  data: dict[str, Union[str, int]] = {}
361
334
 
362
- if node_id is None:
363
- # Retrieve all anonymous Tasks
364
- query = """
365
- SELECT task_id
366
- FROM task_ins
367
- WHERE consumer_anonymous == 1
368
- AND consumer_node_id == 0
369
- AND delivered_at = ""
370
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
371
- """
372
- else:
373
- # Convert the uint64 value to sint64 for SQLite
374
- data["node_id"] = convert_uint64_to_sint64(node_id)
335
+ # Convert the uint64 value to sint64 for SQLite
336
+ data["node_id"] = convert_uint64_to_sint64(node_id)
375
337
 
376
- # Retrieve all TaskIns for node_id
377
- query = """
378
- SELECT task_id
379
- FROM task_ins
380
- WHERE consumer_anonymous == 0
381
- AND consumer_node_id == :node_id
382
- AND delivered_at = ""
383
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
384
- """
338
+ # Retrieve all TaskIns for node_id
339
+ query = """
340
+ SELECT task_id
341
+ FROM task_ins
342
+ WHERE consumer_node_id == :node_id
343
+ AND delivered_at = ""
344
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
345
+ """
385
346
 
386
347
  if limit is not None:
387
348
  query += " LIMIT :limit"
@@ -431,11 +392,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
431
392
 
432
393
  Constraints
433
394
  -----------
434
- If `task_res.task.consumer.anonymous` is `True`, then
435
- `task_res.task.consumer.node_id` MUST NOT be set (equal 0).
436
-
437
- If `task_res.task.consumer.anonymous` is `False`, then
438
- `task_res.task.consumer.node_id` MUST be set (not 0)
395
+ `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
439
396
  """
440
397
  # Validate task
441
398
  errors = validate_task_ins_or_res(task_res)
@@ -461,7 +418,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
461
418
  if (
462
419
  task_ins
463
420
  and task_res
464
- and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
465
421
  and convert_sint64_to_uint64(task_ins["consumer_node_id"])
466
422
  != task_res.task.producer.node_id
467
423
  ):
@@ -569,9 +525,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
569
525
  data: list[Any] = [delivered_at] + task_res_ids
570
526
  self.query(query, data)
571
527
 
572
- # Cleanup
573
- self._force_delete_tasks_by_ids(set(ret.keys()))
574
-
575
528
  return list(ret.values())
576
529
 
577
530
  def num_task_ins(self) -> int:
@@ -595,86 +548,61 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
595
548
  result: dict[str, int] = rows[0]
596
549
  return result["num"]
597
550
 
598
- def delete_tasks(self, task_ids: set[UUID]) -> None:
599
- """Delete all delivered TaskIns/TaskRes pairs."""
600
- ids = list(task_ids)
601
- if len(ids) == 0:
602
- return None
551
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
552
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
553
+ if not task_ins_ids:
554
+ return
555
+ if self.conn is None:
556
+ raise AttributeError("LinkState not initialized")
603
557
 
604
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
605
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
558
+ placeholders = ",".join(["?"] * len(task_ins_ids))
559
+ data = tuple(str(task_id) for task_id in task_ins_ids)
606
560
 
607
- # 1. Query: Delete task_ins which have a delivered task_res
561
+ # Delete task_ins
608
562
  query_1 = f"""
609
563
  DELETE FROM task_ins
610
- WHERE delivered_at != ''
611
- AND task_id IN (
612
- SELECT ancestry
613
- FROM task_res
614
- WHERE ancestry IN ({placeholders})
615
- AND delivered_at != ''
616
- );
564
+ WHERE task_id IN ({placeholders});
617
565
  """
618
566
 
619
- # 2. Query: Delete delivered task_res to be run after 1. Query
567
+ # Delete task_res
620
568
  query_2 = f"""
621
569
  DELETE FROM task_res
622
- WHERE ancestry IN ({placeholders})
623
- AND delivered_at != '';
570
+ WHERE ancestry IN ({placeholders});
624
571
  """
625
572
 
626
- if self.conn is None:
627
- raise AttributeError("LinkState not intitialized")
628
-
629
573
  with self.conn:
630
574
  self.conn.execute(query_1, data)
631
575
  self.conn.execute(query_2, data)
632
576
 
633
- return None
634
-
635
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
636
- """Delete tasks based on a set of TaskIns IDs."""
637
- if not task_ids:
638
- return
577
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
578
+ """Get all TaskIns IDs for the given run_id."""
639
579
  if self.conn is None:
640
580
  raise AttributeError("LinkState not initialized")
641
581
 
642
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
643
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
644
-
645
- # Delete task_ins
646
- query_1 = f"""
647
- DELETE FROM task_ins
648
- WHERE task_id IN ({placeholders});
582
+ query = """
583
+ SELECT task_id
584
+ FROM task_ins
585
+ WHERE run_id = :run_id;
649
586
  """
650
587
 
651
- # Delete task_res
652
- query_2 = f"""
653
- DELETE FROM task_res
654
- WHERE ancestry IN ({placeholders});
655
- """
588
+ sint64_run_id = convert_uint64_to_sint64(run_id)
589
+ data = {"run_id": sint64_run_id}
656
590
 
657
591
  with self.conn:
658
- self.conn.execute(query_1, data)
659
- self.conn.execute(query_2, data)
592
+ rows = self.conn.execute(query, data).fetchall()
660
593
 
661
- def create_node(
662
- self, ping_interval: float, public_key: Optional[bytes] = None
663
- ) -> int:
594
+ return {UUID(row["task_id"]) for row in rows}
595
+
596
+ def create_node(self, ping_interval: float) -> int:
664
597
  """Create, store in the link state, and return `node_id`."""
665
598
  # Sample a random uint64 as node_id
666
- uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
599
+ uint64_node_id = generate_rand_int_from_bytes(
600
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
601
+ )
667
602
 
668
603
  # Convert the uint64 value to sint64 for SQLite
669
604
  sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
670
605
 
671
- query = "SELECT node_id FROM node WHERE public_key = :public_key;"
672
- row = self.query(query, {"public_key": public_key})
673
-
674
- if len(row) > 0:
675
- log(ERROR, "Unexpected node registration failure.")
676
- return 0
677
-
678
606
  query = (
679
607
  "INSERT INTO node "
680
608
  "(node_id, online_until, ping_interval, public_key) "
@@ -688,7 +616,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
688
616
  sint64_node_id,
689
617
  time.time() + ping_interval,
690
618
  ping_interval,
691
- public_key,
619
+ b"", # Initialize with an empty public key
692
620
  ),
693
621
  )
694
622
  except sqlite3.IntegrityError:
@@ -698,7 +626,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
698
626
  # Note: we need to return the uint64 value of the node_id
699
627
  return uint64_node_id
700
628
 
701
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
629
+ def delete_node(self, node_id: int) -> None:
702
630
  """Delete a node."""
703
631
  # Convert the uint64 value to sint64 for SQLite
704
632
  sint64_node_id = convert_uint64_to_sint64(node_id)
@@ -706,10 +634,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
706
634
  query = "DELETE FROM node WHERE node_id = ?"
707
635
  params = (sint64_node_id,)
708
636
 
709
- if public_key is not None:
710
- query += " AND public_key = ?"
711
- params += (public_key,) # type: ignore
712
-
713
637
  if self.conn is None:
714
638
  raise AttributeError("LinkState is not initialized.")
715
639
 
@@ -717,7 +641,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
717
641
  with self.conn:
718
642
  rows = self.conn.execute(query, params)
719
643
  if rows.rowcount < 1:
720
- raise ValueError("Public key or node_id not found")
644
+ raise ValueError(f"Node {node_id} not found")
721
645
  except KeyError as exc:
722
646
  log(ERROR, {"query": query, "data": params, "exception": exc})
723
647
 
@@ -745,6 +669,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
745
669
  result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
746
670
  return result
747
671
 
672
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
673
+ """Set `public_key` for the specified `node_id`."""
674
+ # Convert the uint64 value to sint64 for SQLite
675
+ sint64_node_id = convert_uint64_to_sint64(node_id)
676
+
677
+ # Check if the node exists in the `node` table
678
+ query = "SELECT 1 FROM node WHERE node_id = ?"
679
+ if not self.query(query, (sint64_node_id,)):
680
+ raise ValueError(f"Node {node_id} not found")
681
+
682
+ # Check if the public key is already in use in the `node` table
683
+ query = "SELECT 1 FROM node WHERE public_key = ?"
684
+ if self.query(query, (public_key,)):
685
+ raise ValueError("Public key already in use")
686
+
687
+ # Update the `node` table to set the public key for the given node ID
688
+ query = "UPDATE node SET public_key = ? WHERE node_id = ?"
689
+ self.query(query, (public_key, sint64_node_id))
690
+
691
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
692
+ """Get `public_key` for the specified `node_id`."""
693
+ # Convert the uint64 value to sint64 for SQLite
694
+ sint64_node_id = convert_uint64_to_sint64(node_id)
695
+
696
+ # Query the public key for the given node_id
697
+ query = "SELECT public_key FROM node WHERE node_id = ?"
698
+ rows = self.query(query, (sint64_node_id,))
699
+
700
+ # If no result is found, return None
701
+ if not rows:
702
+ raise ValueError(f"Node {node_id} not found")
703
+
704
+ # Return the public key if it is not empty, otherwise return None
705
+ return rows[0]["public_key"] or None
706
+
748
707
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
749
708
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
750
709
  query = "SELECT node_id FROM node WHERE public_key = :public_key;"
@@ -784,8 +743,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
784
743
  "federation_options, pending_at, starting_at, running_at, finished_at, "
785
744
  "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
786
745
  )
787
- if fab_hash:
788
- fab_id, fab_version = "", ""
789
746
  override_config_json = json.dumps(override_config)
790
747
  data = [
791
748
  sint64_run_id,
@@ -808,40 +765,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
808
765
  log(ERROR, "Unexpected run creation failure.")
809
766
  return 0
810
767
 
811
- def store_server_private_public_key(
812
- self, private_key: bytes, public_key: bytes
813
- ) -> None:
814
- """Store `server_private_key` and `server_public_key` in the link state."""
815
- query = "SELECT COUNT(*) FROM credential"
816
- count = self.query(query)[0]["COUNT(*)"]
817
- if count < 1:
818
- query = (
819
- "INSERT OR REPLACE INTO credential (private_key, public_key) "
820
- "VALUES (:private_key, :public_key)"
821
- )
822
- self.query(query, {"private_key": private_key, "public_key": public_key})
823
- else:
824
- raise RuntimeError("Server private and public key already set")
825
-
826
- def get_server_private_key(self) -> Optional[bytes]:
827
- """Retrieve `server_private_key` in urlsafe bytes."""
828
- query = "SELECT private_key FROM credential"
829
- rows = self.query(query)
830
- try:
831
- private_key: Optional[bytes] = rows[0]["private_key"]
832
- except IndexError:
833
- private_key = None
834
- return private_key
835
-
836
- def get_server_public_key(self) -> Optional[bytes]:
837
- """Retrieve `server_public_key` in urlsafe bytes."""
838
- query = "SELECT public_key FROM credential"
839
- rows = self.query(query)
840
- try:
841
- public_key: Optional[bytes] = rows[0]["public_key"]
842
- except IndexError:
843
- public_key = None
844
- return public_key
768
+ def clear_supernode_auth_keys(self) -> None:
769
+ """Clear stored `node_public_keys` in the link state if any."""
770
+ self.query("DELETE FROM public_key;")
845
771
 
846
772
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
847
773
  """Store a set of `node_public_keys` in the link state."""
@@ -1001,17 +927,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1001
927
  """Acknowledge a ping received from a node, serving as a heartbeat."""
1002
928
  sint64_node_id = convert_uint64_to_sint64(node_id)
1003
929
 
1004
- # Update `online_until` and `ping_interval` for the given `node_id`
1005
- query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
1006
- try:
1007
- self.query(
1008
- query, (time.time() + ping_interval, ping_interval, sint64_node_id)
1009
- )
1010
- return True
1011
- except sqlite3.IntegrityError:
1012
- log(ERROR, "`node_id` does not exist.")
930
+ # Check if the node exists in the `node` table
931
+ query = "SELECT 1 FROM node WHERE node_id = ?"
932
+ if not self.query(query, (sint64_node_id,)):
1013
933
  return False
1014
934
 
935
+ # Update `online_until` and `ping_interval` for the given `node_id`
936
+ query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
937
+ self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
938
+ return True
939
+
1015
940
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
1016
941
  """Get the context for the specified `run_id`."""
1017
942
  # Retrieve context if any
@@ -1124,9 +1049,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1124
1049
  "task_id": task_msg.task_id,
1125
1050
  "group_id": task_msg.group_id,
1126
1051
  "run_id": task_msg.run_id,
1127
- "producer_anonymous": task_msg.task.producer.anonymous,
1128
1052
  "producer_node_id": task_msg.task.producer.node_id,
1129
- "consumer_anonymous": task_msg.task.consumer.anonymous,
1130
1053
  "consumer_node_id": task_msg.task.consumer.node_id,
1131
1054
  "created_at": task_msg.task.created_at,
1132
1055
  "delivered_at": task_msg.task.delivered_at,
@@ -1145,9 +1068,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1145
1068
  "task_id": task_msg.task_id,
1146
1069
  "group_id": task_msg.group_id,
1147
1070
  "run_id": task_msg.run_id,
1148
- "producer_anonymous": task_msg.task.producer.anonymous,
1149
1071
  "producer_node_id": task_msg.task.producer.node_id,
1150
- "consumer_anonymous": task_msg.task.consumer.anonymous,
1151
1072
  "consumer_node_id": task_msg.task.consumer.node_id,
1152
1073
  "created_at": task_msg.task.created_at,
1153
1074
  "delivered_at": task_msg.task.delivered_at,
@@ -1172,11 +1093,9 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1172
1093
  task=Task(
1173
1094
  producer=Node(
1174
1095
  node_id=task_dict["producer_node_id"],
1175
- anonymous=task_dict["producer_anonymous"],
1176
1096
  ),
1177
1097
  consumer=Node(
1178
1098
  node_id=task_dict["consumer_node_id"],
1179
- anonymous=task_dict["consumer_anonymous"],
1180
1099
  ),
1181
1100
  created_at=task_dict["created_at"],
1182
1101
  delivered_at=task_dict["delivered_at"],
@@ -1202,11 +1121,9 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1202
1121
  task=Task(
1203
1122
  producer=Node(
1204
1123
  node_id=task_dict["producer_node_id"],
1205
- anonymous=task_dict["producer_anonymous"],
1206
1124
  ),
1207
1125
  consumer=Node(
1208
1126
  node_id=task_dict["consumer_node_id"],
1209
- anonymous=task_dict["consumer_anonymous"],
1210
1127
  ),
1211
1128
  created_at=task_dict["created_at"],
1212
1129
  delivered_at=task_dict["delivered_at"],
@@ -21,7 +21,7 @@ from typing import Optional, Union
21
21
  from uuid import UUID, uuid4
22
22
 
23
23
  from flwr.common import ConfigsRecord, Context, log, now, serde
24
- from flwr.common.constant import ErrorCode, Status, SubStatus
24
+ from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
25
25
  from flwr.common.typing import RunStatus
26
26
 
27
27
  # pylint: disable=E0611
@@ -60,9 +60,19 @@ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
60
60
  )
61
61
 
62
62
 
63
- def generate_rand_int_from_bytes(num_bytes: int) -> int:
64
- """Generate a random unsigned integer from `num_bytes` bytes."""
65
- return int.from_bytes(urandom(num_bytes), "little", signed=False)
63
+ def generate_rand_int_from_bytes(
64
+ num_bytes: int, exclude: Optional[list[int]] = None
65
+ ) -> int:
66
+ """Generate a random unsigned integer from `num_bytes` bytes.
67
+
68
+ If `exclude` is set, this function guarantees such number is not returned.
69
+ """
70
+ num = int.from_bytes(urandom(num_bytes), "little", signed=False)
71
+
72
+ if exclude:
73
+ while num in exclude:
74
+ num = int.from_bytes(urandom(num_bytes), "little", signed=False)
75
+ return num
66
76
 
67
77
 
68
78
  def convert_uint64_to_sint64(u: int) -> int:
@@ -246,8 +256,8 @@ def create_taskres_for_unavailable_taskins(taskins_id: Union[str, UUID]) -> Task
246
256
  run_id=0, # Unknown run ID
247
257
  task=Task(
248
258
  # This function is only called by SuperLink, and thus it's the producer.
249
- producer=Node(node_id=0, anonymous=False),
250
- consumer=Node(node_id=0, anonymous=False),
259
+ producer=Node(node_id=SUPERLINK_NODE_ID),
260
+ consumer=Node(node_id=SUPERLINK_NODE_ID),
251
261
  created_at=current_time,
252
262
  ttl=0,
253
263
  ancestry=[str(taskins_id)],
@@ -285,8 +295,8 @@ def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
285
295
  run_id=ref_taskins.run_id,
286
296
  task=Task(
287
297
  # This function is only called by SuperLink, and thus it's the producer.
288
- producer=Node(node_id=0, anonymous=False),
289
- consumer=Node(node_id=0, anonymous=False),
298
+ producer=Node(node_id=SUPERLINK_NODE_ID),
299
+ consumer=Node(node_id=SUPERLINK_NODE_ID),
290
300
  created_at=current_time,
291
301
  ttl=ttl,
292
302
  ancestry=[ref_taskins.task_id],
@@ -21,6 +21,7 @@ from typing import Optional
21
21
  import grpc
22
22
 
23
23
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
24
+ from flwr.common.grpc import generic_create_grpc_server
24
25
  from flwr.common.logger import log
25
26
  from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
26
27
  add_SimulationIoServicer_to_server,
@@ -28,7 +29,6 @@ from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
28
29
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
29
30
  from flwr.server.superlink.linkstate import LinkStateFactory
30
31
 
31
- from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
32
32
  from .simulationio_servicer import SimulationIoServicer
33
33
 
34
34