flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flwr/cli/new/new.py CHANGED
@@ -22,7 +22,7 @@ from typing import Dict, Optional
22
22
  import typer
23
23
  from typing_extensions import Annotated
24
24
 
25
- from ..utils import prompt_options
25
+ from ..utils import prompt_options, prompt_text
26
26
 
27
27
 
28
28
  class MlFramework(str, Enum):
@@ -72,9 +72,9 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
72
72
 
73
73
  def new(
74
74
  project_name: Annotated[
75
- str,
75
+ Optional[str],
76
76
  typer.Argument(metavar="project_name", help="The name of the project"),
77
- ],
77
+ ] = None,
78
78
  framework: Annotated[
79
79
  Optional[MlFramework],
80
80
  typer.Option(case_sensitive=False, help="The ML framework to use"),
@@ -83,6 +83,9 @@ def new(
83
83
  """Create new Flower project."""
84
84
  print(f"Creating Flower project {project_name}...")
85
85
 
86
+ if project_name is None:
87
+ project_name = prompt_text("Please provide project name")
88
+
86
89
  if framework is not None:
87
90
  framework_str = str(framework.value)
88
91
  else:
flwr/cli/utils.py CHANGED
@@ -14,11 +14,24 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface utils."""
16
16
 
17
- from typing import List
17
+ from typing import List, cast
18
18
 
19
19
  import typer
20
20
 
21
21
 
22
+ def prompt_text(text: str) -> str:
23
+ """Ask user to enter text input."""
24
+ while True:
25
+ result = typer.prompt(
26
+ typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True)
27
+ )
28
+ if len(result) > 0:
29
+ break
30
+ print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True))
31
+
32
+ return cast(str, result)
33
+
34
+
22
35
  def prompt_options(text: str, options: List[str]) -> str:
23
36
  """Ask user to select one of the given options and return the selected item."""
24
37
  # Turn options into a list with index as in " [ 0] quickstart-pytorch"
flwr/client/app.py CHANGED
@@ -456,7 +456,19 @@ def _start_client_internal(
456
456
  time.sleep(3) # Wait for 3s before asking again
457
457
  continue
458
458
 
459
- log(INFO, "Received message")
459
+ log(INFO, "")
460
+ log(
461
+ INFO,
462
+ "[RUN %s, ROUND %s]",
463
+ message.metadata.run_id,
464
+ message.metadata.group_id,
465
+ )
466
+ log(
467
+ INFO,
468
+ "Received: %s message %s",
469
+ message.metadata.message_type,
470
+ message.metadata.message_id,
471
+ )
460
472
 
461
473
  # Handle control message
462
474
  out_message, sleep_duration = handle_control_message(message)
@@ -484,7 +496,18 @@ def _start_client_internal(
484
496
 
485
497
  # Send
486
498
  send(out_message)
487
- log(INFO, "Sent reply")
499
+ log(
500
+ INFO,
501
+ "[RUN %s, ROUND %s]",
502
+ out_message.metadata.run_id,
503
+ out_message.metadata.group_id,
504
+ )
505
+ log(
506
+ INFO,
507
+ "Sent: %s reply to message %s",
508
+ out_message.metadata.message_type,
509
+ message.metadata.message_id,
510
+ )
488
511
 
489
512
  # Unregister node
490
513
  if delete_node is not None:
@@ -17,7 +17,7 @@
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
19
  from .localdp_mod import LocalDpMod
20
- from .secure_aggregation.secaggplus_mod import secaggplus_mod
20
+ from .secure_aggregation import secagg_mod, secaggplus_mod
21
21
  from .utils import make_ffn
22
22
 
23
23
  __all__ = [
@@ -25,5 +25,6 @@ __all__ = [
25
25
  "fixedclipping_mod",
26
26
  "LocalDpMod",
27
27
  "make_ffn",
28
+ "secagg_mod",
28
29
  "secaggplus_mod",
29
30
  ]
@@ -15,8 +15,10 @@
15
15
  """Secure Aggregation mods."""
16
16
 
17
17
 
18
+ from .secagg_mod import secagg_mod
18
19
  from .secaggplus_mod import secaggplus_mod
19
20
 
20
21
  __all__ = [
22
+ "secagg_mod",
21
23
  "secaggplus_mod",
22
24
  ]
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Modifier for the SecAgg protocol."""
16
+
17
+
18
+ from flwr.client.typing import ClientAppCallable
19
+ from flwr.common import Context, Message
20
+
21
+ from .secaggplus_mod import secaggplus_mod
22
+
23
+
24
+ def secagg_mod(
25
+ msg: Message,
26
+ ctxt: Context,
27
+ call_next: ClientAppCallable,
28
+ ) -> Message:
29
+ """Handle incoming message and return results, following the SecAgg protocol."""
30
+ return secaggplus_mod(msg, ctxt, call_next)
@@ -12,19 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Message handler for the SecAgg+ protocol."""
15
+ """Modifier for the SecAgg+ protocol."""
16
16
 
17
17
 
18
18
  import os
19
19
  from dataclasses import dataclass, field
20
- from logging import INFO, WARNING
21
- from typing import Any, Callable, Dict, List, Tuple, cast
20
+ from logging import DEBUG, WARNING
21
+ from typing import Any, Dict, List, Tuple, cast
22
22
 
23
23
  from flwr.client.typing import ClientAppCallable
24
24
  from flwr.common import (
25
25
  ConfigsRecord,
26
26
  Context,
27
27
  Message,
28
+ Parameters,
28
29
  RecordSet,
29
30
  ndarray_to_bytes,
30
31
  parameters_to_ndarrays,
@@ -62,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
62
63
  share_keys_plaintext_concat,
63
64
  share_keys_plaintext_separate,
64
65
  )
65
- from flwr.common.typing import ConfigsRecordValues, FitRes
66
+ from flwr.common.typing import ConfigsRecordValues
66
67
 
67
68
 
68
69
  @dataclass
@@ -132,18 +133,6 @@ class SecAggPlusState:
132
133
  return ret
133
134
 
134
135
 
135
- def _get_fit_fn(
136
- msg: Message, ctxt: Context, call_next: ClientAppCallable
137
- ) -> Callable[[], FitRes]:
138
- """Get the fit function."""
139
-
140
- def fit() -> FitRes:
141
- out_msg = call_next(msg, ctxt)
142
- return compat.recordset_to_fitres(out_msg.content, keep_input=False)
143
-
144
- return fit
145
-
146
-
147
136
  def secaggplus_mod(
148
137
  msg: Message,
149
138
  ctxt: Context,
@@ -173,25 +162,32 @@ def secaggplus_mod(
173
162
  check_configs(state.current_stage, configs)
174
163
 
175
164
  # Execute
165
+ out_content = RecordSet()
176
166
  if state.current_stage == Stage.SETUP:
177
167
  state.nid = msg.metadata.dst_node_id
178
168
  res = _setup(state, configs)
179
169
  elif state.current_stage == Stage.SHARE_KEYS:
180
170
  res = _share_keys(state, configs)
181
- elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
182
- fit = _get_fit_fn(msg, ctxt, call_next)
183
- res = _collect_masked_input(state, configs, fit)
171
+ elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
172
+ out_msg = call_next(msg, ctxt)
173
+ out_content = out_msg.content
174
+ fitres = compat.recordset_to_fitres(out_content, keep_input=True)
175
+ res = _collect_masked_vectors(
176
+ state, configs, fitres.num_examples, fitres.parameters
177
+ )
178
+ for p_record in out_content.parameters_records.values():
179
+ p_record.clear()
184
180
  elif state.current_stage == Stage.UNMASK:
185
181
  res = _unmask(state, configs)
186
182
  else:
187
- raise ValueError(f"Unknown secagg stage: {state.current_stage}")
183
+ raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
188
184
 
189
185
  # Save state
190
186
  ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())
191
187
 
192
188
  # Return message
193
- content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)})
194
- return msg.create_reply(content, ttl="")
189
+ out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
190
+ return msg.create_reply(out_content, ttl="")
195
191
 
196
192
 
197
193
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
@@ -199,7 +195,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
199
195
  # Check the existence of Config.STAGE
200
196
  if Key.STAGE not in configs:
201
197
  raise KeyError(
202
- f"The required key '{Key.STAGE}' is missing from the input `named_values`."
198
+ f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
203
199
  )
204
200
 
205
201
  # Check the value type of the Config.STAGE
@@ -215,7 +211,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
215
211
  if current_stage != Stage.UNMASK:
216
212
  log(WARNING, "Restart from the setup stage")
217
213
  # If stage is not "setup",
218
- # the stage from `named_values` should be the expected next stage
214
+ # the stage from configs should be the expected next stage
219
215
  else:
220
216
  stages = Stage.all()
221
217
  expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
@@ -229,7 +225,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
229
225
  # pylint: disable-next=too-many-branches
230
226
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
231
227
  """Check the validity of the configs."""
232
- # Check `named_values` for the setup stage
228
+ # Check configs for the setup stage
233
229
  if stage == Stage.SETUP:
234
230
  key_type_pairs = [
235
231
  (Key.SAMPLE_NUMBER, int),
@@ -243,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
243
239
  if key not in configs:
244
240
  raise KeyError(
245
241
  f"Stage {Stage.SETUP}: the required key '{key}' is "
246
- "missing from the input `named_values`."
242
+ "missing from the ConfigsRecord."
247
243
  )
248
244
  # Bool is a subclass of int in Python,
249
245
  # so `isinstance(v, int)` will return True even if v is a boolean.
@@ -266,7 +262,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
266
262
  f"Stage {Stage.SHARE_KEYS}: "
267
263
  f"the value for the key '{key}' must be a list of two bytes."
268
264
  )
269
- elif stage == Stage.COLLECT_MASKED_INPUT:
265
+ elif stage == Stage.COLLECT_MASKED_VECTORS:
270
266
  key_type_pairs = [
271
267
  (Key.CIPHERTEXT_LIST, bytes),
272
268
  (Key.SOURCE_LIST, int),
@@ -274,9 +270,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
274
270
  for key, expected_type in key_type_pairs:
275
271
  if key not in configs:
276
272
  raise KeyError(
277
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
273
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
278
274
  f"the required key '{key}' is "
279
- "missing from the input `named_values`."
275
+ "missing from the ConfigsRecord."
280
276
  )
281
277
  if not isinstance(configs[key], list) or any(
282
278
  elm
@@ -285,7 +281,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
285
281
  if type(elm) is not expected_type
286
282
  ):
287
283
  raise TypeError(
288
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
284
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
289
285
  f"the value for the key '{key}' "
290
286
  f"must be of type List[{expected_type.__name__}]"
291
287
  )
@@ -299,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
299
295
  raise KeyError(
300
296
  f"Stage {Stage.UNMASK}: "
301
297
  f"the required key '{key}' is "
302
- "missing from the input `named_values`."
298
+ "missing from the ConfigsRecord."
303
299
  )
304
300
  if not isinstance(configs[key], list) or any(
305
301
  elm
@@ -322,7 +318,7 @@ def _setup(
322
318
  # Assigning parameter values to object fields
323
319
  sec_agg_param_dict = configs
324
320
  state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
325
- log(INFO, "Node %d: starting stage 0...", state.nid)
321
+ log(DEBUG, "Node %d: starting stage 0...", state.nid)
326
322
 
327
323
  state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
324
  state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
@@ -347,7 +343,7 @@ def _setup(
347
343
 
348
344
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
349
345
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
350
- log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid)
346
+ log(DEBUG, "Node %d: stage 0 completes. uploading public keys...", state.nid)
351
347
  return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
352
348
 
353
349
 
@@ -357,7 +353,7 @@ def _share_keys(
357
353
  ) -> Dict[str, ConfigsRecordValues]:
358
354
  named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
359
355
  key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
360
- log(INFO, "Node %d: starting stage 1...", state.nid)
356
+ log(DEBUG, "Node %d: starting stage 1...", state.nid)
361
357
  state.public_keys_dict = key_dict
362
358
 
363
359
  # Check if the size is larger than threshold
@@ -409,17 +405,18 @@ def _share_keys(
409
405
  dsts.append(nid)
410
406
  ciphertexts.append(ciphertext)
411
407
 
412
- log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid)
408
+ log(DEBUG, "Node %d: stage 1 completes. uploading key shares...", state.nid)
413
409
  return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
414
410
 
415
411
 
416
412
  # pylint: disable-next=too-many-locals
417
- def _collect_masked_input(
413
+ def _collect_masked_vectors(
418
414
  state: SecAggPlusState,
419
415
  configs: ConfigsRecord,
420
- fit: Callable[[], FitRes],
416
+ num_examples: int,
417
+ updated_parameters: Parameters,
421
418
  ) -> Dict[str, ConfigsRecordValues]:
422
- log(INFO, "Node %d: starting stage 2...", state.nid)
419
+ log(DEBUG, "Node %d: starting stage 2...", state.nid)
423
420
  available_clients: List[int] = []
424
421
  ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
425
422
  srcs = cast(List[int], configs[Key.SOURCE_LIST])
@@ -447,26 +444,20 @@ def _collect_masked_input(
447
444
  state.rd_seed_share_dict[src] = rd_seed_share
448
445
  state.sk1_share_dict[src] = sk1_share
449
446
 
450
- # Fit client
451
- fit_res = fit()
452
- if len(fit_res.metrics) > 0:
453
- log(
454
- WARNING,
455
- "The metrics in FitRes will not be preserved or sent to the server.",
456
- )
457
- ratio = fit_res.num_examples / state.max_weight
447
+ # Fit
448
+ ratio = num_examples / state.max_weight
458
449
  if ratio > 1:
459
450
  log(
460
451
  WARNING,
461
452
  "Potential overflow warning: the provided weight (%s) exceeds the specified"
462
453
  " max_weight (%s). This may lead to overflow issues.",
463
- fit_res.num_examples,
454
+ num_examples,
464
455
  state.max_weight,
465
456
  )
466
457
  q_ratio = round(ratio * state.target_range)
467
458
  dq_ratio = q_ratio / state.target_range
468
459
 
469
- parameters = parameters_to_ndarrays(fit_res.parameters)
460
+ parameters = parameters_to_ndarrays(updated_parameters)
470
461
  parameters = parameters_multiply(parameters, dq_ratio)
471
462
 
472
463
  # Quantize parameter update (vector)
@@ -500,7 +491,7 @@ def _collect_masked_input(
500
491
 
501
492
  # Take mod of final weight update vector and return to server
502
493
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
503
- log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
494
+ log(DEBUG, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
504
495
  return {
505
496
  Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
506
497
  }
@@ -509,7 +500,7 @@ def _collect_masked_input(
509
500
  def _unmask(
510
501
  state: SecAggPlusState, configs: ConfigsRecord
511
502
  ) -> Dict[str, ConfigsRecordValues]:
512
- log(INFO, "Node %d: starting stage 3...", state.nid)
503
+ log(DEBUG, "Node %d: starting stage 3...", state.nid)
513
504
 
514
505
  active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST])
515
506
  dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST])
@@ -523,5 +514,5 @@ def _unmask(
523
514
  shares += [state.rd_seed_share_dict[nid] for nid in active_nids]
524
515
  shares += [state.sk1_share_dict[nid] for nid in dead_nids]
525
516
 
526
- log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid)
517
+ log(DEBUG, "Node %d: stage 3 completes. uploading key shares...", state.nid)
527
518
  return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}
flwr/common/logger.py CHANGED
@@ -168,11 +168,10 @@ def warn_experimental_feature(name: str) -> None:
168
168
  """Warn the user when they use an experimental feature."""
169
169
  log(
170
170
  WARN,
171
- """
172
- EXPERIMENTAL FEATURE: %s
171
+ """EXPERIMENTAL FEATURE: %s
173
172
 
174
- This is an experimental feature. It could change significantly or be removed
175
- entirely in future versions of Flower.
173
+ This is an experimental feature. It could change significantly or be removed
174
+ entirely in future versions of Flower.
176
175
  """,
177
176
  name,
178
177
  )
@@ -182,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None:
182
181
  """Warn the user when they use a deprecated feature."""
183
182
  log(
184
183
  WARN,
185
- """
186
- DEPRECATED FEATURE: %s
184
+ """DEPRECATED FEATURE: %s
187
185
 
188
- This is a deprecated feature. It will be removed
189
- entirely in future versions of Flower.
186
+ This is a deprecated feature. It will be removed
187
+ entirely in future versions of Flower.
190
188
  """,
191
189
  name,
192
190
  )
@@ -0,0 +1,41 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Validates the project's name property."""
16
+
17
+ import re
18
+
19
+
20
+ def validate_project_name(name: str) -> bool:
21
+ """Validate the project name against PEP 621 and PEP 503 specifications.
22
+
23
+ Conventions at a glance:
24
+ - Must be lowercase
25
+ - Must not contain special characters
26
+ - Must use hyphens(recommended) or underscores. No spaces.
27
+ - Recommended to be no more than 40 characters long (But it can be)
28
+
29
+ Parameters
30
+ ----------
31
+ name : str
32
+ The project name to validate.
33
+
34
+ Returns
35
+ -------
36
+ bool
37
+ True if the name is valid, False otherwise.
38
+ """
39
+ if not name or len(name) > 40 or not re.match(r"^[a-z0-9-_]+$", name):
40
+ return False
41
+ return True
@@ -27,9 +27,9 @@ class Stage:
27
27
 
28
28
  SETUP = "setup"
29
29
  SHARE_KEYS = "share_keys"
30
- COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ COLLECT_MASKED_VECTORS = "collect_masked_vectors"
31
31
  UNMASK = "unmask"
32
- _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK)
33
33
 
34
34
  @classmethod
35
35
  def all(cls) -> tuple[str, str, str, str]:
@@ -122,7 +122,8 @@ class InMemoryState(State):
122
122
  task_res.task_id = str(task_id)
123
123
  task_res.task.created_at = created_at.isoformat()
124
124
  task_res.task.ttl = ttl.isoformat()
125
- self.task_res_store[task_id] = task_res
125
+ with self.lock:
126
+ self.task_res_store[task_id] = task_res
126
127
 
127
128
  # Return the new task_id
128
129
  return task_id
@@ -132,46 +133,47 @@ class InMemoryState(State):
132
133
  if limit is not None and limit < 1:
133
134
  raise AssertionError("`limit` must be >= 1")
134
135
 
135
- # Find TaskRes that were not delivered yet
136
- task_res_list: List[TaskRes] = []
137
- for _, task_res in self.task_res_store.items():
138
- if (
139
- UUID(task_res.task.ancestry[0]) in task_ids
140
- and task_res.task.delivered_at == ""
141
- ):
142
- task_res_list.append(task_res)
143
- if limit and len(task_res_list) == limit:
144
- break
136
+ with self.lock:
137
+ # Find TaskRes that were not delivered yet
138
+ task_res_list: List[TaskRes] = []
139
+ for _, task_res in self.task_res_store.items():
140
+ if (
141
+ UUID(task_res.task.ancestry[0]) in task_ids
142
+ and task_res.task.delivered_at == ""
143
+ ):
144
+ task_res_list.append(task_res)
145
+ if limit and len(task_res_list) == limit:
146
+ break
145
147
 
146
- # Mark all of them as delivered
147
- delivered_at = now().isoformat()
148
- for task_res in task_res_list:
149
- task_res.task.delivered_at = delivered_at
148
+ # Mark all of them as delivered
149
+ delivered_at = now().isoformat()
150
+ for task_res in task_res_list:
151
+ task_res.task.delivered_at = delivered_at
150
152
 
151
- # Return TaskRes
152
- return task_res_list
153
+ # Return TaskRes
154
+ return task_res_list
153
155
 
154
156
  def delete_tasks(self, task_ids: Set[UUID]) -> None:
155
157
  """Delete all delivered TaskIns/TaskRes pairs."""
156
158
  task_ins_to_be_deleted: Set[UUID] = set()
157
159
  task_res_to_be_deleted: Set[UUID] = set()
158
160
 
159
- for task_ins_id in task_ids:
160
- # Find the task_id of the matching task_res
161
- for task_res_id, task_res in self.task_res_store.items():
162
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
163
- continue
164
- if task_res.task.delivered_at == "":
165
- continue
166
-
167
- task_ins_to_be_deleted.add(task_ins_id)
168
- task_res_to_be_deleted.add(task_res_id)
169
-
170
- for task_id in task_ins_to_be_deleted:
171
- with self.lock:
161
+ with self.lock:
162
+ for task_ins_id in task_ids:
163
+ # Find the task_id of the matching task_res
164
+ for task_res_id, task_res in self.task_res_store.items():
165
+ if UUID(task_res.task.ancestry[0]) != task_ins_id:
166
+ continue
167
+ if task_res.task.delivered_at == "":
168
+ continue
169
+
170
+ task_ins_to_be_deleted.add(task_ins_id)
171
+ task_res_to_be_deleted.add(task_res_id)
172
+
173
+ for task_id in task_ins_to_be_deleted:
172
174
  del self.task_ins_store[task_id]
173
- for task_id in task_res_to_be_deleted:
174
- del self.task_res_store[task_id]
175
+ for task_id in task_res_to_be_deleted:
176
+ del self.task_res_store[task_id]
175
177
 
176
178
  def num_task_ins(self) -> int:
177
179
  """Calculate the number of task_ins in store.
@@ -16,9 +16,10 @@
16
16
 
17
17
 
18
18
  from .default_workflows import DefaultWorkflow
19
- from .secure_aggregation import SecAggPlusWorkflow
19
+ from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow
20
20
 
21
21
  __all__ = [
22
22
  "DefaultWorkflow",
23
23
  "SecAggPlusWorkflow",
24
+ "SecAggWorkflow",
24
25
  ]