flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
 
20
20
  RECORD_KEY_STATE = "secaggplus_state"
21
21
  RECORD_KEY_CONFIGS = "secaggplus_configs"
22
+ RATIO_QUANTIZATION_RANGE = 1073741824 # 1 << 30
22
23
 
23
24
 
24
25
  class Stage:
@@ -26,9 +27,9 @@ class Stage:
26
27
 
27
28
  SETUP = "setup"
28
29
  SHARE_KEYS = "share_keys"
29
- COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ COLLECT_MASKED_VECTORS = "collect_masked_vectors"
30
31
  UNMASK = "unmask"
31
- _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK)
32
33
 
33
34
  @classmethod
34
35
  def all(cls) -> tuple[str, str, str, str]:
@@ -45,12 +46,12 @@ class Key:
45
46
 
46
47
  STAGE = "stage"
47
48
  SAMPLE_NUMBER = "sample_num"
48
- SECURE_ID = "secure_id"
49
49
  SHARE_NUMBER = "share_num"
50
50
  THRESHOLD = "threshold"
51
51
  CLIPPING_RANGE = "clipping_range"
52
52
  TARGET_RANGE = "target_range"
53
53
  MOD_RANGE = "mod_range"
54
+ MAX_WEIGHT = "max_weight"
54
55
  PUBLIC_KEY_1 = "pk1"
55
56
  PUBLIC_KEY_2 = "pk2"
56
57
  DESTINATION_LIST = "dsts"
@@ -58,9 +59,9 @@ class Key:
58
59
  SOURCE_LIST = "srcs"
59
60
  PARAMETERS = "params"
60
61
  MASKED_PARAMETERS = "masked_params"
61
- ACTIVE_SECURE_ID_LIST = "active_sids"
62
- DEAD_SECURE_ID_LIST = "dead_sids"
63
- SECURE_ID_LIST = "sids"
62
+ ACTIVE_NODE_ID_LIST = "active_nids"
63
+ DEAD_NODE_ID_LIST = "dead_nids"
64
+ NODE_ID_LIST = "nids"
64
65
  SHARE_LIST = "shares"
65
66
 
66
67
  def __new__(cls) -> Key:
@@ -23,16 +23,16 @@ from flwr.common.typing import NDArrayInt
23
23
 
24
24
 
25
25
  def share_keys_plaintext_concat(
26
- source: int, destination: int, b_share: bytes, sk_share: bytes
26
+ src_node_id: int, dst_node_id: int, b_share: bytes, sk_share: bytes
27
27
  ) -> bytes:
28
28
  """Combine arguments to bytes.
29
29
 
30
30
  Parameters
31
31
  ----------
32
- source : int
33
- the secure ID of the source.
34
- destination : int
35
- the secure ID of the destination.
32
+ src_node_id : int
33
+ the node ID of the source.
34
+ dst_node_id : int
35
+ the node ID of the destination.
36
36
  b_share : bytes
37
37
  the private key share of the source sent to the destination.
38
38
  sk_share : bytes
@@ -45,8 +45,8 @@ def share_keys_plaintext_concat(
45
45
  """
46
46
  return b"".join(
47
47
  [
48
- int.to_bytes(source, 4, "little"),
49
- int.to_bytes(destination, 4, "little"),
48
+ int.to_bytes(src_node_id, 8, "little", signed=True),
49
+ int.to_bytes(dst_node_id, 8, "little", signed=True),
50
50
  int.to_bytes(len(b_share), 4, "little"),
51
51
  b_share,
52
52
  sk_share,
@@ -64,21 +64,21 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by
64
64
 
65
65
  Returns
66
66
  -------
67
- source : int
68
- the secure ID of the source.
69
- destination : int
70
- the secure ID of the destination.
67
+ src_node_id : int
68
+ the node ID of the source.
69
+ dst_node_id : int
70
+ the node ID of the destination.
71
71
  b_share : bytes
72
72
  the private key share of the source sent to the destination.
73
73
  sk_share : bytes
74
74
  the secret key share of the source sent to the destination.
75
75
  """
76
76
  src, dst, mark = (
77
- int.from_bytes(plaintext[:4], "little"),
78
- int.from_bytes(plaintext[4:8], "little"),
79
- int.from_bytes(plaintext[8:12], "little"),
77
+ int.from_bytes(plaintext[:8], "little", signed=True),
78
+ int.from_bytes(plaintext[8:16], "little", signed=True),
79
+ int.from_bytes(plaintext[16:20], "little"),
80
80
  )
81
- ret = (src, dst, plaintext[12 : 12 + mark], plaintext[12 + mark :])
81
+ ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :])
82
82
  return ret
83
83
 
84
84
 
flwr/server/compat/app.py CHANGED
@@ -127,9 +127,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
127
127
  )
128
128
  log(
129
129
  INFO,
130
- "Starting Flower server, config: %s",
130
+ "Starting Flower ServerApp, config: %s",
131
131
  initialized_config,
132
132
  )
133
+ log(INFO, "")
133
134
 
134
135
  # Start the thread updating nodes
135
136
  thread, f_stop = start_update_client_manager_thread(
@@ -15,7 +15,7 @@
15
15
  """Flower driver service client."""
16
16
 
17
17
 
18
- from logging import ERROR, INFO, WARNING
18
+ from logging import DEBUG, ERROR, WARNING
19
19
  from typing import Optional
20
20
 
21
21
  import grpc
@@ -70,19 +70,19 @@ class GrpcDriver:
70
70
  root_certificates=self.root_certificates,
71
71
  )
72
72
  self.stub = DriverStub(self.channel)
73
- log(INFO, "[Driver] Connected to %s", self.driver_service_address)
73
+ log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
74
74
 
75
75
  def disconnect(self) -> None:
76
76
  """Disconnect from the Driver API."""
77
77
  event(EventType.DRIVER_DISCONNECT)
78
78
  if self.channel is None or self.stub is None:
79
- log(WARNING, "Already disconnected")
79
+ log(DEBUG, "Already disconnected")
80
80
  return
81
81
  channel = self.channel
82
82
  self.channel = None
83
83
  self.stub = None
84
84
  channel.close()
85
- log(INFO, "[Driver] Disconnected")
85
+ log(DEBUG, "[Driver] Disconnected")
86
86
 
87
87
  def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
88
88
  """Request for run ID."""
flwr/server/history.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Training history."""
16
16
 
17
17
 
18
+ import pprint
18
19
  from functools import reduce
19
20
  from typing import Dict, List, Tuple
20
21
 
@@ -90,29 +91,35 @@ class History:
90
91
  """
91
92
  rep = ""
92
93
  if self.losses_distributed:
93
- rep += "History (loss, distributed):\n" + reduce(
94
- lambda a, b: a + b,
95
- [
96
- f"\tround {server_round}: {loss}\n"
97
- for server_round, loss in self.losses_distributed
98
- ],
94
+ rep += "History (loss, distributed):\n" + pprint.pformat(
95
+ reduce(
96
+ lambda a, b: a + b,
97
+ [
98
+ f"\tround {server_round}: {loss}\n"
99
+ for server_round, loss in self.losses_distributed
100
+ ],
101
+ )
99
102
  )
100
103
  if self.losses_centralized:
101
- rep += "History (loss, centralized):\n" + reduce(
102
- lambda a, b: a + b,
103
- [
104
- f"\tround {server_round}: {loss}\n"
105
- for server_round, loss in self.losses_centralized
106
- ],
104
+ rep += "History (loss, centralized):\n" + pprint.pformat(
105
+ reduce(
106
+ lambda a, b: a + b,
107
+ [
108
+ f"\tround {server_round}: {loss}\n"
109
+ for server_round, loss in self.losses_centralized
110
+ ],
111
+ )
107
112
  )
108
113
  if self.metrics_distributed_fit:
109
- rep += "History (metrics, distributed, fit):\n" + str(
114
+ rep += "History (metrics, distributed, fit):\n" + pprint.pformat(
110
115
  self.metrics_distributed_fit
111
116
  )
112
117
  if self.metrics_distributed:
113
- rep += "History (metrics, distributed, evaluate):\n" + str(
118
+ rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat(
114
119
  self.metrics_distributed
115
120
  )
116
121
  if self.metrics_centralized:
117
- rep += "History (metrics, centralized):\n" + str(self.metrics_centralized)
122
+ rep += "History (metrics, centralized):\n" + pprint.pformat(
123
+ self.metrics_centralized
124
+ )
118
125
  return rep
@@ -17,15 +17,16 @@
17
17
 
18
18
  import argparse
19
19
  import sys
20
- from logging import DEBUG, WARN
20
+ from logging import DEBUG, INFO, WARN
21
21
  from pathlib import Path
22
22
  from typing import Optional
23
23
 
24
24
  from flwr.common import Context, EventType, RecordSet, event
25
- from flwr.common.logger import log
25
+ from flwr.common.logger import log, update_console_handler
26
+ from flwr.common.object_ref import load_app
26
27
 
27
28
  from .driver.driver import Driver
28
- from .server_app import ServerApp, load_server_app
29
+ from .server_app import LoadServerAppError, ServerApp
29
30
 
30
31
 
31
32
  def run(
@@ -47,7 +48,13 @@ def run(
47
48
  # Load ServerApp if needed
48
49
  def _load() -> ServerApp:
49
50
  if server_app_attr:
50
- server_app: ServerApp = load_server_app(server_app_attr)
51
+ server_app: ServerApp = load_app(server_app_attr, LoadServerAppError)
52
+
53
+ if not isinstance(server_app, ServerApp):
54
+ raise LoadServerAppError(
55
+ f"Attribute {server_app_attr} is not of type {ServerApp}",
56
+ ) from None
57
+
51
58
  if loaded_server_app:
52
59
  server_app = loaded_server_app
53
60
  return server_app
@@ -69,6 +76,12 @@ def run_server_app() -> None:
69
76
 
70
77
  args = _parse_args_run_server_app().parse_args()
71
78
 
79
+ update_console_handler(
80
+ level=DEBUG if args.verbose else INFO,
81
+ timestamps=args.verbose,
82
+ colored=True,
83
+ )
84
+
72
85
  # Obtain certificates
73
86
  if args.insecure:
74
87
  if args.root_certificates is not None:
@@ -146,6 +159,11 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
146
159
  help="Run the server app without HTTPS. By default, the app runs with "
147
160
  "HTTPS enabled. Use this flag only if you understand the risks.",
148
161
  )
162
+ parser.add_argument(
163
+ "--verbose",
164
+ action="store_true",
165
+ help="Set the logging to `DEBUG`.",
166
+ )
149
167
  parser.add_argument(
150
168
  "--root-certificates",
151
169
  metavar="ROOT_CERT",
flwr/server/server.py CHANGED
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import concurrent.futures
19
+ import io
19
20
  import timeit
20
21
  from logging import INFO, WARN
21
22
  from typing import Dict, List, Optional, Tuple, Union
@@ -83,14 +84,14 @@ class Server:
83
84
  return self._client_manager
84
85
 
85
86
  # pylint: disable=too-many-locals
86
- def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
87
+ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]:
87
88
  """Run federated averaging for a number of rounds."""
88
89
  history = History()
89
90
 
90
91
  # Initialize parameters
91
- log(INFO, "Initializing global parameters")
92
+ log(INFO, "[INIT]")
92
93
  self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
93
- log(INFO, "Evaluating initial parameters")
94
+ log(INFO, "Evaluating initial global parameters")
94
95
  res = self.strategy.evaluate(0, parameters=self.parameters)
95
96
  if res is not None:
96
97
  log(
@@ -103,10 +104,11 @@ class Server:
103
104
  history.add_metrics_centralized(server_round=0, metrics=res[1])
104
105
 
105
106
  # Run federated learning for num_rounds
106
- log(INFO, "FL starting")
107
107
  start_time = timeit.default_timer()
108
108
 
109
109
  for current_round in range(1, num_rounds + 1):
110
+ log(INFO, "")
111
+ log(INFO, "[ROUND %s]", current_round)
110
112
  # Train model and replace previous global model
111
113
  res_fit = self.fit_round(
112
114
  server_round=current_round,
@@ -152,8 +154,7 @@ class Server:
152
154
  # Bookkeeping
153
155
  end_time = timeit.default_timer()
154
156
  elapsed = end_time - start_time
155
- log(INFO, "FL finished in %s", elapsed)
156
- return history
157
+ return history, elapsed
157
158
 
158
159
  def evaluate_round(
159
160
  self,
@@ -170,12 +171,11 @@ class Server:
170
171
  client_manager=self._client_manager,
171
172
  )
172
173
  if not client_instructions:
173
- log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
174
+ log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
174
175
  return None
175
176
  log(
176
177
  INFO,
177
- "evaluate_round %s: strategy sampled %s clients (out of %s)",
178
- server_round,
178
+ "configure_evaluate: strategy sampled %s clients (out of %s)",
179
179
  len(client_instructions),
180
180
  self._client_manager.num_available(),
181
181
  )
@@ -189,8 +189,7 @@ class Server:
189
189
  )
190
190
  log(
191
191
  INFO,
192
- "evaluate_round %s received %s results and %s failures",
193
- server_round,
192
+ "aggregate_evaluate: received %s results and %s failures",
194
193
  len(results),
195
194
  len(failures),
196
195
  )
@@ -220,12 +219,11 @@ class Server:
220
219
  )
221
220
 
222
221
  if not client_instructions:
223
- log(INFO, "fit_round %s: no clients selected, cancel", server_round)
222
+ log(INFO, "configure_fit: no clients selected, cancel")
224
223
  return None
225
224
  log(
226
225
  INFO,
227
- "fit_round %s: strategy sampled %s clients (out of %s)",
228
- server_round,
226
+ "configure_fit: strategy sampled %s clients (out of %s)",
229
227
  len(client_instructions),
230
228
  self._client_manager.num_available(),
231
229
  )
@@ -239,8 +237,7 @@ class Server:
239
237
  )
240
238
  log(
241
239
  INFO,
242
- "fit_round %s received %s results and %s failures",
243
- server_round,
240
+ "aggregate_fit: received %s results and %s failures",
244
241
  len(results),
245
242
  len(failures),
246
243
  )
@@ -275,7 +272,7 @@ class Server:
275
272
  client_manager=self._client_manager
276
273
  )
277
274
  if parameters is not None:
278
- log(INFO, "Using initial parameters provided by strategy")
275
+ log(INFO, "Using initial global parameters provided by strategy")
279
276
  return parameters
280
277
 
281
278
  # Get initial parameters from one of the clients
@@ -483,12 +480,19 @@ def run_fl(
483
480
  config: ServerConfig,
484
481
  ) -> History:
485
482
  """Train a model on the given server and return the History object."""
486
- hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout)
487
- log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
488
- log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit))
489
- log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
490
- log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
491
- log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
483
+ hist, elapsed_time = server.fit(
484
+ num_rounds=config.num_rounds, timeout=config.round_timeout
485
+ )
486
+
487
+ log(INFO, "")
488
+ log(INFO, "[SUMMARY]")
489
+ log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
490
+ for idx, line in enumerate(io.StringIO(str(hist))):
491
+ if idx == 0:
492
+ log(INFO, "%s", line.strip("\n"))
493
+ else:
494
+ log(INFO, "\t%s", line.strip("\n"))
495
+ log(INFO, "")
492
496
 
493
497
  # Graceful shutdown
494
498
  server.disconnect_all_clients(timeout=config.round_timeout)
flwr/server/server_app.py CHANGED
@@ -15,8 +15,7 @@
15
15
  """Flower ServerApp."""
16
16
 
17
17
 
18
- import importlib
19
- from typing import Callable, Optional, cast
18
+ from typing import Callable, Optional
20
19
 
21
20
  from flwr.common import Context, RecordSet
22
21
  from flwr.server.strategy import Strategy
@@ -132,48 +131,3 @@ class ServerApp:
132
131
 
133
132
  class LoadServerAppError(Exception):
134
133
  """Error when trying to load `ServerApp`."""
135
-
136
-
137
- def load_server_app(module_attribute_str: str) -> ServerApp:
138
- """Load the `ServerApp` object specified in a module attribute string.
139
-
140
- The module/attribute string should have the form <module>:<attribute>. Valid
141
- examples include `server:app` and `project.package.module:wrapper.app`. It
142
- must refer to a module on the PYTHONPATH, the module needs to have the specified
143
- attribute, and the attribute must be of type `ServerApp`.
144
- """
145
- module_str, _, attributes_str = module_attribute_str.partition(":")
146
- if not module_str:
147
- raise LoadServerAppError(
148
- f"Missing module in {module_attribute_str}",
149
- ) from None
150
- if not attributes_str:
151
- raise LoadServerAppError(
152
- f"Missing attribute in {module_attribute_str}",
153
- ) from None
154
-
155
- # Load module
156
- try:
157
- module = importlib.import_module(module_str)
158
- except ModuleNotFoundError:
159
- raise LoadServerAppError(
160
- f"Unable to load module {module_str}",
161
- ) from None
162
-
163
- # Recursively load attribute
164
- attribute = module
165
- try:
166
- for attribute_str in attributes_str.split("."):
167
- attribute = getattr(attribute, attribute_str)
168
- except AttributeError:
169
- raise LoadServerAppError(
170
- f"Unable to load attribute {attributes_str} from module {module_str}",
171
- ) from None
172
-
173
- # Check type
174
- if not isinstance(attribute, ServerApp):
175
- raise LoadServerAppError(
176
- f"Attribute {attributes_str} is not of type {ServerApp}",
177
- ) from None
178
-
179
- return cast(ServerApp, attribute)
@@ -29,3 +29,12 @@ class ServerConfig:
29
29
 
30
30
  num_rounds: int = 1
31
31
  round_timeout: Optional[float] = None
32
+
33
+ def __repr__(self) -> str:
34
+ """Return the string representation of the ServerConfig."""
35
+ timeout_string = (
36
+ "no round_timeout"
37
+ if self.round_timeout is None
38
+ else f"round_timeout={self.round_timeout}s"
39
+ )
40
+ return f"num_rounds={self.num_rounds}, {timeout_string}"
@@ -84,6 +84,8 @@ class FedAvg(Strategy):
84
84
  Metrics aggregation function, optional.
85
85
  evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
86
86
  Metrics aggregation function, optional.
87
+ inplace : bool (default: True)
88
+ Enable (True) or disable (False) in-place aggregation of model updates.
87
89
  """
88
90
 
89
91
  # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
@@ -21,9 +21,10 @@ import traceback
21
21
  from logging import DEBUG, ERROR, INFO, WARN
22
22
  from typing import Callable, Dict, List, Optional
23
23
 
24
- from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
24
+ from flwr.client.client_app import ClientApp, LoadClientAppError
25
25
  from flwr.client.node_state import NodeState
26
26
  from flwr.common.logger import log
27
+ from flwr.common.object_ref import load_app
27
28
  from flwr.common.serde import message_from_taskins, message_to_taskres
28
29
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
29
30
  from flwr.server.superlink.state import StateFactory
@@ -305,7 +306,13 @@ def start_vce(
305
306
  def _load() -> ClientApp:
306
307
 
307
308
  if client_app_attr:
308
- app: ClientApp = load_client_app(client_app_attr)
309
+ app: ClientApp = load_app(client_app_attr, LoadClientAppError)
310
+
311
+ if not isinstance(app, ClientApp):
312
+ raise LoadClientAppError(
313
+ f"Attribute {client_app_attr} is not of type {ClientApp}",
314
+ ) from None
315
+
309
316
  if client_app:
310
317
  app = client_app
311
318
  return app
@@ -122,7 +122,8 @@ class InMemoryState(State):
122
122
  task_res.task_id = str(task_id)
123
123
  task_res.task.created_at = created_at.isoformat()
124
124
  task_res.task.ttl = ttl.isoformat()
125
- self.task_res_store[task_id] = task_res
125
+ with self.lock:
126
+ self.task_res_store[task_id] = task_res
126
127
 
127
128
  # Return the new task_id
128
129
  return task_id
@@ -132,46 +133,47 @@ class InMemoryState(State):
132
133
  if limit is not None and limit < 1:
133
134
  raise AssertionError("`limit` must be >= 1")
134
135
 
135
- # Find TaskRes that were not delivered yet
136
- task_res_list: List[TaskRes] = []
137
- for _, task_res in self.task_res_store.items():
138
- if (
139
- UUID(task_res.task.ancestry[0]) in task_ids
140
- and task_res.task.delivered_at == ""
141
- ):
142
- task_res_list.append(task_res)
143
- if limit and len(task_res_list) == limit:
144
- break
136
+ with self.lock:
137
+ # Find TaskRes that were not delivered yet
138
+ task_res_list: List[TaskRes] = []
139
+ for _, task_res in self.task_res_store.items():
140
+ if (
141
+ UUID(task_res.task.ancestry[0]) in task_ids
142
+ and task_res.task.delivered_at == ""
143
+ ):
144
+ task_res_list.append(task_res)
145
+ if limit and len(task_res_list) == limit:
146
+ break
145
147
 
146
- # Mark all of them as delivered
147
- delivered_at = now().isoformat()
148
- for task_res in task_res_list:
149
- task_res.task.delivered_at = delivered_at
148
+ # Mark all of them as delivered
149
+ delivered_at = now().isoformat()
150
+ for task_res in task_res_list:
151
+ task_res.task.delivered_at = delivered_at
150
152
 
151
- # Return TaskRes
152
- return task_res_list
153
+ # Return TaskRes
154
+ return task_res_list
153
155
 
154
156
  def delete_tasks(self, task_ids: Set[UUID]) -> None:
155
157
  """Delete all delivered TaskIns/TaskRes pairs."""
156
158
  task_ins_to_be_deleted: Set[UUID] = set()
157
159
  task_res_to_be_deleted: Set[UUID] = set()
158
160
 
159
- for task_ins_id in task_ids:
160
- # Find the task_id of the matching task_res
161
- for task_res_id, task_res in self.task_res_store.items():
162
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
163
- continue
164
- if task_res.task.delivered_at == "":
165
- continue
166
-
167
- task_ins_to_be_deleted.add(task_ins_id)
168
- task_res_to_be_deleted.add(task_res_id)
169
-
170
- for task_id in task_ins_to_be_deleted:
171
- with self.lock:
161
+ with self.lock:
162
+ for task_ins_id in task_ids:
163
+ # Find the task_id of the matching task_res
164
+ for task_res_id, task_res in self.task_res_store.items():
165
+ if UUID(task_res.task.ancestry[0]) != task_ins_id:
166
+ continue
167
+ if task_res.task.delivered_at == "":
168
+ continue
169
+
170
+ task_ins_to_be_deleted.add(task_ins_id)
171
+ task_res_to_be_deleted.add(task_res_id)
172
+
173
+ for task_id in task_ins_to_be_deleted:
172
174
  del self.task_ins_store[task_id]
173
- for task_id in task_res_to_be_deleted:
174
- del self.task_res_store[task_id]
175
+ for task_id in task_res_to_be_deleted:
176
+ del self.task_res_store[task_id]
175
177
 
176
178
  def num_task_ins(self) -> int:
177
179
  """Calculate the number of task_ins in store.
@@ -16,7 +16,10 @@
16
16
 
17
17
 
18
18
  from .default_workflows import DefaultWorkflow
19
+ from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow
19
20
 
20
21
  __all__ = [
21
22
  "DefaultWorkflow",
23
+ "SecAggPlusWorkflow",
24
+ "SecAggWorkflow",
22
25
  ]
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Constants for default workflows."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ MAIN_CONFIGS_RECORD = "config"
21
+ MAIN_PARAMS_RECORD = "parameters"
22
+
23
+
24
+ class Key:
25
+ """Constants for default workflows."""
26
+
27
+ CURRENT_ROUND = "current_round"
28
+ START_TIME = "start_time"
29
+
30
+ def __new__(cls) -> Key:
31
+ """Prevent instantiation."""
32
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")