flwr-nightly 1.15.0.dev20250104__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/cli_user_auth_interceptor.py +6 -2
  2. flwr/cli/config_utils.py +23 -146
  3. flwr/cli/constant.py +27 -0
  4. flwr/cli/install.py +1 -1
  5. flwr/cli/log.py +17 -2
  6. flwr/cli/login/login.py +20 -5
  7. flwr/cli/ls.py +10 -2
  8. flwr/cli/run/run.py +20 -10
  9. flwr/cli/stop.py +9 -1
  10. flwr/cli/utils.py +4 -4
  11. flwr/client/app.py +36 -48
  12. flwr/client/clientapp/app.py +4 -6
  13. flwr/client/clientapp/utils.py +1 -1
  14. flwr/client/grpc_client/connection.py +0 -6
  15. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  16. flwr/client/grpc_rere_client/connection.py +34 -24
  17. flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
  18. flwr/client/rest_client/connection.py +34 -26
  19. flwr/client/supernode/app.py +14 -20
  20. flwr/common/auth_plugin/auth_plugin.py +34 -23
  21. flwr/common/config.py +152 -15
  22. flwr/common/constant.py +11 -8
  23. flwr/common/exit/__init__.py +24 -0
  24. flwr/common/exit/exit.py +99 -0
  25. flwr/common/exit/exit_code.py +93 -0
  26. flwr/common/exit_handlers.py +24 -10
  27. flwr/common/grpc.py +161 -3
  28. flwr/common/logger.py +1 -1
  29. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  30. flwr/common/serde.py +6 -4
  31. flwr/common/typing.py +20 -0
  32. flwr/proto/clientappio_pb2.py +13 -3
  33. flwr/proto/clientappio_pb2_grpc.py +63 -12
  34. flwr/proto/error_pb2.py +13 -3
  35. flwr/proto/error_pb2_grpc.py +20 -0
  36. flwr/proto/exec_pb2.py +27 -29
  37. flwr/proto/exec_pb2.pyi +27 -54
  38. flwr/proto/exec_pb2_grpc.py +105 -24
  39. flwr/proto/fab_pb2.py +13 -3
  40. flwr/proto/fab_pb2_grpc.py +20 -0
  41. flwr/proto/fleet_pb2.py +54 -31
  42. flwr/proto/fleet_pb2.pyi +84 -0
  43. flwr/proto/fleet_pb2_grpc.py +207 -28
  44. flwr/proto/fleet_pb2_grpc.pyi +26 -0
  45. flwr/proto/grpcadapter_pb2.py +14 -4
  46. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  47. flwr/proto/log_pb2.py +13 -3
  48. flwr/proto/log_pb2_grpc.py +20 -0
  49. flwr/proto/message_pb2.py +15 -5
  50. flwr/proto/message_pb2_grpc.py +20 -0
  51. flwr/proto/node_pb2.py +15 -5
  52. flwr/proto/node_pb2.pyi +1 -4
  53. flwr/proto/node_pb2_grpc.py +20 -0
  54. flwr/proto/recordset_pb2.py +18 -8
  55. flwr/proto/recordset_pb2_grpc.py +20 -0
  56. flwr/proto/run_pb2.py +16 -6
  57. flwr/proto/run_pb2_grpc.py +20 -0
  58. flwr/proto/serverappio_pb2.py +32 -14
  59. flwr/proto/serverappio_pb2.pyi +56 -0
  60. flwr/proto/serverappio_pb2_grpc.py +261 -44
  61. flwr/proto/serverappio_pb2_grpc.pyi +20 -0
  62. flwr/proto/simulationio_pb2.py +13 -3
  63. flwr/proto/simulationio_pb2_grpc.py +105 -24
  64. flwr/proto/task_pb2.py +13 -3
  65. flwr/proto/task_pb2_grpc.py +20 -0
  66. flwr/proto/transport_pb2.py +20 -10
  67. flwr/proto/transport_pb2_grpc.py +35 -4
  68. flwr/server/app.py +87 -38
  69. flwr/server/compat/app_utils.py +0 -1
  70. flwr/server/compat/driver_client_proxy.py +1 -2
  71. flwr/server/driver/grpc_driver.py +5 -2
  72. flwr/server/driver/inmemory_driver.py +2 -1
  73. flwr/server/serverapp/app.py +5 -6
  74. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  75. flwr/server/superlink/driver/serverappio_servicer.py +132 -14
  76. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  77. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  78. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +38 -0
  79. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -168
  80. flwr/server/superlink/fleet/message_handler/message_handler.py +66 -5
  81. flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -3
  82. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  83. flwr/server/superlink/linkstate/in_memory_linkstate.py +40 -48
  84. flwr/server/superlink/linkstate/linkstate.py +15 -22
  85. flwr/server/superlink/linkstate/sqlite_linkstate.py +80 -99
  86. flwr/server/superlink/linkstate/utils.py +18 -8
  87. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  88. flwr/server/utils/validator.py +9 -34
  89. flwr/simulation/app.py +4 -6
  90. flwr/simulation/legacy_app.py +4 -2
  91. flwr/simulation/run_simulation.py +1 -1
  92. flwr/superexec/exec_grpc.py +1 -1
  93. flwr/superexec/exec_servicer.py +23 -2
  94. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +7 -7
  95. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +98 -94
  96. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
  97. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
  98. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
@@ -28,6 +28,7 @@ from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
30
30
  RUN_ID_NUM_BYTES,
31
+ SUPERLINK_NODE_ID,
31
32
  Status,
32
33
  )
33
34
  from flwr.common.record import ConfigsRecord
@@ -62,6 +63,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
62
63
  # Map node_id to (online_until, ping_interval)
63
64
  self.node_ids: dict[int, tuple[float, float]] = {}
64
65
  self.public_key_to_node_id: dict[bytes, int] = {}
66
+ self.node_id_to_public_key: dict[int, bytes] = {}
65
67
 
66
68
  # Map run_id to RunRecord
67
69
  self.run_ids: dict[int, RunRecord] = {}
@@ -89,7 +91,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
89
91
  log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
90
92
  return None
91
93
  # Validate source node ID
92
- if task_ins.task.producer.node_id != 0:
94
+ if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
93
95
  log(
94
96
  ERROR,
95
97
  "Invalid source node ID for TaskIns: %s",
@@ -97,14 +99,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
97
99
  )
98
100
  return None
99
101
  # Validate destination node ID
100
- if not task_ins.task.consumer.anonymous:
101
- if task_ins.task.consumer.node_id not in self.node_ids:
102
- log(
103
- ERROR,
104
- "Invalid destination node ID for TaskIns: %s",
105
- task_ins.task.consumer.node_id,
106
- )
107
- return None
102
+ if task_ins.task.consumer.node_id not in self.node_ids:
103
+ log(
104
+ ERROR,
105
+ "Invalid destination node ID for TaskIns: %s",
106
+ task_ins.task.consumer.node_id,
107
+ )
108
+ return None
108
109
 
109
110
  # Create task_id
110
111
  task_id = uuid4()
@@ -117,9 +118,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
117
118
  # Return the new task_id
118
119
  return task_id
119
120
 
120
- def get_task_ins(
121
- self, node_id: Optional[int], limit: Optional[int]
122
- ) -> list[TaskIns]:
121
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
123
122
  """Get all TaskIns that have not been delivered yet."""
124
123
  if limit is not None and limit < 1:
125
124
  raise AssertionError("`limit` must be >= 1")
@@ -129,17 +128,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
129
128
  current_time = time.time()
130
129
  with self.lock:
131
130
  for _, task_ins in self.task_ins_store.items():
132
- # pylint: disable=too-many-boolean-expressions
133
131
  if (
134
- node_id is not None # Not anonymous
135
- and task_ins.task.consumer.anonymous is False
136
- and task_ins.task.consumer.node_id == node_id
137
- and task_ins.task.delivered_at == ""
138
- and task_ins.task.created_at + task_ins.task.ttl > current_time
139
- ) or (
140
- node_id is None # Anonymous
141
- and task_ins.task.consumer.anonymous is True
142
- and task_ins.task.consumer.node_id == 0
132
+ task_ins.task.consumer.node_id == node_id
143
133
  and task_ins.task.delivered_at == ""
144
134
  and task_ins.task.created_at + task_ins.task.ttl > current_time
145
135
  ):
@@ -173,9 +163,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
173
163
  if (
174
164
  task_ins
175
165
  and task_res
176
- and not (
177
- task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
178
- )
179
166
  and task_ins.task.consumer.node_id != task_res.task.producer.node_id
180
167
  ):
181
168
  return None
@@ -306,45 +293,30 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
306
293
  """
307
294
  return len(self.task_res_store)
308
295
 
309
- def create_node(
310
- self, ping_interval: float, public_key: Optional[bytes] = None
311
- ) -> int:
296
+ def create_node(self, ping_interval: float) -> int:
312
297
  """Create, store in the link state, and return `node_id`."""
313
298
  # Sample a random int64 as node_id
314
- node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
299
+ node_id = generate_rand_int_from_bytes(
300
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
301
+ )
315
302
 
316
303
  with self.lock:
317
304
  if node_id in self.node_ids:
318
305
  log(ERROR, "Unexpected node registration failure.")
319
306
  return 0
320
307
 
321
- if public_key is not None:
322
- if (
323
- public_key in self.public_key_to_node_id
324
- or node_id in self.public_key_to_node_id.values()
325
- ):
326
- log(ERROR, "Unexpected node registration failure.")
327
- return 0
328
-
329
- self.public_key_to_node_id[public_key] = node_id
330
-
331
308
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
332
309
  return node_id
333
310
 
334
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
311
+ def delete_node(self, node_id: int) -> None:
335
312
  """Delete a node."""
336
313
  with self.lock:
337
314
  if node_id not in self.node_ids:
338
315
  raise ValueError(f"Node {node_id} not found")
339
316
 
340
- if public_key is not None:
341
- if (
342
- public_key not in self.public_key_to_node_id
343
- or node_id not in self.public_key_to_node_id.values()
344
- ):
345
- raise ValueError("Public key or node_id not found")
346
-
347
- del self.public_key_to_node_id[public_key]
317
+ # Remove node ID <> public key mappings
318
+ if pk := self.node_id_to_public_key.pop(node_id, None):
319
+ del self.public_key_to_node_id[pk]
348
320
 
349
321
  del self.node_ids[node_id]
350
322
 
@@ -366,6 +338,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
366
338
  if online_until > current_time
367
339
  }
368
340
 
341
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
342
+ """Set `public_key` for the specified `node_id`."""
343
+ with self.lock:
344
+ if node_id not in self.node_ids:
345
+ raise ValueError(f"Node {node_id} not found")
346
+
347
+ if public_key in self.public_key_to_node_id:
348
+ raise ValueError("Public key already in use")
349
+
350
+ self.public_key_to_node_id[public_key] = node_id
351
+ self.node_id_to_public_key[node_id] = public_key
352
+
353
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
354
+ """Get `public_key` for the specified `node_id`."""
355
+ with self.lock:
356
+ if node_id not in self.node_ids:
357
+ raise ValueError(f"Node {node_id} not found")
358
+
359
+ return self.node_id_to_public_key.get(node_id)
360
+
369
361
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
370
362
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
371
363
  return self.public_key_to_node_id.get(node_public_key)
@@ -40,20 +40,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
40
40
 
41
41
  Constraints
42
42
  -----------
43
- If `task_ins.task.consumer.anonymous` is `True`, then
44
- `task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
45
-
46
- If `task_ins.task.consumer.anonymous` is `False`, then
47
- `task_ins.task.consumer.node_id` MUST be set (not 0)
43
+ `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
48
44
 
49
45
  If `task_ins.run_id` is invalid, then
50
46
  storing the `task_ins` MUST fail.
51
47
  """
52
48
 
53
49
  @abc.abstractmethod
54
- def get_task_ins(
55
- self, node_id: Optional[int], limit: Optional[int]
56
- ) -> list[TaskIns]:
50
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
57
51
  """Get TaskIns optionally filtered by node_id.
58
52
 
59
53
  Usually, the Fleet API calls this for Nodes planning to work on one or more
@@ -61,15 +55,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
61
55
 
62
56
  Constraints
63
57
  -----------
64
- If `node_id` is not `None`, retrieve all TaskIns where
58
+ Retrieve all TaskIns where
65
59
 
66
60
  1. the `task_ins.task.consumer.node_id` equals `node_id` AND
67
- 2. the `task_ins.task.consumer.anonymous` equals `False` AND
68
- 3. the `task_ins.task.delivered_at` equals `""`.
61
+ 2. the `task_ins.task.delivered_at` equals `""`.
69
62
 
70
- If `node_id` is `None`, retrieve all TaskIns where the
71
- `task_ins.task.consumer.node_id` equals `0` and
72
- `task_ins.task.consumer.anonymous` is set to `True`.
73
63
 
74
64
  If `delivered_at` MUST BE set (not `""`) otherwise the TaskIns MUST not be in
75
65
  the result.
@@ -89,11 +79,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
89
79
 
90
80
  Constraints
91
81
  -----------
92
- If `task_res.task.consumer.anonymous` is `True`, then
93
- `task_res.task.consumer.node_id` MUST NOT be set (equal 0).
94
82
 
95
- If `task_res.task.consumer.anonymous` is `False`, then
96
- `task_res.task.consumer.node_id` MUST be set (not 0)
83
+ `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
97
84
 
98
85
  If `task_res.run_id` is invalid, then
99
86
  storing the `task_res` MUST fail.
@@ -154,13 +141,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
154
141
  """Get all TaskIns IDs for the given run_id."""
155
142
 
156
143
  @abc.abstractmethod
157
- def create_node(
158
- self, ping_interval: float, public_key: Optional[bytes] = None
159
- ) -> int:
144
+ def create_node(self, ping_interval: float) -> int:
160
145
  """Create, store in the link state, and return `node_id`."""
161
146
 
162
147
  @abc.abstractmethod
163
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
148
+ def delete_node(self, node_id: int) -> None:
164
149
  """Remove `node_id` from the link state."""
165
150
 
166
151
  @abc.abstractmethod
@@ -173,6 +158,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
173
158
  an empty `Set` MUST be returned.
174
159
  """
175
160
 
161
+ @abc.abstractmethod
162
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
163
+ """Set `public_key` for the specified `node_id`."""
164
+
165
+ @abc.abstractmethod
166
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
167
+ """Get `public_key` for the specified `node_id`."""
168
+
176
169
  @abc.abstractmethod
177
170
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
178
171
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
@@ -31,6 +31,7 @@ from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
33
33
  RUN_ID_NUM_BYTES,
34
+ SUPERLINK_NODE_ID,
34
35
  Status,
35
36
  )
36
37
  from flwr.common.record import ConfigsRecord
@@ -72,14 +73,14 @@ CREATE TABLE IF NOT EXISTS node(
72
73
 
73
74
  SQL_CREATE_TABLE_CREDENTIAL = """
74
75
  CREATE TABLE IF NOT EXISTS credential(
75
- private_key BLOB PRIMARY KEY,
76
- public_key BLOB
76
+ private_key BLOB PRIMARY KEY,
77
+ public_key BLOB
77
78
  );
78
79
  """
79
80
 
80
81
  SQL_CREATE_TABLE_PUBLIC_KEY = """
81
82
  CREATE TABLE IF NOT EXISTS public_key(
82
- public_key BLOB UNIQUE
83
+ public_key BLOB PRIMARY KEY
83
84
  );
84
85
  """
85
86
 
@@ -128,9 +129,7 @@ CREATE TABLE IF NOT EXISTS task_ins(
128
129
  task_id TEXT UNIQUE,
129
130
  group_id TEXT,
130
131
  run_id INTEGER,
131
- producer_anonymous BOOLEAN,
132
132
  producer_node_id INTEGER,
133
- consumer_anonymous BOOLEAN,
134
133
  consumer_node_id INTEGER,
135
134
  created_at REAL,
136
135
  delivered_at TEXT,
@@ -148,9 +147,7 @@ CREATE TABLE IF NOT EXISTS task_res(
148
147
  task_id TEXT UNIQUE,
149
148
  group_id TEXT,
150
149
  run_id INTEGER,
151
- producer_anonymous BOOLEAN,
152
150
  producer_node_id INTEGER,
153
- consumer_anonymous BOOLEAN,
154
151
  consumer_node_id INTEGER,
155
152
  created_at REAL,
156
153
  delivered_at TEXT,
@@ -263,11 +260,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
263
260
 
264
261
  Constraints
265
262
  -----------
266
- If `task_ins.task.consumer.anonymous` is `True`, then
267
- `task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
268
263
 
269
- If `task_ins.task.consumer.anonymous` is `False`, then
270
- `task_ins.task.consumer.node_id` MUST be set (not 0)
264
+ `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
271
265
  """
272
266
  # Validate task
273
267
  errors = validate_task_ins_or_res(task_ins)
@@ -292,7 +286,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
292
286
  log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
293
287
  return None
294
288
  # Validate source node ID
295
- if task_ins.task.producer.node_id != 0:
289
+ if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
296
290
  log(
297
291
  ERROR,
298
292
  "Invalid source node ID for TaskIns: %s",
@@ -301,14 +295,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
301
295
  return None
302
296
  # Validate destination node ID
303
297
  query = "SELECT node_id FROM node WHERE node_id = ?;"
304
- if not task_ins.task.consumer.anonymous:
305
- if not self.query(query, (data[0]["consumer_node_id"],)):
306
- log(
307
- ERROR,
308
- "Invalid destination node ID for TaskIns: %s",
309
- task_ins.task.consumer.node_id,
310
- )
311
- return None
298
+ if not self.query(query, (data[0]["consumer_node_id"],)):
299
+ log(
300
+ ERROR,
301
+ "Invalid destination node ID for TaskIns: %s",
302
+ task_ins.task.consumer.node_id,
303
+ )
304
+ return None
312
305
 
313
306
  columns = ", ".join([f":{key}" for key in data[0]])
314
307
  query = f"INSERT INTO task_ins VALUES({columns});"
@@ -319,25 +312,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
319
312
 
320
313
  return task_id
321
314
 
322
- def get_task_ins(
323
- self, node_id: Optional[int], limit: Optional[int]
324
- ) -> list[TaskIns]:
325
- """Get undelivered TaskIns for one node (either anonymous or with ID).
315
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
316
+ """Get undelivered TaskIns for one node.
326
317
 
327
318
  Usually, the Fleet API calls this for Nodes planning to work on one or more
328
319
  TaskIns.
329
320
 
330
321
  Constraints
331
322
  -----------
332
- If `node_id` is not `None`, retrieve all TaskIns where
323
+ Retrieve all TaskIns where
333
324
 
334
325
  1. the `task_ins.task.consumer.node_id` equals `node_id` AND
335
- 2. the `task_ins.task.consumer.anonymous` equals `False` AND
336
- 3. the `task_ins.task.delivered_at` equals `""`.
337
-
338
- If `node_id` is `None`, retrieve all TaskIns where the
339
- `task_ins.task.consumer.node_id` equals `0` and
340
- `task_ins.task.consumer.anonymous` is set to `True`.
326
+ 2. the `task_ins.task.delivered_at` equals `""`.
341
327
 
342
328
  `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
343
329
  the result.
@@ -348,38 +334,23 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
348
334
  if limit is not None and limit < 1:
349
335
  raise AssertionError("`limit` must be >= 1")
350
336
 
351
- if node_id == 0:
352
- msg = (
353
- "`node_id` must be >= 1"
354
- "\n\n For requesting anonymous tasks use `node_id` equal `None`"
355
- )
337
+ if node_id == SUPERLINK_NODE_ID:
338
+ msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
356
339
  raise AssertionError(msg)
357
340
 
358
341
  data: dict[str, Union[str, int]] = {}
359
342
 
360
- if node_id is None:
361
- # Retrieve all anonymous Tasks
362
- query = """
363
- SELECT task_id
364
- FROM task_ins
365
- WHERE consumer_anonymous == 1
366
- AND consumer_node_id == 0
367
- AND delivered_at = ""
368
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
369
- """
370
- else:
371
- # Convert the uint64 value to sint64 for SQLite
372
- data["node_id"] = convert_uint64_to_sint64(node_id)
343
+ # Convert the uint64 value to sint64 for SQLite
344
+ data["node_id"] = convert_uint64_to_sint64(node_id)
373
345
 
374
- # Retrieve all TaskIns for node_id
375
- query = """
376
- SELECT task_id
377
- FROM task_ins
378
- WHERE consumer_anonymous == 0
379
- AND consumer_node_id == :node_id
380
- AND delivered_at = ""
381
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
382
- """
346
+ # Retrieve all TaskIns for node_id
347
+ query = """
348
+ SELECT task_id
349
+ FROM task_ins
350
+ WHERE consumer_node_id == :node_id
351
+ AND delivered_at = ""
352
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
353
+ """
383
354
 
384
355
  if limit is not None:
385
356
  query += " LIMIT :limit"
@@ -429,11 +400,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
429
400
 
430
401
  Constraints
431
402
  -----------
432
- If `task_res.task.consumer.anonymous` is `True`, then
433
- `task_res.task.consumer.node_id` MUST NOT be set (equal 0).
434
-
435
- If `task_res.task.consumer.anonymous` is `False`, then
436
- `task_res.task.consumer.node_id` MUST be set (not 0)
403
+ `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
437
404
  """
438
405
  # Validate task
439
406
  errors = validate_task_ins_or_res(task_res)
@@ -459,7 +426,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
459
426
  if (
460
427
  task_ins
461
428
  and task_res
462
- and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
463
429
  and convert_sint64_to_uint64(task_ins["consumer_node_id"])
464
430
  != task_res.task.producer.node_id
465
431
  ):
@@ -635,23 +601,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
635
601
 
636
602
  return {UUID(row["task_id"]) for row in rows}
637
603
 
638
- def create_node(
639
- self, ping_interval: float, public_key: Optional[bytes] = None
640
- ) -> int:
604
+ def create_node(self, ping_interval: float) -> int:
641
605
  """Create, store in the link state, and return `node_id`."""
642
606
  # Sample a random uint64 as node_id
643
- uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
607
+ uint64_node_id = generate_rand_int_from_bytes(
608
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
609
+ )
644
610
 
645
611
  # Convert the uint64 value to sint64 for SQLite
646
612
  sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
647
613
 
648
- query = "SELECT node_id FROM node WHERE public_key = :public_key;"
649
- row = self.query(query, {"public_key": public_key})
650
-
651
- if len(row) > 0:
652
- log(ERROR, "Unexpected node registration failure.")
653
- return 0
654
-
655
614
  query = (
656
615
  "INSERT INTO node "
657
616
  "(node_id, online_until, ping_interval, public_key) "
@@ -665,7 +624,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
665
624
  sint64_node_id,
666
625
  time.time() + ping_interval,
667
626
  ping_interval,
668
- public_key,
627
+ b"", # Initialize with an empty public key
669
628
  ),
670
629
  )
671
630
  except sqlite3.IntegrityError:
@@ -675,7 +634,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
675
634
  # Note: we need to return the uint64 value of the node_id
676
635
  return uint64_node_id
677
636
 
678
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
637
+ def delete_node(self, node_id: int) -> None:
679
638
  """Delete a node."""
680
639
  # Convert the uint64 value to sint64 for SQLite
681
640
  sint64_node_id = convert_uint64_to_sint64(node_id)
@@ -683,10 +642,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
683
642
  query = "DELETE FROM node WHERE node_id = ?"
684
643
  params = (sint64_node_id,)
685
644
 
686
- if public_key is not None:
687
- query += " AND public_key = ?"
688
- params += (public_key,) # type: ignore
689
-
690
645
  if self.conn is None:
691
646
  raise AttributeError("LinkState is not initialized.")
692
647
 
@@ -694,7 +649,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
694
649
  with self.conn:
695
650
  rows = self.conn.execute(query, params)
696
651
  if rows.rowcount < 1:
697
- raise ValueError("Public key or node_id not found")
652
+ raise ValueError(f"Node {node_id} not found")
698
653
  except KeyError as exc:
699
654
  log(ERROR, {"query": query, "data": params, "exception": exc})
700
655
 
@@ -722,6 +677,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
722
677
  result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
723
678
  return result
724
679
 
680
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
681
+ """Set `public_key` for the specified `node_id`."""
682
+ # Convert the uint64 value to sint64 for SQLite
683
+ sint64_node_id = convert_uint64_to_sint64(node_id)
684
+
685
+ # Check if the node exists in the `node` table
686
+ query = "SELECT 1 FROM node WHERE node_id = ?"
687
+ if not self.query(query, (sint64_node_id,)):
688
+ raise ValueError(f"Node {node_id} not found")
689
+
690
+ # Check if the public key is already in use in the `node` table
691
+ query = "SELECT 1 FROM node WHERE public_key = ?"
692
+ if self.query(query, (public_key,)):
693
+ raise ValueError("Public key already in use")
694
+
695
+ # Update the `node` table to set the public key for the given node ID
696
+ query = "UPDATE node SET public_key = ? WHERE node_id = ?"
697
+ self.query(query, (public_key, sint64_node_id))
698
+
699
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
700
+ """Get `public_key` for the specified `node_id`."""
701
+ # Convert the uint64 value to sint64 for SQLite
702
+ sint64_node_id = convert_uint64_to_sint64(node_id)
703
+
704
+ # Query the public key for the given node_id
705
+ query = "SELECT public_key FROM node WHERE node_id = ?"
706
+ rows = self.query(query, (sint64_node_id,))
707
+
708
+ # If no result is found, return None
709
+ if not rows:
710
+ raise ValueError(f"Node {node_id} not found")
711
+
712
+ # Return the public key if it is not empty, otherwise return None
713
+ return rows[0]["public_key"] or None
714
+
725
715
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
726
716
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
727
717
  query = "SELECT node_id FROM node WHERE public_key = :public_key;"
@@ -982,17 +972,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
982
972
  """Acknowledge a ping received from a node, serving as a heartbeat."""
983
973
  sint64_node_id = convert_uint64_to_sint64(node_id)
984
974
 
985
- # Update `online_until` and `ping_interval` for the given `node_id`
986
- query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
987
- try:
988
- self.query(
989
- query, (time.time() + ping_interval, ping_interval, sint64_node_id)
990
- )
991
- return True
992
- except sqlite3.IntegrityError:
993
- log(ERROR, "`node_id` does not exist.")
975
+ # Check if the node exists in the `node` table
976
+ query = "SELECT 1 FROM node WHERE node_id = ?"
977
+ if not self.query(query, (sint64_node_id,)):
994
978
  return False
995
979
 
980
+ # Update `online_until` and `ping_interval` for the given `node_id`
981
+ query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
982
+ self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
983
+ return True
984
+
996
985
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
997
986
  """Get the context for the specified `run_id`."""
998
987
  # Retrieve context if any
@@ -1105,9 +1094,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1105
1094
  "task_id": task_msg.task_id,
1106
1095
  "group_id": task_msg.group_id,
1107
1096
  "run_id": task_msg.run_id,
1108
- "producer_anonymous": task_msg.task.producer.anonymous,
1109
1097
  "producer_node_id": task_msg.task.producer.node_id,
1110
- "consumer_anonymous": task_msg.task.consumer.anonymous,
1111
1098
  "consumer_node_id": task_msg.task.consumer.node_id,
1112
1099
  "created_at": task_msg.task.created_at,
1113
1100
  "delivered_at": task_msg.task.delivered_at,
@@ -1126,9 +1113,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1126
1113
  "task_id": task_msg.task_id,
1127
1114
  "group_id": task_msg.group_id,
1128
1115
  "run_id": task_msg.run_id,
1129
- "producer_anonymous": task_msg.task.producer.anonymous,
1130
1116
  "producer_node_id": task_msg.task.producer.node_id,
1131
- "consumer_anonymous": task_msg.task.consumer.anonymous,
1132
1117
  "consumer_node_id": task_msg.task.consumer.node_id,
1133
1118
  "created_at": task_msg.task.created_at,
1134
1119
  "delivered_at": task_msg.task.delivered_at,
@@ -1153,11 +1138,9 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1153
1138
  task=Task(
1154
1139
  producer=Node(
1155
1140
  node_id=task_dict["producer_node_id"],
1156
- anonymous=task_dict["producer_anonymous"],
1157
1141
  ),
1158
1142
  consumer=Node(
1159
1143
  node_id=task_dict["consumer_node_id"],
1160
- anonymous=task_dict["consumer_anonymous"],
1161
1144
  ),
1162
1145
  created_at=task_dict["created_at"],
1163
1146
  delivered_at=task_dict["delivered_at"],
@@ -1183,11 +1166,9 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1183
1166
  task=Task(
1184
1167
  producer=Node(
1185
1168
  node_id=task_dict["producer_node_id"],
1186
- anonymous=task_dict["producer_anonymous"],
1187
1169
  ),
1188
1170
  consumer=Node(
1189
1171
  node_id=task_dict["consumer_node_id"],
1190
- anonymous=task_dict["consumer_anonymous"],
1191
1172
  ),
1192
1173
  created_at=task_dict["created_at"],
1193
1174
  delivered_at=task_dict["delivered_at"],
@@ -21,7 +21,7 @@ from typing import Optional, Union
21
21
  from uuid import UUID, uuid4
22
22
 
23
23
  from flwr.common import ConfigsRecord, Context, log, now, serde
24
- from flwr.common.constant import ErrorCode, Status, SubStatus
24
+ from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
25
25
  from flwr.common.typing import RunStatus
26
26
 
27
27
  # pylint: disable=E0611
@@ -60,9 +60,19 @@ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
60
60
  )
61
61
 
62
62
 
63
- def generate_rand_int_from_bytes(num_bytes: int) -> int:
64
- """Generate a random unsigned integer from `num_bytes` bytes."""
65
- return int.from_bytes(urandom(num_bytes), "little", signed=False)
63
+ def generate_rand_int_from_bytes(
64
+ num_bytes: int, exclude: Optional[list[int]] = None
65
+ ) -> int:
66
+ """Generate a random unsigned integer from `num_bytes` bytes.
67
+
68
+ If `exclude` is set, this function guarantees such number is not returned.
69
+ """
70
+ num = int.from_bytes(urandom(num_bytes), "little", signed=False)
71
+
72
+ if exclude:
73
+ while num in exclude:
74
+ num = int.from_bytes(urandom(num_bytes), "little", signed=False)
75
+ return num
66
76
 
67
77
 
68
78
  def convert_uint64_to_sint64(u: int) -> int:
@@ -246,8 +256,8 @@ def create_taskres_for_unavailable_taskins(taskins_id: Union[str, UUID]) -> Task
246
256
  run_id=0, # Unknown run ID
247
257
  task=Task(
248
258
  # This function is only called by SuperLink, and thus it's the producer.
249
- producer=Node(node_id=0, anonymous=False),
250
- consumer=Node(node_id=0, anonymous=False),
259
+ producer=Node(node_id=SUPERLINK_NODE_ID),
260
+ consumer=Node(node_id=SUPERLINK_NODE_ID),
251
261
  created_at=current_time,
252
262
  ttl=0,
253
263
  ancestry=[str(taskins_id)],
@@ -285,8 +295,8 @@ def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
285
295
  run_id=ref_taskins.run_id,
286
296
  task=Task(
287
297
  # This function is only called by SuperLink, and thus it's the producer.
288
- producer=Node(node_id=0, anonymous=False),
289
- consumer=Node(node_id=0, anonymous=False),
298
+ producer=Node(node_id=SUPERLINK_NODE_ID),
299
+ consumer=Node(node_id=SUPERLINK_NODE_ID),
290
300
  created_at=current_time,
291
301
  ttl=ttl,
292
302
  ancestry=[ref_taskins.task_id],