flwr-nightly 1.12.0.dev20240907__py3-none-any.whl → 1.12.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (118) hide show
  1. flwr/cli/build.py +1 -2
  2. flwr/cli/config_utils.py +10 -10
  3. flwr/cli/install.py +1 -2
  4. flwr/cli/new/new.py +26 -40
  5. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
  6. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
  7. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
  8. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
  9. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
  10. flwr/cli/run/run.py +6 -7
  11. flwr/cli/utils.py +2 -2
  12. flwr/client/app.py +14 -14
  13. flwr/client/client_app.py +5 -5
  14. flwr/client/clientapp/app.py +2 -2
  15. flwr/client/dpfedavg_numpy_client.py +6 -7
  16. flwr/client/grpc_adapter_client/connection.py +4 -3
  17. flwr/client/grpc_client/connection.py +4 -3
  18. flwr/client/grpc_rere_client/client_interceptor.py +5 -5
  19. flwr/client/grpc_rere_client/connection.py +5 -4
  20. flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
  21. flwr/client/message_handler/message_handler.py +3 -3
  22. flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
  23. flwr/client/mod/utils.py +1 -3
  24. flwr/client/node_state.py +2 -2
  25. flwr/client/numpy_client.py +8 -8
  26. flwr/client/rest_client/connection.py +5 -4
  27. flwr/client/supernode/app.py +7 -8
  28. flwr/common/address.py +2 -2
  29. flwr/common/config.py +8 -8
  30. flwr/common/constant.py +12 -1
  31. flwr/common/differential_privacy.py +2 -2
  32. flwr/common/dp.py +1 -3
  33. flwr/common/exit_handlers.py +3 -3
  34. flwr/common/grpc.py +2 -1
  35. flwr/common/logger.py +3 -3
  36. flwr/common/object_ref.py +3 -3
  37. flwr/common/record/configsrecord.py +3 -3
  38. flwr/common/record/metricsrecord.py +3 -3
  39. flwr/common/record/parametersrecord.py +3 -2
  40. flwr/common/record/recordset.py +1 -1
  41. flwr/common/record/typeddict.py +23 -10
  42. flwr/common/recordset_compat.py +7 -5
  43. flwr/common/retry_invoker.py +6 -17
  44. flwr/common/secure_aggregation/crypto/shamir.py +10 -10
  45. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
  46. flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
  47. flwr/common/secure_aggregation/quantization.py +7 -7
  48. flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
  49. flwr/common/serde.py +11 -9
  50. flwr/common/telemetry.py +5 -5
  51. flwr/common/typing.py +19 -19
  52. flwr/common/version.py +2 -3
  53. flwr/server/app.py +18 -18
  54. flwr/server/client_manager.py +6 -6
  55. flwr/server/compat/app_utils.py +2 -3
  56. flwr/server/driver/driver.py +3 -2
  57. flwr/server/driver/grpc_driver.py +7 -7
  58. flwr/server/driver/inmemory_driver.py +5 -4
  59. flwr/server/history.py +8 -9
  60. flwr/server/run_serverapp.py +5 -6
  61. flwr/server/server.py +36 -36
  62. flwr/server/strategy/aggregate.py +13 -13
  63. flwr/server/strategy/bulyan.py +8 -8
  64. flwr/server/strategy/dp_adaptive_clipping.py +20 -20
  65. flwr/server/strategy/dp_fixed_clipping.py +19 -19
  66. flwr/server/strategy/dpfedavg_adaptive.py +6 -6
  67. flwr/server/strategy/dpfedavg_fixed.py +10 -10
  68. flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
  69. flwr/server/strategy/fedadagrad.py +8 -8
  70. flwr/server/strategy/fedadam.py +8 -8
  71. flwr/server/strategy/fedavg.py +16 -16
  72. flwr/server/strategy/fedavg_android.py +16 -16
  73. flwr/server/strategy/fedavgm.py +8 -8
  74. flwr/server/strategy/fedmedian.py +4 -4
  75. flwr/server/strategy/fedopt.py +5 -5
  76. flwr/server/strategy/fedprox.py +6 -6
  77. flwr/server/strategy/fedtrimmedavg.py +8 -8
  78. flwr/server/strategy/fedxgb_bagging.py +11 -11
  79. flwr/server/strategy/fedxgb_cyclic.py +9 -9
  80. flwr/server/strategy/fedxgb_nn_avg.py +5 -5
  81. flwr/server/strategy/fedyogi.py +8 -8
  82. flwr/server/strategy/krum.py +8 -8
  83. flwr/server/strategy/qfedavg.py +15 -15
  84. flwr/server/strategy/strategy.py +10 -10
  85. flwr/server/superlink/driver/driver_grpc.py +2 -2
  86. flwr/server/superlink/driver/driver_servicer.py +6 -6
  87. flwr/server/superlink/ffs/disk_ffs.py +4 -4
  88. flwr/server/superlink/ffs/ffs.py +4 -4
  89. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
  90. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
  91. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
  92. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
  93. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
  94. flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
  95. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
  96. flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
  97. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  98. flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
  99. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  100. flwr/server/superlink/state/in_memory_state.py +18 -18
  101. flwr/server/superlink/state/sqlite_state.py +22 -21
  102. flwr/server/superlink/state/state.py +7 -7
  103. flwr/server/utils/tensorboard.py +4 -4
  104. flwr/server/utils/validator.py +2 -2
  105. flwr/server/workflow/default_workflows.py +5 -5
  106. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
  107. flwr/simulation/app.py +8 -8
  108. flwr/simulation/ray_transport/ray_actor.py +23 -23
  109. flwr/simulation/run_simulation.py +16 -4
  110. flwr/superexec/app.py +4 -4
  111. flwr/superexec/deployment.py +2 -2
  112. flwr/superexec/exec_grpc.py +2 -2
  113. flwr/superexec/exec_servicer.py +3 -2
  114. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/METADATA +4 -6
  115. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/RECORD +118 -118
  116. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/LICENSE +0 -0
  117. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/WHEEL +0 -0
  118. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/entry_points.txt +0 -0
@@ -16,7 +16,7 @@
16
16
 
17
17
  import sys
18
18
  from logging import DEBUG, ERROR
19
- from typing import Callable, Dict, Optional, Tuple, Union
19
+ from typing import Callable, Optional, Union
20
20
 
21
21
  import ray
22
22
 
@@ -31,8 +31,8 @@ from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
31
31
 
32
32
  from .backend import Backend, BackendConfig
33
33
 
34
- ClientResourcesDict = Dict[str, Union[int, float]]
35
- ActorArgsDict = Dict[str, Union[int, float, Callable[[], None]]]
34
+ ClientResourcesDict = dict[str, Union[int, float]]
35
+ ActorArgsDict = dict[str, Union[int, float, Callable[[], None]]]
36
36
 
37
37
 
38
38
  class RayBackend(Backend):
@@ -52,16 +52,11 @@ class RayBackend(Backend):
52
52
 
53
53
  # Validate client resources
54
54
  self.client_resources_key = "client_resources"
55
- client_resources = self._validate_client_resources(config=backend_config)
55
+ self.client_resources = self._validate_client_resources(config=backend_config)
56
56
 
57
- # Create actor pool
58
- actor_kwargs = self._validate_actor_arguments(config=backend_config)
59
-
60
- self.pool = BasicActorPool(
61
- actor_type=ClientAppActor,
62
- client_resources=client_resources,
63
- actor_kwargs=actor_kwargs,
64
- )
57
+ # Valide actor resources
58
+ self.actor_kwargs = self._validate_actor_arguments(config=backend_config)
59
+ self.pool: Optional[BasicActorPool] = None
65
60
 
66
61
  self.app_fn: Optional[Callable[[], ClientApp]] = None
67
62
 
@@ -106,7 +101,7 @@ class RayBackend(Backend):
106
101
  def init_ray(self, backend_config: BackendConfig) -> None:
107
102
  """Intialises Ray if not already initialised."""
108
103
  if not ray.is_initialized():
109
- ray_init_args: Dict[
104
+ ray_init_args: dict[
110
105
  str,
111
106
  ConfigsRecordValues,
112
107
  ] = {}
@@ -122,14 +117,24 @@ class RayBackend(Backend):
122
117
  @property
123
118
  def num_workers(self) -> int:
124
119
  """Return number of actors in pool."""
125
- return self.pool.num_actors
120
+ return self.pool.num_actors if self.pool else 0
126
121
 
127
122
  def is_worker_idle(self) -> bool:
128
123
  """Report whether the pool has idle actors."""
129
- return self.pool.is_actor_available()
124
+ return self.pool.is_actor_available() if self.pool else False
130
125
 
131
126
  def build(self, app_fn: Callable[[], ClientApp]) -> None:
132
127
  """Build pool of Ray actors that this backend will submit jobs to."""
128
+ # Create Actor Pool
129
+ try:
130
+ self.pool = BasicActorPool(
131
+ actor_type=ClientAppActor,
132
+ client_resources=self.client_resources,
133
+ actor_kwargs=self.actor_kwargs,
134
+ )
135
+ except Exception as ex:
136
+ raise ex
137
+
133
138
  self.pool.add_actors_to_pool(self.pool.actors_capacity)
134
139
  # Set ClientApp callable that ray actors will use
135
140
  self.app_fn = app_fn
@@ -139,13 +144,16 @@ class RayBackend(Backend):
139
144
  self,
140
145
  message: Message,
141
146
  context: Context,
142
- ) -> Tuple[Message, Context]:
147
+ ) -> tuple[Message, Context]:
143
148
  """Run ClientApp that process a given message.
144
149
 
145
150
  Return output message and updated context.
146
151
  """
147
152
  partition_id = context.node_config[PARTITION_ID_KEY]
148
153
 
154
+ if self.pool is None:
155
+ raise ValueError("The actor pool is empty, unfit to process messages.")
156
+
149
157
  if self.app_fn is None:
150
158
  raise ValueError(
151
159
  "Unspecified function to load a `ClientApp`. "
@@ -179,6 +187,7 @@ class RayBackend(Backend):
179
187
 
180
188
  def terminate(self) -> None:
181
189
  """Terminate all actors in actor pool."""
182
- self.pool.terminate_all_actors()
190
+ if self.pool:
191
+ self.pool.terminate_all_actors()
183
192
  ray.shutdown()
184
193
  log(DEBUG, "Terminated %s", self.__class__.__name__)
@@ -24,7 +24,7 @@ from logging import DEBUG, ERROR, INFO, WARN
24
24
  from pathlib import Path
25
25
  from queue import Empty, Queue
26
26
  from time import sleep
27
- from typing import Callable, Dict, Optional
27
+ from typing import Callable, Optional
28
28
 
29
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
30
30
  from flwr.client.clientapp.utils import get_load_client_app_fn
@@ -44,7 +44,7 @@ from flwr.server.superlink.state import State, StateFactory
44
44
 
45
45
  from .backend import Backend, error_messages_backends, supported_backends
46
46
 
47
- NodeToPartitionMapping = Dict[int, int]
47
+ NodeToPartitionMapping = dict[int, int]
48
48
 
49
49
 
50
50
  def _register_nodes(
@@ -64,9 +64,9 @@ def _register_node_states(
64
64
  nodes_mapping: NodeToPartitionMapping,
65
65
  run: Run,
66
66
  app_dir: Optional[str] = None,
67
- ) -> Dict[int, NodeState]:
67
+ ) -> dict[int, NodeState]:
68
68
  """Create NodeState objects and pre-register the context for the run."""
69
- node_states: Dict[int, NodeState] = {}
69
+ node_states: dict[int, NodeState] = {}
70
70
  num_partitions = len(set(nodes_mapping.values()))
71
71
  for node_id, partition_id in nodes_mapping.items():
72
72
  node_states[node_id] = NodeState(
@@ -89,7 +89,7 @@ def _register_node_states(
89
89
  def worker(
90
90
  taskins_queue: "Queue[TaskIns]",
91
91
  taskres_queue: "Queue[TaskRes]",
92
- node_states: Dict[int, NodeState],
92
+ node_states: dict[int, NodeState],
93
93
  backend: Backend,
94
94
  f_stop: threading.Event,
95
95
  ) -> None:
@@ -177,7 +177,7 @@ def run_api(
177
177
  backend_fn: Callable[[], Backend],
178
178
  nodes_mapping: NodeToPartitionMapping,
179
179
  state_factory: StateFactory,
180
- node_states: Dict[int, NodeState],
180
+ node_states: dict[int, NodeState],
181
181
  f_stop: threading.Event,
182
182
  ) -> None:
183
183
  """Run the VCE."""
@@ -18,7 +18,7 @@
18
18
  import threading
19
19
  import time
20
20
  from logging import ERROR
21
- from typing import Dict, List, Optional, Set, Tuple
21
+ from typing import Optional
22
22
  from uuid import UUID, uuid4
23
23
 
24
24
  from flwr.common import log, now
@@ -37,15 +37,15 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
37
37
  def __init__(self) -> None:
38
38
 
39
39
  # Map node_id to (online_until, ping_interval)
40
- self.node_ids: Dict[int, Tuple[float, float]] = {}
41
- self.public_key_to_node_id: Dict[bytes, int] = {}
40
+ self.node_ids: dict[int, tuple[float, float]] = {}
41
+ self.public_key_to_node_id: dict[bytes, int] = {}
42
42
 
43
43
  # Map run_id to (fab_id, fab_version)
44
- self.run_ids: Dict[int, Run] = {}
45
- self.task_ins_store: Dict[UUID, TaskIns] = {}
46
- self.task_res_store: Dict[UUID, TaskRes] = {}
44
+ self.run_ids: dict[int, Run] = {}
45
+ self.task_ins_store: dict[UUID, TaskIns] = {}
46
+ self.task_res_store: dict[UUID, TaskRes] = {}
47
47
 
48
- self.node_public_keys: Set[bytes] = set()
48
+ self.node_public_keys: set[bytes] = set()
49
49
  self.server_public_key: Optional[bytes] = None
50
50
  self.server_private_key: Optional[bytes] = None
51
51
 
@@ -76,13 +76,13 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
76
76
 
77
77
  def get_task_ins(
78
78
  self, node_id: Optional[int], limit: Optional[int]
79
- ) -> List[TaskIns]:
79
+ ) -> list[TaskIns]:
80
80
  """Get all TaskIns that have not been delivered yet."""
81
81
  if limit is not None and limit < 1:
82
82
  raise AssertionError("`limit` must be >= 1")
83
83
 
84
84
  # Find TaskIns for node_id that were not delivered yet
85
- task_ins_list: List[TaskIns] = []
85
+ task_ins_list: list[TaskIns] = []
86
86
  with self.lock:
87
87
  for _, task_ins in self.task_ins_store.items():
88
88
  # pylint: disable=too-many-boolean-expressions
@@ -133,15 +133,15 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
133
133
  # Return the new task_id
134
134
  return task_id
135
135
 
136
- def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
136
+ def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
137
137
  """Get all TaskRes that have not been delivered yet."""
138
138
  if limit is not None and limit < 1:
139
139
  raise AssertionError("`limit` must be >= 1")
140
140
 
141
141
  with self.lock:
142
142
  # Find TaskRes that were not delivered yet
143
- task_res_list: List[TaskRes] = []
144
- replied_task_ids: Set[UUID] = set()
143
+ task_res_list: list[TaskRes] = []
144
+ replied_task_ids: set[UUID] = set()
145
145
  for _, task_res in self.task_res_store.items():
146
146
  reply_to = UUID(task_res.task.ancestry[0])
147
147
  if reply_to in task_ids and task_res.task.delivered_at == "":
@@ -175,10 +175,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
175
175
  # Return TaskRes
176
176
  return task_res_list
177
177
 
178
- def delete_tasks(self, task_ids: Set[UUID]) -> None:
178
+ def delete_tasks(self, task_ids: set[UUID]) -> None:
179
179
  """Delete all delivered TaskIns/TaskRes pairs."""
180
- task_ins_to_be_deleted: Set[UUID] = set()
181
- task_res_to_be_deleted: Set[UUID] = set()
180
+ task_ins_to_be_deleted: set[UUID] = set()
181
+ task_res_to_be_deleted: set[UUID] = set()
182
182
 
183
183
  with self.lock:
184
184
  for task_ins_id in task_ids:
@@ -253,7 +253,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
253
253
 
254
254
  del self.node_ids[node_id]
255
255
 
256
- def get_nodes(self, run_id: int) -> Set[int]:
256
+ def get_nodes(self, run_id: int) -> set[int]:
257
257
  """Return all available nodes.
258
258
 
259
259
  Constraints
@@ -318,7 +318,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
318
318
  """Retrieve `server_public_key` in urlsafe bytes."""
319
319
  return self.server_public_key
320
320
 
321
- def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
321
+ def store_node_public_keys(self, public_keys: set[bytes]) -> None:
322
322
  """Store a set of `node_public_keys` in state."""
323
323
  with self.lock:
324
324
  self.node_public_keys = public_keys
@@ -328,7 +328,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
328
328
  with self.lock:
329
329
  self.node_public_keys.add(public_key)
330
330
 
331
- def get_node_public_keys(self) -> Set[bytes]:
331
+ def get_node_public_keys(self) -> set[bytes]:
332
332
  """Retrieve all currently stored `node_public_keys` as a set."""
333
333
  return self.node_public_keys
334
334
 
@@ -19,8 +19,9 @@ import json
19
19
  import re
20
20
  import sqlite3
21
21
  import time
22
+ from collections.abc import Sequence
22
23
  from logging import DEBUG, ERROR
23
- from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24
+ from typing import Any, Optional, Union, cast
24
25
  from uuid import UUID, uuid4
25
26
 
26
27
  from flwr.common import log, now
@@ -110,7 +111,7 @@ CREATE TABLE IF NOT EXISTS task_res(
110
111
  );
111
112
  """
112
113
 
113
- DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
114
+ DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
114
115
 
115
116
 
116
117
  class SqliteState(State): # pylint: disable=R0904
@@ -131,7 +132,7 @@ class SqliteState(State): # pylint: disable=R0904
131
132
  self.database_path = database_path
132
133
  self.conn: Optional[sqlite3.Connection] = None
133
134
 
134
- def initialize(self, log_queries: bool = False) -> List[Tuple[str]]:
135
+ def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
135
136
  """Create tables if they don't exist yet.
136
137
 
137
138
  Parameters
@@ -162,7 +163,7 @@ class SqliteState(State): # pylint: disable=R0904
162
163
  self,
163
164
  query: str,
164
165
  data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
165
- ) -> List[Dict[str, Any]]:
166
+ ) -> list[dict[str, Any]]:
166
167
  """Execute a SQL query."""
167
168
  if self.conn is None:
168
169
  raise AttributeError("State is not initialized.")
@@ -237,7 +238,7 @@ class SqliteState(State): # pylint: disable=R0904
237
238
 
238
239
  def get_task_ins(
239
240
  self, node_id: Optional[int], limit: Optional[int]
240
- ) -> List[TaskIns]:
241
+ ) -> list[TaskIns]:
241
242
  """Get undelivered TaskIns for one node (either anonymous or with ID).
242
243
 
243
244
  Usually, the Fleet API calls this for Nodes planning to work on one or more
@@ -271,7 +272,7 @@ class SqliteState(State): # pylint: disable=R0904
271
272
  )
272
273
  raise AssertionError(msg)
273
274
 
274
- data: Dict[str, Union[str, int]] = {}
275
+ data: dict[str, Union[str, int]] = {}
275
276
 
276
277
  if node_id is None:
277
278
  # Retrieve all anonymous Tasks
@@ -367,7 +368,7 @@ class SqliteState(State): # pylint: disable=R0904
367
368
  return task_id
368
369
 
369
370
  # pylint: disable-next=R0914
370
- def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
371
+ def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
371
372
  """Get TaskRes for task_ids.
372
373
 
373
374
  Usually, the Driver API calls this method to get results for instructions it has
@@ -397,7 +398,7 @@ class SqliteState(State): # pylint: disable=R0904
397
398
  AND delivered_at = ""
398
399
  """
399
400
 
400
- data: Dict[str, Union[str, float, int]] = {}
401
+ data: dict[str, Union[str, float, int]] = {}
401
402
 
402
403
  if limit is not None:
403
404
  query += " LIMIT :limit"
@@ -435,7 +436,7 @@ class SqliteState(State): # pylint: disable=R0904
435
436
  # 1. Query: Fetch consumer_node_id of remaining task_ids
436
437
  # Assume the ancestry field only contains one element
437
438
  data.clear()
438
- replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
439
+ replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
439
440
  remaining_task_ids = task_ids - replied_task_ids
440
441
  placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
441
442
  query = f"""
@@ -499,10 +500,10 @@ class SqliteState(State): # pylint: disable=R0904
499
500
  """
500
501
  query = "SELECT count(*) AS num FROM task_res;"
501
502
  rows = self.query(query)
502
- result: Dict[str, int] = rows[0]
503
+ result: dict[str, int] = rows[0]
503
504
  return result["num"]
504
505
 
505
- def delete_tasks(self, task_ids: Set[UUID]) -> None:
506
+ def delete_tasks(self, task_ids: set[UUID]) -> None:
506
507
  """Delete all delivered TaskIns/TaskRes pairs."""
507
508
  ids = list(task_ids)
508
509
  if len(ids) == 0:
@@ -588,7 +589,7 @@ class SqliteState(State): # pylint: disable=R0904
588
589
  except KeyError as exc:
589
590
  log(ERROR, {"query": query, "data": params, "exception": exc})
590
591
 
591
- def get_nodes(self, run_id: int) -> Set[int]:
592
+ def get_nodes(self, run_id: int) -> set[int]:
592
593
  """Retrieve all currently stored node IDs as a set.
593
594
 
594
595
  Constraints
@@ -604,7 +605,7 @@ class SqliteState(State): # pylint: disable=R0904
604
605
  # Get nodes
605
606
  query = "SELECT node_id FROM node WHERE online_until > ?;"
606
607
  rows = self.query(query, (time.time(),))
607
- result: Set[int] = {row["node_id"] for row in rows}
608
+ result: set[int] = {row["node_id"] for row in rows}
608
609
  return result
609
610
 
610
611
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
@@ -684,7 +685,7 @@ class SqliteState(State): # pylint: disable=R0904
684
685
  public_key = None
685
686
  return public_key
686
687
 
687
- def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
688
+ def store_node_public_keys(self, public_keys: set[bytes]) -> None:
688
689
  """Store a set of `node_public_keys` in state."""
689
690
  query = "INSERT INTO public_key (public_key) VALUES (?)"
690
691
  data = [(key,) for key in public_keys]
@@ -695,11 +696,11 @@ class SqliteState(State): # pylint: disable=R0904
695
696
  query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
696
697
  self.query(query, {"public_key": public_key})
697
698
 
698
- def get_node_public_keys(self) -> Set[bytes]:
699
+ def get_node_public_keys(self) -> set[bytes]:
699
700
  """Retrieve all currently stored `node_public_keys` as a set."""
700
701
  query = "SELECT public_key FROM public_key"
701
702
  rows = self.query(query)
702
- result: Set[bytes] = {row["public_key"] for row in rows}
703
+ result: set[bytes] = {row["public_key"] for row in rows}
703
704
  return result
704
705
 
705
706
  def get_run(self, run_id: int) -> Optional[Run]:
@@ -733,7 +734,7 @@ class SqliteState(State): # pylint: disable=R0904
733
734
  def dict_factory(
734
735
  cursor: sqlite3.Cursor,
735
736
  row: sqlite3.Row,
736
- ) -> Dict[str, Any]:
737
+ ) -> dict[str, Any]:
737
738
  """Turn SQLite results into dicts.
738
739
 
739
740
  Less efficent for retrival of large amounts of data but easier to use.
@@ -742,7 +743,7 @@ def dict_factory(
742
743
  return dict(zip(fields, row))
743
744
 
744
745
 
745
- def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
746
+ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
746
747
  """Transform TaskIns to dict."""
747
748
  result = {
748
749
  "task_id": task_msg.task_id,
@@ -763,7 +764,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
763
764
  return result
764
765
 
765
766
 
766
- def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
767
+ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
767
768
  """Transform TaskRes to dict."""
768
769
  result = {
769
770
  "task_id": task_msg.task_id,
@@ -784,7 +785,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
784
785
  return result
785
786
 
786
787
 
787
- def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
788
+ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
788
789
  """Turn task_dict into protobuf message."""
789
790
  recordset = RecordSet()
790
791
  recordset.ParseFromString(task_dict["recordset"])
@@ -814,7 +815,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
814
815
  return result
815
816
 
816
817
 
817
- def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes:
818
+ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
818
819
  """Turn task_dict into protobuf message."""
819
820
  recordset = RecordSet()
820
821
  recordset.ParseFromString(task_dict["recordset"])
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import List, Optional, Set
19
+ from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common.typing import Run, UserConfig
@@ -51,7 +51,7 @@ class State(abc.ABC): # pylint: disable=R0904
51
51
  @abc.abstractmethod
52
52
  def get_task_ins(
53
53
  self, node_id: Optional[int], limit: Optional[int]
54
- ) -> List[TaskIns]:
54
+ ) -> list[TaskIns]:
55
55
  """Get TaskIns optionally filtered by node_id.
56
56
 
57
57
  Usually, the Fleet API calls this for Nodes planning to work on one or more
@@ -98,7 +98,7 @@ class State(abc.ABC): # pylint: disable=R0904
98
98
  """
99
99
 
100
100
  @abc.abstractmethod
101
- def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
101
+ def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
102
102
  """Get TaskRes for task_ids.
103
103
 
104
104
  Usually, the Driver API calls this method to get results for instructions it has
@@ -129,7 +129,7 @@ class State(abc.ABC): # pylint: disable=R0904
129
129
  """
130
130
 
131
131
  @abc.abstractmethod
132
- def delete_tasks(self, task_ids: Set[UUID]) -> None:
132
+ def delete_tasks(self, task_ids: set[UUID]) -> None:
133
133
  """Delete all delivered TaskIns/TaskRes pairs."""
134
134
 
135
135
  @abc.abstractmethod
@@ -143,7 +143,7 @@ class State(abc.ABC): # pylint: disable=R0904
143
143
  """Remove `node_id` from state."""
144
144
 
145
145
  @abc.abstractmethod
146
- def get_nodes(self, run_id: int) -> Set[int]:
146
+ def get_nodes(self, run_id: int) -> set[int]:
147
147
  """Retrieve all currently stored node IDs as a set.
148
148
 
149
149
  Constraints
@@ -199,7 +199,7 @@ class State(abc.ABC): # pylint: disable=R0904
199
199
  """Retrieve `server_public_key` in urlsafe bytes."""
200
200
 
201
201
  @abc.abstractmethod
202
- def store_node_public_keys(self, public_keys: Set[bytes]) -> None:
202
+ def store_node_public_keys(self, public_keys: set[bytes]) -> None:
203
203
  """Store a set of `node_public_keys` in state."""
204
204
 
205
205
  @abc.abstractmethod
@@ -207,7 +207,7 @@ class State(abc.ABC): # pylint: disable=R0904
207
207
  """Store a `node_public_key` in state."""
208
208
 
209
209
  @abc.abstractmethod
210
- def get_node_public_keys(self) -> Set[bytes]:
210
+ def get_node_public_keys(self) -> set[bytes]:
211
211
  """Retrieve all currently stored `node_public_keys` as a set."""
212
212
 
213
213
  @abc.abstractmethod
@@ -18,7 +18,7 @@
18
18
  import os
19
19
  from datetime import datetime
20
20
  from logging import WARN
21
- from typing import Callable, Dict, List, Optional, Tuple, Union, cast
21
+ from typing import Callable, Optional, Union, cast
22
22
 
23
23
  from flwr.common import EvaluateRes, Scalar
24
24
  from flwr.common.logger import log
@@ -92,9 +92,9 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
92
92
  def aggregate_evaluate(
93
93
  self,
94
94
  server_round: int,
95
- results: List[Tuple[ClientProxy, EvaluateRes]],
96
- failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
97
- ) -> Tuple[Optional[float], Dict[str, Scalar]]:
95
+ results: list[tuple[ClientProxy, EvaluateRes]],
96
+ failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
97
+ ) -> tuple[Optional[float], dict[str, Scalar]]:
98
98
  """Hooks into aggregate_evaluate for TensorBoard logging purpose."""
99
99
  # Execute decorated function and extract results for logging
100
100
  # They will be returned at the end of this function but also
@@ -15,13 +15,13 @@
15
15
  """Validators."""
16
16
 
17
17
 
18
- from typing import List, Union
18
+ from typing import Union
19
19
 
20
20
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
21
21
 
22
22
 
23
23
  # pylint: disable-next=too-many-branches,too-many-statements
24
- def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str]:
24
+ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str]:
25
25
  """Validate a TaskIns or TaskRes."""
26
26
  validation_errors = []
27
27
 
@@ -18,7 +18,7 @@
18
18
  import io
19
19
  import timeit
20
20
  from logging import INFO, WARN
21
- from typing import List, Optional, Tuple, Union, cast
21
+ from typing import Optional, Union, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
24
  from flwr.common import (
@@ -276,8 +276,8 @@ def default_fit_workflow( # pylint: disable=R0914
276
276
  )
277
277
 
278
278
  # Aggregate training results
279
- results: List[Tuple[ClientProxy, FitRes]] = []
280
- failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
279
+ results: list[tuple[ClientProxy, FitRes]] = []
280
+ failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = []
281
281
  for msg in messages:
282
282
  if msg.has_content():
283
283
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
@@ -362,8 +362,8 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
362
362
  )
363
363
 
364
364
  # Aggregate the evaluation results
365
- results: List[Tuple[ClientProxy, EvaluateRes]] = []
366
- failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = []
365
+ results: list[tuple[ClientProxy, EvaluateRes]] = []
366
+ failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = []
367
367
  for msg in messages:
368
368
  if msg.has_content():
369
369
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
@@ -18,7 +18,7 @@
18
18
  import random
19
19
  from dataclasses import dataclass, field
20
20
  from logging import DEBUG, ERROR, INFO, WARN
21
- from typing import Dict, List, Optional, Set, Tuple, Union, cast
21
+ from typing import Optional, Union, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
24
  from flwr.common import (
@@ -65,23 +65,23 @@ from ..constant import Key as WorkflowKey
65
65
  class WorkflowState: # pylint: disable=R0902
66
66
  """The state of the SecAgg+ protocol."""
67
67
 
68
- nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
69
- nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
70
- sampled_node_ids: Set[int] = field(default_factory=set)
71
- active_node_ids: Set[int] = field(default_factory=set)
68
+ nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
69
+ nid_to_fitins: dict[int, RecordSet] = field(default_factory=dict)
70
+ sampled_node_ids: set[int] = field(default_factory=set)
71
+ active_node_ids: set[int] = field(default_factory=set)
72
72
  num_shares: int = 0
73
73
  threshold: int = 0
74
74
  clipping_range: float = 0.0
75
75
  quantization_range: int = 0
76
76
  mod_range: int = 0
77
77
  max_weight: float = 0.0
78
- nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict)
79
- nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict)
80
- forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
81
- forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
78
+ nid_to_neighbours: dict[int, set[int]] = field(default_factory=dict)
79
+ nid_to_publickeys: dict[int, list[bytes]] = field(default_factory=dict)
80
+ forward_srcs: dict[int, list[int]] = field(default_factory=dict)
81
+ forward_ciphertexts: dict[int, list[bytes]] = field(default_factory=dict)
82
82
  aggregate_ndarrays: NDArrays = field(default_factory=list)
83
- legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
- failures: List[Exception] = field(default_factory=list)
83
+ legacy_results: list[tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
+ failures: list[Exception] = field(default_factory=list)
85
85
 
86
86
 
87
87
  class SecAggPlusWorkflow:
@@ -444,13 +444,13 @@ class SecAggPlusWorkflow:
444
444
  )
445
445
 
446
446
  # Build forward packet list dictionary
447
- srcs: List[int] = []
448
- dsts: List[int] = []
449
- ciphertexts: List[bytes] = []
450
- fwd_ciphertexts: Dict[int, List[bytes]] = {
447
+ srcs: list[int] = []
448
+ dsts: list[int] = []
449
+ ciphertexts: list[bytes] = []
450
+ fwd_ciphertexts: dict[int, list[bytes]] = {
451
451
  nid: [] for nid in state.active_node_ids
452
452
  } # dest node ID -> list of ciphertexts
453
- fwd_srcs: Dict[int, List[int]] = {
453
+ fwd_srcs: dict[int, list[int]] = {
454
454
  nid: [] for nid in state.active_node_ids
455
455
  } # dest node ID -> list of src node IDs
456
456
  for msg in msgs:
@@ -459,8 +459,8 @@ class SecAggPlusWorkflow:
459
459
  continue
460
460
  node_id = msg.metadata.src_node_id
461
461
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
462
- dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
463
- ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST])
462
+ dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
463
+ ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
464
464
  srcs += [node_id] * len(dst_lst)
465
465
  dsts += dst_lst
466
466
  ciphertexts += ctxt_lst
@@ -525,7 +525,7 @@ class SecAggPlusWorkflow:
525
525
  state.failures.append(Exception(msg.error))
526
526
  continue
527
527
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
528
- bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
528
+ bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
529
529
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
530
530
  if masked_vector is None:
531
531
  masked_vector = client_masked_vec
@@ -592,7 +592,7 @@ class SecAggPlusWorkflow:
592
592
  )
593
593
 
594
594
  # Build collected shares dict
595
- collected_shares_dict: Dict[int, List[bytes]] = {}
595
+ collected_shares_dict: dict[int, list[bytes]] = {}
596
596
  for nid in state.sampled_node_ids:
597
597
  collected_shares_dict[nid] = []
598
598
  for msg in msgs:
@@ -600,8 +600,8 @@ class SecAggPlusWorkflow:
600
600
  state.failures.append(Exception(msg.error))
601
601
  continue
602
602
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
603
- nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
604
- shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
603
+ nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
604
+ shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
605
605
  for owner_nid, share in zip(nids, shares):
606
606
  collected_shares_dict[owner_nid].append(share)
607
607