flwr 1.14.0__py3-none-any.whl → 1.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. flwr/cli/auth_plugin/__init__.py +31 -0
  2. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  3. flwr/cli/cli_user_auth_interceptor.py +6 -2
  4. flwr/cli/config_utils.py +24 -147
  5. flwr/cli/constant.py +27 -0
  6. flwr/cli/install.py +1 -1
  7. flwr/cli/log.py +18 -3
  8. flwr/cli/login/login.py +43 -8
  9. flwr/cli/ls.py +14 -5
  10. flwr/cli/new/templates/app/README.md.tpl +3 -2
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  20. flwr/cli/run/run.py +21 -11
  21. flwr/cli/stop.py +13 -4
  22. flwr/cli/utils.py +54 -40
  23. flwr/client/app.py +36 -48
  24. flwr/client/clientapp/app.py +19 -25
  25. flwr/client/clientapp/utils.py +1 -1
  26. flwr/client/grpc_client/connection.py +1 -12
  27. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  28. flwr/client/grpc_rere_client/connection.py +46 -36
  29. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  30. flwr/client/message_handler/task_handler.py +0 -17
  31. flwr/client/rest_client/connection.py +34 -26
  32. flwr/client/supernode/app.py +18 -72
  33. flwr/common/args.py +25 -47
  34. flwr/common/auth_plugin/auth_plugin.py +34 -23
  35. flwr/common/config.py +166 -16
  36. flwr/common/constant.py +24 -9
  37. flwr/common/differential_privacy.py +2 -1
  38. flwr/common/exit/__init__.py +24 -0
  39. flwr/common/exit/exit.py +99 -0
  40. flwr/common/exit/exit_code.py +93 -0
  41. flwr/common/exit_handlers.py +32 -30
  42. flwr/common/grpc.py +167 -4
  43. flwr/common/logger.py +26 -7
  44. flwr/common/object_ref.py +0 -14
  45. flwr/common/record/recordset.py +1 -1
  46. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  47. flwr/common/serde.py +6 -4
  48. flwr/common/typing.py +20 -0
  49. flwr/proto/clientappio_pb2.py +1 -1
  50. flwr/proto/error_pb2.py +1 -1
  51. flwr/proto/exec_pb2.py +13 -25
  52. flwr/proto/exec_pb2.pyi +27 -54
  53. flwr/proto/fab_pb2.py +1 -1
  54. flwr/proto/fleet_pb2.py +31 -31
  55. flwr/proto/fleet_pb2.pyi +23 -23
  56. flwr/proto/fleet_pb2_grpc.py +30 -30
  57. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  58. flwr/proto/grpcadapter_pb2.py +1 -1
  59. flwr/proto/log_pb2.py +1 -1
  60. flwr/proto/message_pb2.py +1 -1
  61. flwr/proto/node_pb2.py +3 -3
  62. flwr/proto/node_pb2.pyi +1 -4
  63. flwr/proto/recordset_pb2.py +1 -1
  64. flwr/proto/run_pb2.py +1 -1
  65. flwr/proto/serverappio_pb2.py +24 -25
  66. flwr/proto/serverappio_pb2.pyi +26 -32
  67. flwr/proto/serverappio_pb2_grpc.py +28 -28
  68. flwr/proto/serverappio_pb2_grpc.pyi +16 -16
  69. flwr/proto/simulationio_pb2.py +1 -1
  70. flwr/proto/task_pb2.py +1 -1
  71. flwr/proto/transport_pb2.py +1 -1
  72. flwr/server/app.py +116 -128
  73. flwr/server/compat/app_utils.py +0 -1
  74. flwr/server/compat/driver_client_proxy.py +1 -2
  75. flwr/server/driver/grpc_driver.py +32 -27
  76. flwr/server/driver/inmemory_driver.py +2 -1
  77. flwr/server/serverapp/app.py +12 -10
  78. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  79. flwr/server/superlink/driver/serverappio_servicer.py +74 -48
  80. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  81. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  82. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
  83. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +110 -168
  84. flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
  85. flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
  86. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  87. flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
  88. flwr/server/superlink/linkstate/linkstate.py +17 -38
  89. flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
  90. flwr/server/superlink/linkstate/utils.py +18 -8
  91. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  92. flwr/server/utils/validator.py +9 -34
  93. flwr/simulation/app.py +4 -6
  94. flwr/simulation/legacy_app.py +4 -2
  95. flwr/simulation/run_simulation.py +1 -1
  96. flwr/simulation/simulationio_connection.py +2 -1
  97. flwr/superexec/exec_grpc.py +1 -1
  98. flwr/superexec/exec_servicer.py +23 -2
  99. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/METADATA +8 -8
  100. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/RECORD +103 -97
  101. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/LICENSE +0 -0
  102. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/WHEEL +0 -0
  103. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/entry_points.txt +0 -0
@@ -28,6 +28,7 @@ from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
30
30
  RUN_ID_NUM_BYTES,
31
+ SUPERLINK_NODE_ID,
31
32
  Status,
32
33
  )
33
34
  from flwr.common.record import ConfigsRecord
@@ -62,6 +63,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
62
63
  # Map node_id to (online_until, ping_interval)
63
64
  self.node_ids: dict[int, tuple[float, float]] = {}
64
65
  self.public_key_to_node_id: dict[bytes, int] = {}
66
+ self.node_id_to_public_key: dict[int, bytes] = {}
65
67
 
66
68
  # Map run_id to RunRecord
67
69
  self.run_ids: dict[int, RunRecord] = {}
@@ -72,8 +74,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
72
74
  self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
73
75
 
74
76
  self.node_public_keys: set[bytes] = set()
75
- self.server_public_key: Optional[bytes] = None
76
- self.server_private_key: Optional[bytes] = None
77
77
 
78
78
  self.lock = threading.RLock()
79
79
 
@@ -89,7 +89,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
89
89
  log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
90
90
  return None
91
91
  # Validate source node ID
92
- if task_ins.task.producer.node_id != 0:
92
+ if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
93
93
  log(
94
94
  ERROR,
95
95
  "Invalid source node ID for TaskIns: %s",
@@ -97,14 +97,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
97
97
  )
98
98
  return None
99
99
  # Validate destination node ID
100
- if not task_ins.task.consumer.anonymous:
101
- if task_ins.task.consumer.node_id not in self.node_ids:
102
- log(
103
- ERROR,
104
- "Invalid destination node ID for TaskIns: %s",
105
- task_ins.task.consumer.node_id,
106
- )
107
- return None
100
+ if task_ins.task.consumer.node_id not in self.node_ids:
101
+ log(
102
+ ERROR,
103
+ "Invalid destination node ID for TaskIns: %s",
104
+ task_ins.task.consumer.node_id,
105
+ )
106
+ return None
108
107
 
109
108
  # Create task_id
110
109
  task_id = uuid4()
@@ -117,9 +116,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
117
116
  # Return the new task_id
118
117
  return task_id
119
118
 
120
- def get_task_ins(
121
- self, node_id: Optional[int], limit: Optional[int]
122
- ) -> list[TaskIns]:
119
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
123
120
  """Get all TaskIns that have not been delivered yet."""
124
121
  if limit is not None and limit < 1:
125
122
  raise AssertionError("`limit` must be >= 1")
@@ -129,17 +126,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
129
126
  current_time = time.time()
130
127
  with self.lock:
131
128
  for _, task_ins in self.task_ins_store.items():
132
- # pylint: disable=too-many-boolean-expressions
133
129
  if (
134
- node_id is not None # Not anonymous
135
- and task_ins.task.consumer.anonymous is False
136
- and task_ins.task.consumer.node_id == node_id
137
- and task_ins.task.delivered_at == ""
138
- and task_ins.task.created_at + task_ins.task.ttl > current_time
139
- ) or (
140
- node_id is None # Anonymous
141
- and task_ins.task.consumer.anonymous is True
142
- and task_ins.task.consumer.node_id == 0
130
+ task_ins.task.consumer.node_id == node_id
143
131
  and task_ins.task.delivered_at == ""
144
132
  and task_ins.task.created_at + task_ins.task.ttl > current_time
145
133
  ):
@@ -173,9 +161,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
173
161
  if (
174
162
  task_ins
175
163
  and task_res
176
- and not (
177
- task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
178
- )
179
164
  and task_ins.task.consumer.node_id != task_res.task.producer.node_id
180
165
  ):
181
166
  return None
@@ -306,45 +291,30 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
306
291
  """
307
292
  return len(self.task_res_store)
308
293
 
309
- def create_node(
310
- self, ping_interval: float, public_key: Optional[bytes] = None
311
- ) -> int:
294
+ def create_node(self, ping_interval: float) -> int:
312
295
  """Create, store in the link state, and return `node_id`."""
313
296
  # Sample a random int64 as node_id
314
- node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
297
+ node_id = generate_rand_int_from_bytes(
298
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
299
+ )
315
300
 
316
301
  with self.lock:
317
302
  if node_id in self.node_ids:
318
303
  log(ERROR, "Unexpected node registration failure.")
319
304
  return 0
320
305
 
321
- if public_key is not None:
322
- if (
323
- public_key in self.public_key_to_node_id
324
- or node_id in self.public_key_to_node_id.values()
325
- ):
326
- log(ERROR, "Unexpected node registration failure.")
327
- return 0
328
-
329
- self.public_key_to_node_id[public_key] = node_id
330
-
331
306
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
332
307
  return node_id
333
308
 
334
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
309
+ def delete_node(self, node_id: int) -> None:
335
310
  """Delete a node."""
336
311
  with self.lock:
337
312
  if node_id not in self.node_ids:
338
313
  raise ValueError(f"Node {node_id} not found")
339
314
 
340
- if public_key is not None:
341
- if (
342
- public_key not in self.public_key_to_node_id
343
- or node_id not in self.public_key_to_node_id.values()
344
- ):
345
- raise ValueError("Public key or node_id not found")
346
-
347
- del self.public_key_to_node_id[public_key]
315
+ # Remove node ID <> public key mappings
316
+ if pk := self.node_id_to_public_key.pop(node_id, None):
317
+ del self.public_key_to_node_id[pk]
348
318
 
349
319
  del self.node_ids[node_id]
350
320
 
@@ -366,6 +336,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
366
336
  if online_until > current_time
367
337
  }
368
338
 
339
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
340
+ """Set `public_key` for the specified `node_id`."""
341
+ with self.lock:
342
+ if node_id not in self.node_ids:
343
+ raise ValueError(f"Node {node_id} not found")
344
+
345
+ if public_key in self.public_key_to_node_id:
346
+ raise ValueError("Public key already in use")
347
+
348
+ self.public_key_to_node_id[public_key] = node_id
349
+ self.node_id_to_public_key[node_id] = public_key
350
+
351
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
352
+ """Get `public_key` for the specified `node_id`."""
353
+ with self.lock:
354
+ if node_id not in self.node_ids:
355
+ raise ValueError(f"Node {node_id} not found")
356
+
357
+ return self.node_id_to_public_key.get(node_id)
358
+
369
359
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
370
360
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
371
361
  return self.public_key_to_node_id.get(node_public_key)
@@ -411,36 +401,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
411
401
  log(ERROR, "Unexpected run creation failure.")
412
402
  return 0
413
403
 
414
- def store_server_private_public_key(
415
- self, private_key: bytes, public_key: bytes
416
- ) -> None:
417
- """Store `server_private_key` and `server_public_key` in the link state."""
404
+ def clear_supernode_auth_keys(self) -> None:
405
+ """Clear stored `node_public_keys` in the link state if any."""
418
406
  with self.lock:
419
- if self.server_private_key is None and self.server_public_key is None:
420
- self.server_private_key = private_key
421
- self.server_public_key = public_key
422
- else:
423
- raise RuntimeError("Server private and public key already set")
424
-
425
- def get_server_private_key(self) -> Optional[bytes]:
426
- """Retrieve `server_private_key` in urlsafe bytes."""
427
- return self.server_private_key
428
-
429
- def get_server_public_key(self) -> Optional[bytes]:
430
- """Retrieve `server_public_key` in urlsafe bytes."""
431
- return self.server_public_key
432
-
433
- def clear_supernode_auth_keys_and_credentials(self) -> None:
434
- """Clear stored `node_public_keys` and credentials in the link state if any."""
435
- with self.lock:
436
- self.server_private_key = None
437
- self.server_public_key = None
438
407
  self.node_public_keys.clear()
439
408
 
440
409
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
441
410
  """Store a set of `node_public_keys` in the link state."""
442
411
  with self.lock:
443
- self.node_public_keys = public_keys
412
+ self.node_public_keys.update(public_keys)
444
413
 
445
414
  def store_node_public_key(self, public_key: bytes) -> None:
446
415
  """Store a `node_public_key` in the link state."""
@@ -449,7 +418,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
449
418
 
450
419
  def get_node_public_keys(self) -> set[bytes]:
451
420
  """Retrieve all currently stored `node_public_keys` as a set."""
452
- return self.node_public_keys
421
+ with self.lock:
422
+ return self.node_public_keys.copy()
453
423
 
454
424
  def get_run_ids(self) -> set[int]:
455
425
  """Retrieve all run IDs."""
@@ -40,20 +40,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
40
40
 
41
41
  Constraints
42
42
  -----------
43
- If `task_ins.task.consumer.anonymous` is `True`, then
44
- `task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
45
-
46
- If `task_ins.task.consumer.anonymous` is `False`, then
47
- `task_ins.task.consumer.node_id` MUST be set (not 0)
43
+ `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
48
44
 
49
45
  If `task_ins.run_id` is invalid, then
50
46
  storing the `task_ins` MUST fail.
51
47
  """
52
48
 
53
49
  @abc.abstractmethod
54
- def get_task_ins(
55
- self, node_id: Optional[int], limit: Optional[int]
56
- ) -> list[TaskIns]:
50
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
57
51
  """Get TaskIns optionally filtered by node_id.
58
52
 
59
53
  Usually, the Fleet API calls this for Nodes planning to work on one or more
@@ -61,15 +55,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
61
55
 
62
56
  Constraints
63
57
  -----------
64
- If `node_id` is not `None`, retrieve all TaskIns where
58
+ Retrieve all TaskIns where
65
59
 
66
60
  1. the `task_ins.task.consumer.node_id` equals `node_id` AND
67
- 2. the `task_ins.task.consumer.anonymous` equals `False` AND
68
- 3. the `task_ins.task.delivered_at` equals `""`.
61
+ 2. the `task_ins.task.delivered_at` equals `""`.
69
62
 
70
- If `node_id` is `None`, retrieve all TaskIns where the
71
- `task_ins.task.consumer.node_id` equals `0` and
72
- `task_ins.task.consumer.anonymous` is set to `True`.
73
63
 
74
64
  If `delivered_at` MUST BE set (not `""`) otherwise the TaskIns MUST not be in
75
65
  the result.
@@ -89,11 +79,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
89
79
 
90
80
  Constraints
91
81
  -----------
92
- If `task_res.task.consumer.anonymous` is `True`, then
93
- `task_res.task.consumer.node_id` MUST NOT be set (equal 0).
94
82
 
95
- If `task_res.task.consumer.anonymous` is `False`, then
96
- `task_res.task.consumer.node_id` MUST be set (not 0)
83
+ `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
97
84
 
98
85
  If `task_res.run_id` is invalid, then
99
86
  storing the `task_res` MUST fail.
@@ -154,13 +141,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
154
141
  """Get all TaskIns IDs for the given run_id."""
155
142
 
156
143
  @abc.abstractmethod
157
- def create_node(
158
- self, ping_interval: float, public_key: Optional[bytes] = None
159
- ) -> int:
144
+ def create_node(self, ping_interval: float) -> int:
160
145
  """Create, store in the link state, and return `node_id`."""
161
146
 
162
147
  @abc.abstractmethod
163
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
148
+ def delete_node(self, node_id: int) -> None:
164
149
  """Remove `node_id` from the link state."""
165
150
 
166
151
  @abc.abstractmethod
@@ -173,6 +158,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
173
158
  an empty `Set` MUST be returned.
174
159
  """
175
160
 
161
+ @abc.abstractmethod
162
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
163
+ """Set `public_key` for the specified `node_id`."""
164
+
165
+ @abc.abstractmethod
166
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
167
+ """Get `public_key` for the specified `node_id`."""
168
+
176
169
  @abc.abstractmethod
177
170
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
178
171
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
@@ -271,22 +264,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
271
264
  """
272
265
 
273
266
  @abc.abstractmethod
274
- def store_server_private_public_key(
275
- self, private_key: bytes, public_key: bytes
276
- ) -> None:
277
- """Store `server_private_key` and `server_public_key` in the link state."""
278
-
279
- @abc.abstractmethod
280
- def get_server_private_key(self) -> Optional[bytes]:
281
- """Retrieve `server_private_key` in urlsafe bytes."""
282
-
283
- @abc.abstractmethod
284
- def get_server_public_key(self) -> Optional[bytes]:
285
- """Retrieve `server_public_key` in urlsafe bytes."""
286
-
287
- @abc.abstractmethod
288
- def clear_supernode_auth_keys_and_credentials(self) -> None:
289
- """Clear stored `node_public_keys` and credentials in the link state if any."""
267
+ def clear_supernode_auth_keys(self) -> None:
268
+ """Clear stored `node_public_keys` in the link state if any."""
290
269
 
291
270
  @abc.abstractmethod
292
271
  def store_node_public_keys(self, public_keys: set[bytes]) -> None: