flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
 
20
20
  RECORD_KEY_STATE = "secaggplus_state"
21
21
  RECORD_KEY_CONFIGS = "secaggplus_configs"
22
+ RATIO_QUANTIZATION_RANGE = 1073741824 # 1 << 30
22
23
 
23
24
 
24
25
  class Stage:
@@ -26,9 +27,9 @@ class Stage:
26
27
 
27
28
  SETUP = "setup"
28
29
  SHARE_KEYS = "share_keys"
29
- COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ COLLECT_MASKED_VECTORS = "collect_masked_vectors"
30
31
  UNMASK = "unmask"
31
- _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK)
32
33
 
33
34
  @classmethod
34
35
  def all(cls) -> tuple[str, str, str, str]:
@@ -45,12 +46,12 @@ class Key:
45
46
 
46
47
  STAGE = "stage"
47
48
  SAMPLE_NUMBER = "sample_num"
48
- SECURE_ID = "secure_id"
49
49
  SHARE_NUMBER = "share_num"
50
50
  THRESHOLD = "threshold"
51
51
  CLIPPING_RANGE = "clipping_range"
52
52
  TARGET_RANGE = "target_range"
53
53
  MOD_RANGE = "mod_range"
54
+ MAX_WEIGHT = "max_weight"
54
55
  PUBLIC_KEY_1 = "pk1"
55
56
  PUBLIC_KEY_2 = "pk2"
56
57
  DESTINATION_LIST = "dsts"
@@ -58,9 +59,9 @@ class Key:
58
59
  SOURCE_LIST = "srcs"
59
60
  PARAMETERS = "params"
60
61
  MASKED_PARAMETERS = "masked_params"
61
- ACTIVE_SECURE_ID_LIST = "active_sids"
62
- DEAD_SECURE_ID_LIST = "dead_sids"
63
- SECURE_ID_LIST = "sids"
62
+ ACTIVE_NODE_ID_LIST = "active_nids"
63
+ DEAD_NODE_ID_LIST = "dead_nids"
64
+ NODE_ID_LIST = "nids"
64
65
  SHARE_LIST = "shares"
65
66
 
66
67
  def __new__(cls) -> Key:
@@ -23,16 +23,16 @@ from flwr.common.typing import NDArrayInt
23
23
 
24
24
 
25
25
  def share_keys_plaintext_concat(
26
- source: int, destination: int, b_share: bytes, sk_share: bytes
26
+ src_node_id: int, dst_node_id: int, b_share: bytes, sk_share: bytes
27
27
  ) -> bytes:
28
28
  """Combine arguments to bytes.
29
29
 
30
30
  Parameters
31
31
  ----------
32
- source : int
33
- the secure ID of the source.
34
- destination : int
35
- the secure ID of the destination.
32
+ src_node_id : int
33
+ the node ID of the source.
34
+ dst_node_id : int
35
+ the node ID of the destination.
36
36
  b_share : bytes
37
37
  the private key share of the source sent to the destination.
38
38
  sk_share : bytes
@@ -45,8 +45,8 @@ def share_keys_plaintext_concat(
45
45
  """
46
46
  return b"".join(
47
47
  [
48
- int.to_bytes(source, 4, "little"),
49
- int.to_bytes(destination, 4, "little"),
48
+ int.to_bytes(src_node_id, 8, "little", signed=True),
49
+ int.to_bytes(dst_node_id, 8, "little", signed=True),
50
50
  int.to_bytes(len(b_share), 4, "little"),
51
51
  b_share,
52
52
  sk_share,
@@ -64,21 +64,21 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by
64
64
 
65
65
  Returns
66
66
  -------
67
- source : int
68
- the secure ID of the source.
69
- destination : int
70
- the secure ID of the destination.
67
+ src_node_id : int
68
+ the node ID of the source.
69
+ dst_node_id : int
70
+ the node ID of the destination.
71
71
  b_share : bytes
72
72
  the private key share of the source sent to the destination.
73
73
  sk_share : bytes
74
74
  the secret key share of the source sent to the destination.
75
75
  """
76
76
  src, dst, mark = (
77
- int.from_bytes(plaintext[:4], "little"),
78
- int.from_bytes(plaintext[4:8], "little"),
79
- int.from_bytes(plaintext[8:12], "little"),
77
+ int.from_bytes(plaintext[:8], "little", signed=True),
78
+ int.from_bytes(plaintext[8:16], "little", signed=True),
79
+ int.from_bytes(plaintext[16:20], "little"),
80
80
  )
81
- ret = (src, dst, plaintext[12 : 12 + mark], plaintext[12 + mark :])
81
+ ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :])
82
82
  return ret
83
83
 
84
84
 
flwr/server/compat/app.py CHANGED
@@ -127,9 +127,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
127
127
  )
128
128
  log(
129
129
  INFO,
130
- "Starting Flower server, config: %s",
130
+ "Starting Flower ServerApp, config: %s",
131
131
  initialized_config,
132
132
  )
133
+ log(INFO, "")
133
134
 
134
135
  # Start the thread updating nodes
135
136
  thread, f_stop = start_update_client_manager_thread(
@@ -15,7 +15,7 @@
15
15
  """Flower driver service client."""
16
16
 
17
17
 
18
- from logging import ERROR, INFO, WARNING
18
+ from logging import DEBUG, ERROR, WARNING
19
19
  from typing import Optional
20
20
 
21
21
  import grpc
@@ -70,19 +70,19 @@ class GrpcDriver:
70
70
  root_certificates=self.root_certificates,
71
71
  )
72
72
  self.stub = DriverStub(self.channel)
73
- log(INFO, "[Driver] Connected to %s", self.driver_service_address)
73
+ log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
74
74
 
75
75
  def disconnect(self) -> None:
76
76
  """Disconnect from the Driver API."""
77
77
  event(EventType.DRIVER_DISCONNECT)
78
78
  if self.channel is None or self.stub is None:
79
- log(WARNING, "Already disconnected")
79
+ log(DEBUG, "Already disconnected")
80
80
  return
81
81
  channel = self.channel
82
82
  self.channel = None
83
83
  self.stub = None
84
84
  channel.close()
85
- log(INFO, "[Driver] Disconnected")
85
+ log(DEBUG, "[Driver] Disconnected")
86
86
 
87
87
  def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
88
88
  """Request for run ID."""
flwr/server/history.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Training history."""
16
16
 
17
17
 
18
+ import pprint
18
19
  from functools import reduce
19
20
  from typing import Dict, List, Tuple
20
21
 
@@ -90,29 +91,35 @@ class History:
90
91
  """
91
92
  rep = ""
92
93
  if self.losses_distributed:
93
- rep += "History (loss, distributed):\n" + reduce(
94
- lambda a, b: a + b,
95
- [
96
- f"\tround {server_round}: {loss}\n"
97
- for server_round, loss in self.losses_distributed
98
- ],
94
+ rep += "History (loss, distributed):\n" + pprint.pformat(
95
+ reduce(
96
+ lambda a, b: a + b,
97
+ [
98
+ f"\tround {server_round}: {loss}\n"
99
+ for server_round, loss in self.losses_distributed
100
+ ],
101
+ )
99
102
  )
100
103
  if self.losses_centralized:
101
- rep += "History (loss, centralized):\n" + reduce(
102
- lambda a, b: a + b,
103
- [
104
- f"\tround {server_round}: {loss}\n"
105
- for server_round, loss in self.losses_centralized
106
- ],
104
+ rep += "History (loss, centralized):\n" + pprint.pformat(
105
+ reduce(
106
+ lambda a, b: a + b,
107
+ [
108
+ f"\tround {server_round}: {loss}\n"
109
+ for server_round, loss in self.losses_centralized
110
+ ],
111
+ )
107
112
  )
108
113
  if self.metrics_distributed_fit:
109
- rep += "History (metrics, distributed, fit):\n" + str(
114
+ rep += "History (metrics, distributed, fit):\n" + pprint.pformat(
110
115
  self.metrics_distributed_fit
111
116
  )
112
117
  if self.metrics_distributed:
113
- rep += "History (metrics, distributed, evaluate):\n" + str(
118
+ rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat(
114
119
  self.metrics_distributed
115
120
  )
116
121
  if self.metrics_centralized:
117
- rep += "History (metrics, centralized):\n" + str(self.metrics_centralized)
122
+ rep += "History (metrics, centralized):\n" + pprint.pformat(
123
+ self.metrics_centralized
124
+ )
118
125
  return rep
@@ -17,15 +17,16 @@
17
17
 
18
18
  import argparse
19
19
  import sys
20
- from logging import DEBUG, WARN
20
+ from logging import DEBUG, INFO, WARN
21
21
  from pathlib import Path
22
22
  from typing import Optional
23
23
 
24
24
  from flwr.common import Context, EventType, RecordSet, event
25
- from flwr.common.logger import log
25
+ from flwr.common.logger import log, update_console_handler
26
+ from flwr.common.object_ref import load_app
26
27
 
27
28
  from .driver.driver import Driver
28
- from .server_app import ServerApp, load_server_app
29
+ from .server_app import LoadServerAppError, ServerApp
29
30
 
30
31
 
31
32
  def run(
@@ -47,7 +48,13 @@ def run(
47
48
  # Load ServerApp if needed
48
49
  def _load() -> ServerApp:
49
50
  if server_app_attr:
50
- server_app: ServerApp = load_server_app(server_app_attr)
51
+ server_app: ServerApp = load_app(server_app_attr, LoadServerAppError)
52
+
53
+ if not isinstance(server_app, ServerApp):
54
+ raise LoadServerAppError(
55
+ f"Attribute {server_app_attr} is not of type {ServerApp}",
56
+ ) from None
57
+
51
58
  if loaded_server_app:
52
59
  server_app = loaded_server_app
53
60
  return server_app
@@ -69,6 +76,12 @@ def run_server_app() -> None:
69
76
 
70
77
  args = _parse_args_run_server_app().parse_args()
71
78
 
79
+ update_console_handler(
80
+ level=DEBUG if args.verbose else INFO,
81
+ timestamps=args.verbose,
82
+ colored=True,
83
+ )
84
+
72
85
  # Obtain certificates
73
86
  if args.insecure:
74
87
  if args.root_certificates is not None:
@@ -146,6 +159,11 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
146
159
  help="Run the server app without HTTPS. By default, the app runs with "
147
160
  "HTTPS enabled. Use this flag only if you understand the risks.",
148
161
  )
162
+ parser.add_argument(
163
+ "--verbose",
164
+ action="store_true",
165
+ help="Set the logging to `DEBUG`.",
166
+ )
149
167
  parser.add_argument(
150
168
  "--root-certificates",
151
169
  metavar="ROOT_CERT",
flwr/server/server.py CHANGED
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import concurrent.futures
19
+ import io
19
20
  import timeit
20
21
  from logging import INFO, WARN
21
22
  from typing import Dict, List, Optional, Tuple, Union
@@ -83,14 +84,14 @@ class Server:
83
84
  return self._client_manager
84
85
 
85
86
  # pylint: disable=too-many-locals
86
- def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
87
+ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]:
87
88
  """Run federated averaging for a number of rounds."""
88
89
  history = History()
89
90
 
90
91
  # Initialize parameters
91
- log(INFO, "Initializing global parameters")
92
+ log(INFO, "[INIT]")
92
93
  self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
93
- log(INFO, "Evaluating initial parameters")
94
+ log(INFO, "Evaluating initial global parameters")
94
95
  res = self.strategy.evaluate(0, parameters=self.parameters)
95
96
  if res is not None:
96
97
  log(
@@ -103,10 +104,11 @@ class Server:
103
104
  history.add_metrics_centralized(server_round=0, metrics=res[1])
104
105
 
105
106
  # Run federated learning for num_rounds
106
- log(INFO, "FL starting")
107
107
  start_time = timeit.default_timer()
108
108
 
109
109
  for current_round in range(1, num_rounds + 1):
110
+ log(INFO, "")
111
+ log(INFO, "[ROUND %s]", current_round)
110
112
  # Train model and replace previous global model
111
113
  res_fit = self.fit_round(
112
114
  server_round=current_round,
@@ -152,8 +154,7 @@ class Server:
152
154
  # Bookkeeping
153
155
  end_time = timeit.default_timer()
154
156
  elapsed = end_time - start_time
155
- log(INFO, "FL finished in %s", elapsed)
156
- return history
157
+ return history, elapsed
157
158
 
158
159
  def evaluate_round(
159
160
  self,
@@ -170,12 +171,11 @@ class Server:
170
171
  client_manager=self._client_manager,
171
172
  )
172
173
  if not client_instructions:
173
- log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
174
+ log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
174
175
  return None
175
176
  log(
176
177
  INFO,
177
- "evaluate_round %s: strategy sampled %s clients (out of %s)",
178
- server_round,
178
+ "configure_evaluate: strategy sampled %s clients (out of %s)",
179
179
  len(client_instructions),
180
180
  self._client_manager.num_available(),
181
181
  )
@@ -189,8 +189,7 @@ class Server:
189
189
  )
190
190
  log(
191
191
  INFO,
192
- "evaluate_round %s received %s results and %s failures",
193
- server_round,
192
+ "aggregate_evaluate: received %s results and %s failures",
194
193
  len(results),
195
194
  len(failures),
196
195
  )
@@ -220,12 +219,11 @@ class Server:
220
219
  )
221
220
 
222
221
  if not client_instructions:
223
- log(INFO, "fit_round %s: no clients selected, cancel", server_round)
222
+ log(INFO, "configure_fit: no clients selected, cancel")
224
223
  return None
225
224
  log(
226
225
  INFO,
227
- "fit_round %s: strategy sampled %s clients (out of %s)",
228
- server_round,
226
+ "configure_fit: strategy sampled %s clients (out of %s)",
229
227
  len(client_instructions),
230
228
  self._client_manager.num_available(),
231
229
  )
@@ -239,8 +237,7 @@ class Server:
239
237
  )
240
238
  log(
241
239
  INFO,
242
- "fit_round %s received %s results and %s failures",
243
- server_round,
240
+ "aggregate_fit: received %s results and %s failures",
244
241
  len(results),
245
242
  len(failures),
246
243
  )
@@ -275,7 +272,7 @@ class Server:
275
272
  client_manager=self._client_manager
276
273
  )
277
274
  if parameters is not None:
278
- log(INFO, "Using initial parameters provided by strategy")
275
+ log(INFO, "Using initial global parameters provided by strategy")
279
276
  return parameters
280
277
 
281
278
  # Get initial parameters from one of the clients
@@ -483,12 +480,19 @@ def run_fl(
483
480
  config: ServerConfig,
484
481
  ) -> History:
485
482
  """Train a model on the given server and return the History object."""
486
- hist = server.fit(num_rounds=config.num_rounds, timeout=config.round_timeout)
487
- log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
488
- log(INFO, "app_fit: metrics_distributed_fit %s", str(hist.metrics_distributed_fit))
489
- log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
490
- log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
491
- log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
483
+ hist, elapsed_time = server.fit(
484
+ num_rounds=config.num_rounds, timeout=config.round_timeout
485
+ )
486
+
487
+ log(INFO, "")
488
+ log(INFO, "[SUMMARY]")
489
+ log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
490
+ for idx, line in enumerate(io.StringIO(str(hist))):
491
+ if idx == 0:
492
+ log(INFO, "%s", line.strip("\n"))
493
+ else:
494
+ log(INFO, "\t%s", line.strip("\n"))
495
+ log(INFO, "")
492
496
 
493
497
  # Graceful shutdown
494
498
  server.disconnect_all_clients(timeout=config.round_timeout)
flwr/server/server_app.py CHANGED
@@ -15,8 +15,7 @@
15
15
  """Flower ServerApp."""
16
16
 
17
17
 
18
- import importlib
19
- from typing import Callable, Optional, cast
18
+ from typing import Callable, Optional
20
19
 
21
20
  from flwr.common import Context, RecordSet
22
21
  from flwr.server.strategy import Strategy
@@ -132,48 +131,3 @@ class ServerApp:
132
131
 
133
132
  class LoadServerAppError(Exception):
134
133
  """Error when trying to load `ServerApp`."""
135
-
136
-
137
- def load_server_app(module_attribute_str: str) -> ServerApp:
138
- """Load the `ServerApp` object specified in a module attribute string.
139
-
140
- The module/attribute string should have the form <module>:<attribute>. Valid
141
- examples include `server:app` and `project.package.module:wrapper.app`. It
142
- must refer to a module on the PYTHONPATH, the module needs to have the specified
143
- attribute, and the attribute must be of type `ServerApp`.
144
- """
145
- module_str, _, attributes_str = module_attribute_str.partition(":")
146
- if not module_str:
147
- raise LoadServerAppError(
148
- f"Missing module in {module_attribute_str}",
149
- ) from None
150
- if not attributes_str:
151
- raise LoadServerAppError(
152
- f"Missing attribute in {module_attribute_str}",
153
- ) from None
154
-
155
- # Load module
156
- try:
157
- module = importlib.import_module(module_str)
158
- except ModuleNotFoundError:
159
- raise LoadServerAppError(
160
- f"Unable to load module {module_str}",
161
- ) from None
162
-
163
- # Recursively load attribute
164
- attribute = module
165
- try:
166
- for attribute_str in attributes_str.split("."):
167
- attribute = getattr(attribute, attribute_str)
168
- except AttributeError:
169
- raise LoadServerAppError(
170
- f"Unable to load attribute {attributes_str} from module {module_str}",
171
- ) from None
172
-
173
- # Check type
174
- if not isinstance(attribute, ServerApp):
175
- raise LoadServerAppError(
176
- f"Attribute {attributes_str} is not of type {ServerApp}",
177
- ) from None
178
-
179
- return cast(ServerApp, attribute)
@@ -29,3 +29,12 @@ class ServerConfig:
29
29
 
30
30
  num_rounds: int = 1
31
31
  round_timeout: Optional[float] = None
32
+
33
+ def __repr__(self) -> str:
34
+ """Return the string representation of the ServerConfig."""
35
+ timeout_string = (
36
+ "no round_timeout"
37
+ if self.round_timeout is None
38
+ else f"round_timeout={self.round_timeout}s"
39
+ )
40
+ return f"num_rounds={self.num_rounds}, {timeout_string}"
@@ -84,6 +84,8 @@ class FedAvg(Strategy):
84
84
  Metrics aggregation function, optional.
85
85
  evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
86
86
  Metrics aggregation function, optional.
87
+ inplace : bool (default: True)
88
+ Enable (True) or disable (False) in-place aggregation of model updates.
87
89
  """
88
90
 
89
91
  # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
@@ -21,9 +21,10 @@ import traceback
21
21
  from logging import DEBUG, ERROR, INFO, WARN
22
22
  from typing import Callable, Dict, List, Optional
23
23
 
24
- from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
24
+ from flwr.client.client_app import ClientApp, LoadClientAppError
25
25
  from flwr.client.node_state import NodeState
26
26
  from flwr.common.logger import log
27
+ from flwr.common.object_ref import load_app
27
28
  from flwr.common.serde import message_from_taskins, message_to_taskres
28
29
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
29
30
  from flwr.server.superlink.state import StateFactory
@@ -305,7 +306,13 @@ def start_vce(
305
306
  def _load() -> ClientApp:
306
307
 
307
308
  if client_app_attr:
308
- app: ClientApp = load_client_app(client_app_attr)
309
+ app: ClientApp = load_app(client_app_attr, LoadClientAppError)
310
+
311
+ if not isinstance(app, ClientApp):
312
+ raise LoadClientAppError(
313
+ f"Attribute {client_app_attr} is not of type {ClientApp}",
314
+ ) from None
315
+
309
316
  if client_app:
310
317
  app = client_app
311
318
  return app
@@ -122,7 +122,8 @@ class InMemoryState(State):
122
122
  task_res.task_id = str(task_id)
123
123
  task_res.task.created_at = created_at.isoformat()
124
124
  task_res.task.ttl = ttl.isoformat()
125
- self.task_res_store[task_id] = task_res
125
+ with self.lock:
126
+ self.task_res_store[task_id] = task_res
126
127
 
127
128
  # Return the new task_id
128
129
  return task_id
@@ -132,46 +133,47 @@ class InMemoryState(State):
132
133
  if limit is not None and limit < 1:
133
134
  raise AssertionError("`limit` must be >= 1")
134
135
 
135
- # Find TaskRes that were not delivered yet
136
- task_res_list: List[TaskRes] = []
137
- for _, task_res in self.task_res_store.items():
138
- if (
139
- UUID(task_res.task.ancestry[0]) in task_ids
140
- and task_res.task.delivered_at == ""
141
- ):
142
- task_res_list.append(task_res)
143
- if limit and len(task_res_list) == limit:
144
- break
136
+ with self.lock:
137
+ # Find TaskRes that were not delivered yet
138
+ task_res_list: List[TaskRes] = []
139
+ for _, task_res in self.task_res_store.items():
140
+ if (
141
+ UUID(task_res.task.ancestry[0]) in task_ids
142
+ and task_res.task.delivered_at == ""
143
+ ):
144
+ task_res_list.append(task_res)
145
+ if limit and len(task_res_list) == limit:
146
+ break
145
147
 
146
- # Mark all of them as delivered
147
- delivered_at = now().isoformat()
148
- for task_res in task_res_list:
149
- task_res.task.delivered_at = delivered_at
148
+ # Mark all of them as delivered
149
+ delivered_at = now().isoformat()
150
+ for task_res in task_res_list:
151
+ task_res.task.delivered_at = delivered_at
150
152
 
151
- # Return TaskRes
152
- return task_res_list
153
+ # Return TaskRes
154
+ return task_res_list
153
155
 
154
156
  def delete_tasks(self, task_ids: Set[UUID]) -> None:
155
157
  """Delete all delivered TaskIns/TaskRes pairs."""
156
158
  task_ins_to_be_deleted: Set[UUID] = set()
157
159
  task_res_to_be_deleted: Set[UUID] = set()
158
160
 
159
- for task_ins_id in task_ids:
160
- # Find the task_id of the matching task_res
161
- for task_res_id, task_res in self.task_res_store.items():
162
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
163
- continue
164
- if task_res.task.delivered_at == "":
165
- continue
166
-
167
- task_ins_to_be_deleted.add(task_ins_id)
168
- task_res_to_be_deleted.add(task_res_id)
169
-
170
- for task_id in task_ins_to_be_deleted:
171
- with self.lock:
161
+ with self.lock:
162
+ for task_ins_id in task_ids:
163
+ # Find the task_id of the matching task_res
164
+ for task_res_id, task_res in self.task_res_store.items():
165
+ if UUID(task_res.task.ancestry[0]) != task_ins_id:
166
+ continue
167
+ if task_res.task.delivered_at == "":
168
+ continue
169
+
170
+ task_ins_to_be_deleted.add(task_ins_id)
171
+ task_res_to_be_deleted.add(task_res_id)
172
+
173
+ for task_id in task_ins_to_be_deleted:
172
174
  del self.task_ins_store[task_id]
173
- for task_id in task_res_to_be_deleted:
174
- del self.task_res_store[task_id]
175
+ for task_id in task_res_to_be_deleted:
176
+ del self.task_res_store[task_id]
175
177
 
176
178
  def num_task_ins(self) -> int:
177
179
  """Calculate the number of task_ins in store.
@@ -16,7 +16,10 @@
16
16
 
17
17
 
18
18
  from .default_workflows import DefaultWorkflow
19
+ from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow
19
20
 
20
21
  __all__ = [
21
22
  "DefaultWorkflow",
23
+ "SecAggPlusWorkflow",
24
+ "SecAggWorkflow",
22
25
  ]
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Constants for default workflows."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ MAIN_CONFIGS_RECORD = "config"
21
+ MAIN_PARAMS_RECORD = "parameters"
22
+
23
+
24
+ class Key:
25
+ """Constants for default workflows."""
26
+
27
+ CURRENT_ROUND = "current_round"
28
+ START_TIME = "start_time"
29
+
30
+ def __new__(cls) -> Key:
31
+ """Prevent instantiation."""
32
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")