flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
flwr/cli/new/new.py CHANGED
@@ -22,7 +22,7 @@ from typing import Dict, Optional
22
22
  import typer
23
23
  from typing_extensions import Annotated
24
24
 
25
- from ..utils import prompt_options
25
+ from ..utils import prompt_options, prompt_text
26
26
 
27
27
 
28
28
  class MlFramework(str, Enum):
@@ -72,9 +72,9 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
72
72
 
73
73
  def new(
74
74
  project_name: Annotated[
75
- str,
75
+ Optional[str],
76
76
  typer.Argument(metavar="project_name", help="The name of the project"),
77
- ],
77
+ ] = None,
78
78
  framework: Annotated[
79
79
  Optional[MlFramework],
80
80
  typer.Option(case_sensitive=False, help="The ML framework to use"),
@@ -83,6 +83,9 @@ def new(
83
83
  """Create new Flower project."""
84
84
  print(f"Creating Flower project {project_name}...")
85
85
 
86
+ if project_name is None:
87
+ project_name = prompt_text("Please provide project name")
88
+
86
89
  if framework is not None:
87
90
  framework_str = str(framework.value)
88
91
  else:
flwr/cli/utils.py CHANGED
@@ -14,11 +14,24 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface utils."""
16
16
 
17
- from typing import List
17
+ from typing import List, cast
18
18
 
19
19
  import typer
20
20
 
21
21
 
22
+ def prompt_text(text: str) -> str:
23
+ """Ask user to enter text input."""
24
+ while True:
25
+ result = typer.prompt(
26
+ typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True)
27
+ )
28
+ if len(result) > 0:
29
+ break
30
+ print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True))
31
+
32
+ return cast(str, result)
33
+
34
+
22
35
  def prompt_options(text: str, options: List[str]) -> str:
23
36
  """Ask user to select one of the given options and return the selected item."""
24
37
  # Turn options into a list with index as in " [ 0] quickstart-pytorch"
flwr/client/app.py CHANGED
@@ -456,7 +456,19 @@ def _start_client_internal(
456
456
  time.sleep(3) # Wait for 3s before asking again
457
457
  continue
458
458
 
459
- log(INFO, "Received message")
459
+ log(INFO, "")
460
+ log(
461
+ INFO,
462
+ "[RUN %s, ROUND %s]",
463
+ message.metadata.run_id,
464
+ message.metadata.group_id,
465
+ )
466
+ log(
467
+ INFO,
468
+ "Received: %s message %s",
469
+ message.metadata.message_type,
470
+ message.metadata.message_id,
471
+ )
460
472
 
461
473
  # Handle control message
462
474
  out_message, sleep_duration = handle_control_message(message)
@@ -484,7 +496,18 @@ def _start_client_internal(
484
496
 
485
497
  # Send
486
498
  send(out_message)
487
- log(INFO, "Sent reply")
499
+ log(
500
+ INFO,
501
+ "[RUN %s, ROUND %s]",
502
+ out_message.metadata.run_id,
503
+ out_message.metadata.group_id,
504
+ )
505
+ log(
506
+ INFO,
507
+ "Sent: %s reply to message %s",
508
+ out_message.metadata.message_type,
509
+ message.metadata.message_id,
510
+ )
488
511
 
489
512
  # Unregister node
490
513
  if delete_node is not None:
@@ -17,7 +17,7 @@
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
19
  from .localdp_mod import LocalDpMod
20
- from .secure_aggregation.secaggplus_mod import secaggplus_mod
20
+ from .secure_aggregation import secagg_mod, secaggplus_mod
21
21
  from .utils import make_ffn
22
22
 
23
23
  __all__ = [
@@ -25,5 +25,6 @@ __all__ = [
25
25
  "fixedclipping_mod",
26
26
  "LocalDpMod",
27
27
  "make_ffn",
28
+ "secagg_mod",
28
29
  "secaggplus_mod",
29
30
  ]
@@ -15,8 +15,10 @@
15
15
  """Secure Aggregation mods."""
16
16
 
17
17
 
18
+ from .secagg_mod import secagg_mod
18
19
  from .secaggplus_mod import secaggplus_mod
19
20
 
20
21
  __all__ = [
22
+ "secagg_mod",
21
23
  "secaggplus_mod",
22
24
  ]
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Modifier for the SecAgg protocol."""
16
+
17
+
18
+ from flwr.client.typing import ClientAppCallable
19
+ from flwr.common import Context, Message
20
+
21
+ from .secaggplus_mod import secaggplus_mod
22
+
23
+
24
+ def secagg_mod(
25
+ msg: Message,
26
+ ctxt: Context,
27
+ call_next: ClientAppCallable,
28
+ ) -> Message:
29
+ """Handle incoming message and return results, following the SecAgg protocol."""
30
+ return secaggplus_mod(msg, ctxt, call_next)
@@ -12,19 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Message handler for the SecAgg+ protocol."""
15
+ """Modifier for the SecAgg+ protocol."""
16
16
 
17
17
 
18
18
  import os
19
19
  from dataclasses import dataclass, field
20
- from logging import INFO, WARNING
21
- from typing import Any, Callable, Dict, List, Tuple, cast
20
+ from logging import DEBUG, WARNING
21
+ from typing import Any, Dict, List, Tuple, cast
22
22
 
23
23
  from flwr.client.typing import ClientAppCallable
24
24
  from flwr.common import (
25
25
  ConfigsRecord,
26
26
  Context,
27
27
  Message,
28
+ Parameters,
28
29
  RecordSet,
29
30
  ndarray_to_bytes,
30
31
  parameters_to_ndarrays,
@@ -62,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
62
63
  share_keys_plaintext_concat,
63
64
  share_keys_plaintext_separate,
64
65
  )
65
- from flwr.common.typing import ConfigsRecordValues, FitRes
66
+ from flwr.common.typing import ConfigsRecordValues
66
67
 
67
68
 
68
69
  @dataclass
@@ -132,18 +133,6 @@ class SecAggPlusState:
132
133
  return ret
133
134
 
134
135
 
135
- def _get_fit_fn(
136
- msg: Message, ctxt: Context, call_next: ClientAppCallable
137
- ) -> Callable[[], FitRes]:
138
- """Get the fit function."""
139
-
140
- def fit() -> FitRes:
141
- out_msg = call_next(msg, ctxt)
142
- return compat.recordset_to_fitres(out_msg.content, keep_input=False)
143
-
144
- return fit
145
-
146
-
147
136
  def secaggplus_mod(
148
137
  msg: Message,
149
138
  ctxt: Context,
@@ -173,25 +162,32 @@ def secaggplus_mod(
173
162
  check_configs(state.current_stage, configs)
174
163
 
175
164
  # Execute
165
+ out_content = RecordSet()
176
166
  if state.current_stage == Stage.SETUP:
177
167
  state.nid = msg.metadata.dst_node_id
178
168
  res = _setup(state, configs)
179
169
  elif state.current_stage == Stage.SHARE_KEYS:
180
170
  res = _share_keys(state, configs)
181
- elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
182
- fit = _get_fit_fn(msg, ctxt, call_next)
183
- res = _collect_masked_input(state, configs, fit)
171
+ elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
172
+ out_msg = call_next(msg, ctxt)
173
+ out_content = out_msg.content
174
+ fitres = compat.recordset_to_fitres(out_content, keep_input=True)
175
+ res = _collect_masked_vectors(
176
+ state, configs, fitres.num_examples, fitres.parameters
177
+ )
178
+ for p_record in out_content.parameters_records.values():
179
+ p_record.clear()
184
180
  elif state.current_stage == Stage.UNMASK:
185
181
  res = _unmask(state, configs)
186
182
  else:
187
- raise ValueError(f"Unknown secagg stage: {state.current_stage}")
183
+ raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
188
184
 
189
185
  # Save state
190
186
  ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())
191
187
 
192
188
  # Return message
193
- content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)})
194
- return msg.create_reply(content, ttl="")
189
+ out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
190
+ return msg.create_reply(out_content, ttl="")
195
191
 
196
192
 
197
193
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
@@ -199,7 +195,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
199
195
  # Check the existence of Config.STAGE
200
196
  if Key.STAGE not in configs:
201
197
  raise KeyError(
202
- f"The required key '{Key.STAGE}' is missing from the input `named_values`."
198
+ f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
203
199
  )
204
200
 
205
201
  # Check the value type of the Config.STAGE
@@ -215,7 +211,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
215
211
  if current_stage != Stage.UNMASK:
216
212
  log(WARNING, "Restart from the setup stage")
217
213
  # If stage is not "setup",
218
- # the stage from `named_values` should be the expected next stage
214
+ # the stage from configs should be the expected next stage
219
215
  else:
220
216
  stages = Stage.all()
221
217
  expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
@@ -229,7 +225,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
229
225
  # pylint: disable-next=too-many-branches
230
226
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
231
227
  """Check the validity of the configs."""
232
- # Check `named_values` for the setup stage
228
+ # Check configs for the setup stage
233
229
  if stage == Stage.SETUP:
234
230
  key_type_pairs = [
235
231
  (Key.SAMPLE_NUMBER, int),
@@ -243,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
243
239
  if key not in configs:
244
240
  raise KeyError(
245
241
  f"Stage {Stage.SETUP}: the required key '{key}' is "
246
- "missing from the input `named_values`."
242
+ "missing from the ConfigsRecord."
247
243
  )
248
244
  # Bool is a subclass of int in Python,
249
245
  # so `isinstance(v, int)` will return True even if v is a boolean.
@@ -266,7 +262,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
266
262
  f"Stage {Stage.SHARE_KEYS}: "
267
263
  f"the value for the key '{key}' must be a list of two bytes."
268
264
  )
269
- elif stage == Stage.COLLECT_MASKED_INPUT:
265
+ elif stage == Stage.COLLECT_MASKED_VECTORS:
270
266
  key_type_pairs = [
271
267
  (Key.CIPHERTEXT_LIST, bytes),
272
268
  (Key.SOURCE_LIST, int),
@@ -274,9 +270,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
274
270
  for key, expected_type in key_type_pairs:
275
271
  if key not in configs:
276
272
  raise KeyError(
277
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
273
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
278
274
  f"the required key '{key}' is "
279
- "missing from the input `named_values`."
275
+ "missing from the ConfigsRecord."
280
276
  )
281
277
  if not isinstance(configs[key], list) or any(
282
278
  elm
@@ -285,7 +281,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
285
281
  if type(elm) is not expected_type
286
282
  ):
287
283
  raise TypeError(
288
- f"Stage {Stage.COLLECT_MASKED_INPUT}: "
284
+ f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
289
285
  f"the value for the key '{key}' "
290
286
  f"must be of type List[{expected_type.__name__}]"
291
287
  )
@@ -299,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
299
295
  raise KeyError(
300
296
  f"Stage {Stage.UNMASK}: "
301
297
  f"the required key '{key}' is "
302
- "missing from the input `named_values`."
298
+ "missing from the ConfigsRecord."
303
299
  )
304
300
  if not isinstance(configs[key], list) or any(
305
301
  elm
@@ -322,7 +318,7 @@ def _setup(
322
318
  # Assigning parameter values to object fields
323
319
  sec_agg_param_dict = configs
324
320
  state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
325
- log(INFO, "Node %d: starting stage 0...", state.nid)
321
+ log(DEBUG, "Node %d: starting stage 0...", state.nid)
326
322
 
327
323
  state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
324
  state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
@@ -347,7 +343,7 @@ def _setup(
347
343
 
348
344
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
349
345
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
350
- log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid)
346
+ log(DEBUG, "Node %d: stage 0 completes. uploading public keys...", state.nid)
351
347
  return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
352
348
 
353
349
 
@@ -357,7 +353,7 @@ def _share_keys(
357
353
  ) -> Dict[str, ConfigsRecordValues]:
358
354
  named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
359
355
  key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
360
- log(INFO, "Node %d: starting stage 1...", state.nid)
356
+ log(DEBUG, "Node %d: starting stage 1...", state.nid)
361
357
  state.public_keys_dict = key_dict
362
358
 
363
359
  # Check if the size is larger than threshold
@@ -409,17 +405,18 @@ def _share_keys(
409
405
  dsts.append(nid)
410
406
  ciphertexts.append(ciphertext)
411
407
 
412
- log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid)
408
+ log(DEBUG, "Node %d: stage 1 completes. uploading key shares...", state.nid)
413
409
  return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
414
410
 
415
411
 
416
412
  # pylint: disable-next=too-many-locals
417
- def _collect_masked_input(
413
+ def _collect_masked_vectors(
418
414
  state: SecAggPlusState,
419
415
  configs: ConfigsRecord,
420
- fit: Callable[[], FitRes],
416
+ num_examples: int,
417
+ updated_parameters: Parameters,
421
418
  ) -> Dict[str, ConfigsRecordValues]:
422
- log(INFO, "Node %d: starting stage 2...", state.nid)
419
+ log(DEBUG, "Node %d: starting stage 2...", state.nid)
423
420
  available_clients: List[int] = []
424
421
  ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
425
422
  srcs = cast(List[int], configs[Key.SOURCE_LIST])
@@ -447,26 +444,20 @@ def _collect_masked_input(
447
444
  state.rd_seed_share_dict[src] = rd_seed_share
448
445
  state.sk1_share_dict[src] = sk1_share
449
446
 
450
- # Fit client
451
- fit_res = fit()
452
- if len(fit_res.metrics) > 0:
453
- log(
454
- WARNING,
455
- "The metrics in FitRes will not be preserved or sent to the server.",
456
- )
457
- ratio = fit_res.num_examples / state.max_weight
447
+ # Fit
448
+ ratio = num_examples / state.max_weight
458
449
  if ratio > 1:
459
450
  log(
460
451
  WARNING,
461
452
  "Potential overflow warning: the provided weight (%s) exceeds the specified"
462
453
  " max_weight (%s). This may lead to overflow issues.",
463
- fit_res.num_examples,
454
+ num_examples,
464
455
  state.max_weight,
465
456
  )
466
457
  q_ratio = round(ratio * state.target_range)
467
458
  dq_ratio = q_ratio / state.target_range
468
459
 
469
- parameters = parameters_to_ndarrays(fit_res.parameters)
460
+ parameters = parameters_to_ndarrays(updated_parameters)
470
461
  parameters = parameters_multiply(parameters, dq_ratio)
471
462
 
472
463
  # Quantize parameter update (vector)
@@ -500,7 +491,7 @@ def _collect_masked_input(
500
491
 
501
492
  # Take mod of final weight update vector and return to server
502
493
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
503
- log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
494
+ log(DEBUG, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
504
495
  return {
505
496
  Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
506
497
  }
@@ -509,7 +500,7 @@ def _collect_masked_input(
509
500
  def _unmask(
510
501
  state: SecAggPlusState, configs: ConfigsRecord
511
502
  ) -> Dict[str, ConfigsRecordValues]:
512
- log(INFO, "Node %d: starting stage 3...", state.nid)
503
+ log(DEBUG, "Node %d: starting stage 3...", state.nid)
513
504
 
514
505
  active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST])
515
506
  dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST])
@@ -523,5 +514,5 @@ def _unmask(
523
514
  shares += [state.rd_seed_share_dict[nid] for nid in active_nids]
524
515
  shares += [state.sk1_share_dict[nid] for nid in dead_nids]
525
516
 
526
- log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid)
517
+ log(DEBUG, "Node %d: stage 3 completes. uploading key shares...", state.nid)
527
518
  return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}
flwr/common/logger.py CHANGED
@@ -168,11 +168,10 @@ def warn_experimental_feature(name: str) -> None:
168
168
  """Warn the user when they use an experimental feature."""
169
169
  log(
170
170
  WARN,
171
- """
172
- EXPERIMENTAL FEATURE: %s
171
+ """EXPERIMENTAL FEATURE: %s
173
172
 
174
- This is an experimental feature. It could change significantly or be removed
175
- entirely in future versions of Flower.
173
+ This is an experimental feature. It could change significantly or be removed
174
+ entirely in future versions of Flower.
176
175
  """,
177
176
  name,
178
177
  )
@@ -182,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None:
182
181
  """Warn the user when they use a deprecated feature."""
183
182
  log(
184
183
  WARN,
185
- """
186
- DEPRECATED FEATURE: %s
184
+ """DEPRECATED FEATURE: %s
187
185
 
188
- This is a deprecated feature. It will be removed
189
- entirely in future versions of Flower.
186
+ This is a deprecated feature. It will be removed
187
+ entirely in future versions of Flower.
190
188
  """,
191
189
  name,
192
190
  )
@@ -0,0 +1,41 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Validates the project's name property."""
16
+
17
+ import re
18
+
19
+
20
+ def validate_project_name(name: str) -> bool:
21
+ """Validate the project name against PEP 621 and PEP 503 specifications.
22
+
23
+ Conventions at a glance:
24
+ - Must be lowercase
25
+ - Must not contain special characters
26
+ - Must use hyphens(recommended) or underscores. No spaces.
27
+ - Recommended to be no more than 40 characters long (But it can be)
28
+
29
+ Parameters
30
+ ----------
31
+ name : str
32
+ The project name to validate.
33
+
34
+ Returns
35
+ -------
36
+ bool
37
+ True if the name is valid, False otherwise.
38
+ """
39
+ if not name or len(name) > 40 or not re.match(r"^[a-z0-9-_]+$", name):
40
+ return False
41
+ return True
@@ -27,9 +27,9 @@ class Stage:
27
27
 
28
28
  SETUP = "setup"
29
29
  SHARE_KEYS = "share_keys"
30
- COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ COLLECT_MASKED_VECTORS = "collect_masked_vectors"
31
31
  UNMASK = "unmask"
32
- _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK)
33
33
 
34
34
  @classmethod
35
35
  def all(cls) -> tuple[str, str, str, str]:
@@ -122,7 +122,8 @@ class InMemoryState(State):
122
122
  task_res.task_id = str(task_id)
123
123
  task_res.task.created_at = created_at.isoformat()
124
124
  task_res.task.ttl = ttl.isoformat()
125
- self.task_res_store[task_id] = task_res
125
+ with self.lock:
126
+ self.task_res_store[task_id] = task_res
126
127
 
127
128
  # Return the new task_id
128
129
  return task_id
@@ -132,46 +133,47 @@ class InMemoryState(State):
132
133
  if limit is not None and limit < 1:
133
134
  raise AssertionError("`limit` must be >= 1")
134
135
 
135
- # Find TaskRes that were not delivered yet
136
- task_res_list: List[TaskRes] = []
137
- for _, task_res in self.task_res_store.items():
138
- if (
139
- UUID(task_res.task.ancestry[0]) in task_ids
140
- and task_res.task.delivered_at == ""
141
- ):
142
- task_res_list.append(task_res)
143
- if limit and len(task_res_list) == limit:
144
- break
136
+ with self.lock:
137
+ # Find TaskRes that were not delivered yet
138
+ task_res_list: List[TaskRes] = []
139
+ for _, task_res in self.task_res_store.items():
140
+ if (
141
+ UUID(task_res.task.ancestry[0]) in task_ids
142
+ and task_res.task.delivered_at == ""
143
+ ):
144
+ task_res_list.append(task_res)
145
+ if limit and len(task_res_list) == limit:
146
+ break
145
147
 
146
- # Mark all of them as delivered
147
- delivered_at = now().isoformat()
148
- for task_res in task_res_list:
149
- task_res.task.delivered_at = delivered_at
148
+ # Mark all of them as delivered
149
+ delivered_at = now().isoformat()
150
+ for task_res in task_res_list:
151
+ task_res.task.delivered_at = delivered_at
150
152
 
151
- # Return TaskRes
152
- return task_res_list
153
+ # Return TaskRes
154
+ return task_res_list
153
155
 
154
156
  def delete_tasks(self, task_ids: Set[UUID]) -> None:
155
157
  """Delete all delivered TaskIns/TaskRes pairs."""
156
158
  task_ins_to_be_deleted: Set[UUID] = set()
157
159
  task_res_to_be_deleted: Set[UUID] = set()
158
160
 
159
- for task_ins_id in task_ids:
160
- # Find the task_id of the matching task_res
161
- for task_res_id, task_res in self.task_res_store.items():
162
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
163
- continue
164
- if task_res.task.delivered_at == "":
165
- continue
166
-
167
- task_ins_to_be_deleted.add(task_ins_id)
168
- task_res_to_be_deleted.add(task_res_id)
169
-
170
- for task_id in task_ins_to_be_deleted:
171
- with self.lock:
161
+ with self.lock:
162
+ for task_ins_id in task_ids:
163
+ # Find the task_id of the matching task_res
164
+ for task_res_id, task_res in self.task_res_store.items():
165
+ if UUID(task_res.task.ancestry[0]) != task_ins_id:
166
+ continue
167
+ if task_res.task.delivered_at == "":
168
+ continue
169
+
170
+ task_ins_to_be_deleted.add(task_ins_id)
171
+ task_res_to_be_deleted.add(task_res_id)
172
+
173
+ for task_id in task_ins_to_be_deleted:
172
174
  del self.task_ins_store[task_id]
173
- for task_id in task_res_to_be_deleted:
174
- del self.task_res_store[task_id]
175
+ for task_id in task_res_to_be_deleted:
176
+ del self.task_res_store[task_id]
175
177
 
176
178
  def num_task_ins(self) -> int:
177
179
  """Calculate the number of task_ins in store.
@@ -16,9 +16,10 @@
16
16
 
17
17
 
18
18
  from .default_workflows import DefaultWorkflow
19
- from .secure_aggregation import SecAggPlusWorkflow
19
+ from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow
20
20
 
21
21
  __all__ = [
22
22
  "DefaultWorkflow",
23
23
  "SecAggPlusWorkflow",
24
+ "SecAggWorkflow",
24
25
  ]