flwr-nightly 1.12.0.dev20241008__py3-none-any.whl → 1.13.0.dev20241015__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (52) hide show
  1. flwr/cli/build.py +60 -29
  2. flwr/cli/config_utils.py +10 -0
  3. flwr/cli/install.py +60 -20
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -5
  14. flwr/client/app.py +13 -3
  15. flwr/client/clientapp/app.py +3 -1
  16. flwr/client/clientapp/utils.py +11 -5
  17. flwr/client/grpc_adapter_client/connection.py +1 -1
  18. flwr/client/grpc_client/connection.py +1 -1
  19. flwr/client/grpc_rere_client/connection.py +4 -1
  20. flwr/client/node_state.py +1 -1
  21. flwr/client/rest_client/connection.py +1 -1
  22. flwr/client/supernode/app.py +12 -5
  23. flwr/common/config.py +19 -5
  24. flwr/common/logger.py +1 -1
  25. flwr/common/message.py +6 -1
  26. flwr/common/record/configsrecord.py +6 -0
  27. flwr/common/recordset_compat.py +10 -0
  28. flwr/common/retry_invoker.py +15 -0
  29. flwr/server/app.py +3 -2
  30. flwr/server/client_manager.py +2 -0
  31. flwr/server/driver/driver.py +1 -1
  32. flwr/server/driver/grpc_driver.py +1 -1
  33. flwr/server/driver/inmemory_driver.py +2 -2
  34. flwr/server/run_serverapp.py +11 -13
  35. flwr/server/server_app.py +1 -1
  36. flwr/server/strategy/dp_adaptive_clipping.py +2 -2
  37. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  38. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  39. flwr/server/superlink/driver/driver_servicer.py +1 -1
  40. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -3
  41. flwr/server/superlink/fleet/vce/vce_api.py +9 -6
  42. flwr/server/superlink/state/in_memory_state.py +26 -8
  43. flwr/server/superlink/state/sqlite_state.py +46 -11
  44. flwr/server/superlink/state/state.py +1 -7
  45. flwr/server/superlink/state/utils.py +0 -10
  46. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  47. flwr/simulation/run_simulation.py +4 -4
  48. {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/METADATA +1 -1
  49. {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/RECORD +52 -52
  50. {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/LICENSE +0 -0
  51. {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/WHEEL +0 -0
  52. {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/entry_points.txt +0 -0
flwr/common/config.py CHANGED
@@ -22,7 +22,12 @@ from typing import Any, Optional, Union, cast, get_args
22
22
  import tomli
23
23
 
24
24
  from flwr.cli.config_utils import get_fab_config, validate_fields
25
- from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
+ from flwr.common.constant import (
26
+ APP_DIR,
27
+ FAB_CONFIG_FILE,
28
+ FAB_HASH_TRUNCATION,
29
+ FLWR_HOME,
30
+ )
26
31
  from flwr.common.typing import Run, UserConfig, UserConfigValue
27
32
 
28
33
 
@@ -39,7 +44,10 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
39
44
 
40
45
 
41
46
  def get_project_dir(
42
- fab_id: str, fab_version: str, flwr_dir: Optional[Union[str, Path]] = None
47
+ fab_id: str,
48
+ fab_version: str,
49
+ fab_hash: str,
50
+ flwr_dir: Optional[Union[str, Path]] = None,
43
51
  ) -> Path:
44
52
  """Return the project directory based on the given fab_id and fab_version."""
45
53
  # Check the fab_id
@@ -50,7 +58,11 @@ def get_project_dir(
50
58
  publisher, project_name = fab_id.split("/")
51
59
  if flwr_dir is None:
52
60
  flwr_dir = get_flwr_dir()
53
- return Path(flwr_dir) / APP_DIR / publisher / project_name / fab_version
61
+ return (
62
+ Path(flwr_dir)
63
+ / APP_DIR
64
+ / f"{publisher}.{project_name}.{fab_version}.{fab_hash[:FAB_HASH_TRUNCATION]}"
65
+ )
54
66
 
55
67
 
56
68
  def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
@@ -127,7 +139,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
127
139
  if not run.fab_id or not run.fab_version:
128
140
  return {}
129
141
 
130
- project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
142
+ project_dir = get_project_dir(run.fab_id, run.fab_version, run.fab_hash, flwr_dir)
131
143
 
132
144
  # Return empty dict if project directory does not exist
133
145
  if not project_dir.is_dir():
@@ -194,6 +206,7 @@ def parse_config_args(
194
206
  # Regular expression to capture key-value pairs with possible quoted values
195
207
  pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)")
196
208
 
209
+ flat_overrides = {}
197
210
  for config_line in config:
198
211
  if config_line:
199
212
  # .toml files aren't allowed alongside other configs
@@ -205,8 +218,9 @@ def parse_config_args(
205
218
  matches = pattern.findall(config_line)
206
219
  toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
207
220
  overrides.update(tomli.loads(toml_str))
221
+ flat_overrides = flatten_dict(overrides)
208
222
 
209
- return overrides
223
+ return flat_overrides
210
224
 
211
225
 
212
226
  def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
flwr/common/logger.py CHANGED
@@ -111,7 +111,7 @@ FLOWER_LOGGER.addHandler(console_handler)
111
111
  class CustomHTTPHandler(HTTPHandler):
112
112
  """Custom HTTPHandler which overrides the mapLogRecords method."""
113
113
 
114
- # pylint: disable=too-many-arguments,bad-option-value,R1725
114
+ # pylint: disable=too-many-arguments,bad-option-value,R1725,R0917
115
115
  def __init__(
116
116
  self,
117
117
  identifier: str,
flwr/common/message.py CHANGED
@@ -52,7 +52,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
52
52
  the receiving end.
53
53
  """
54
54
 
55
- def __init__( # pylint: disable=too-many-arguments
55
+ def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
56
56
  self,
57
57
  run_id: int,
58
58
  message_id: str,
@@ -290,6 +290,11 @@ class Message:
290
290
  follows the equation:
291
291
 
292
292
  ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
293
+
294
+ Returns
295
+ -------
296
+ message : Message
297
+ A Message containing only the relevant error and metadata.
293
298
  """
294
299
  # If no TTL passed, use default for message creation (will update after
295
300
  # message creation)
@@ -128,6 +128,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
128
128
 
129
129
  def get_var_bytes(value: ConfigsScalar) -> int:
130
130
  """Return Bytes of value passed."""
131
+ var_bytes = 0
131
132
  if isinstance(value, bool):
132
133
  var_bytes = 1
133
134
  elif isinstance(value, (int, float)):
@@ -136,6 +137,11 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
136
137
  )
137
138
  if isinstance(value, (str, bytes)):
138
139
  var_bytes = len(value)
140
+ if var_bytes == 0:
141
+ raise ValueError(
142
+ "Config values must be either `bool`, `int`, `float`, "
143
+ "`str`, or `bytes`"
144
+ )
139
145
  return var_bytes
140
146
 
141
147
  num_bytes = 0
@@ -59,6 +59,11 @@ def parametersrecord_to_parameters(
59
59
  keep_input : bool
60
60
  A boolean indicating whether entries in the record should be deleted from the
61
61
  input dictionary immediately after adding them to the record.
62
+
63
+ Returns
64
+ -------
65
+ parameters : Parameters
66
+ The parameters in the legacy format Parameters.
62
67
  """
63
68
  parameters = Parameters(tensors=[], tensor_type="")
64
69
 
@@ -94,6 +99,11 @@ def parameters_to_parametersrecord(
94
99
  A boolean indicating whether parameters should be deleted from the input
95
100
  Parameters object (i.e. a list of serialized NumPy arrays) immediately after
96
101
  adding them to the record.
102
+
103
+ Returns
104
+ -------
105
+ ParametersRecord
106
+ The ParametersRecord containing the provided parameters.
97
107
  """
98
108
  tensor_type = parameters.tensor_type
99
109
 
@@ -38,6 +38,11 @@ def exponential(
38
38
  Factor by which the delay is multiplied after each retry.
39
39
  max_delay: Optional[float] (default: None)
40
40
  The maximum delay duration between two consecutive retries.
41
+
42
+ Returns
43
+ -------
44
+ Generator[float, None, None]
45
+ A generator for the delay between 2 retries.
41
46
  """
42
47
  delay = base_delay if max_delay is None else min(base_delay, max_delay)
43
48
  while True:
@@ -56,6 +61,11 @@ def constant(
56
61
  ----------
57
62
  interval: Union[float, Iterable[float]] (default: 1)
58
63
  A constant value to yield or an iterable of such values.
64
+
65
+ Returns
66
+ -------
67
+ Generator[float, None, None]
68
+ A generator for the delay between 2 retries.
59
69
  """
60
70
  if not isinstance(interval, Iterable):
61
71
  interval = itertools.repeat(interval)
@@ -73,6 +83,11 @@ def full_jitter(max_value: float) -> float:
73
83
  ----------
74
84
  max_value : float
75
85
  The upper limit for the randomized value.
86
+
87
+ Returns
88
+ -------
89
+ float
90
+ A random float that is less than max_value.
76
91
  """
77
92
  return random.uniform(0, max_value)
78
93
 
flwr/server/app.py CHANGED
@@ -199,12 +199,12 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
199
199
  # pylint: disable=too-many-branches, too-many-locals, too-many-statements
200
200
  def run_superlink() -> None:
201
201
  """Run Flower SuperLink (Driver API and Fleet API)."""
202
+ args = _parse_args_run_superlink().parse_args()
203
+
202
204
  log(INFO, "Starting Flower SuperLink")
203
205
 
204
206
  event(EventType.RUN_SUPERLINK_ENTER)
205
207
 
206
- args = _parse_args_run_superlink().parse_args()
207
-
208
208
  # Parse IP address
209
209
  driver_address, _, _ = _format_address(args.driver_api_address)
210
210
 
@@ -542,6 +542,7 @@ def _run_fleet_api_grpc_adapter(
542
542
 
543
543
 
544
544
  # pylint: disable=import-outside-toplevel,too-many-arguments
545
+ # pylint: disable=too-many-positional-arguments
545
546
  def _run_fleet_api_rest(
546
547
  host: str,
547
548
  port: int,
@@ -47,6 +47,7 @@ class ClientManager(ABC):
47
47
  Parameters
48
48
  ----------
49
49
  client : flwr.server.client_proxy.ClientProxy
50
+ The ClientProxy of the Client to register.
50
51
 
51
52
  Returns
52
53
  -------
@@ -64,6 +65,7 @@ class ClientManager(ABC):
64
65
  Parameters
65
66
  ----------
66
67
  client : flwr.server.client_proxy.ClientProxy
68
+ The ClientProxy of the Client to unregister.
67
69
  """
68
70
 
69
71
  @abstractmethod
@@ -32,7 +32,7 @@ class Driver(ABC):
32
32
  """Run information."""
33
33
 
34
34
  @abstractmethod
35
- def create_message( # pylint: disable=too-many-arguments
35
+ def create_message( # pylint: disable=too-many-arguments,R0917
36
36
  self,
37
37
  content: RecordSet,
38
38
  message_type: str,
@@ -158,7 +158,7 @@ class GrpcDriver(Driver):
158
158
  ):
159
159
  raise ValueError(f"Invalid message: {message}")
160
160
 
161
- def create_message( # pylint: disable=too-many-arguments
161
+ def create_message( # pylint: disable=too-many-arguments,R0917
162
162
  self,
163
163
  content: RecordSet,
164
164
  message_type: str,
@@ -82,7 +82,7 @@ class InMemoryDriver(Driver):
82
82
  self._init_run()
83
83
  return Run(**vars(cast(Run, self._run)))
84
84
 
85
- def create_message( # pylint: disable=too-many-arguments
85
+ def create_message( # pylint: disable=too-many-arguments,R0917
86
86
  self,
87
87
  content: RecordSet,
88
88
  message_type: str,
@@ -150,7 +150,7 @@ class InMemoryDriver(Driver):
150
150
  """
151
151
  msg_ids = {UUID(msg_id) for msg_id in message_ids}
152
152
  # Pull TaskRes
153
- task_res_list = self.state.get_task_res(task_ids=msg_ids, limit=len(msg_ids))
153
+ task_res_list = self.state.get_task_res(task_ids=msg_ids)
154
154
  # Delete tasks in state
155
155
  self.state.delete_tasks(msg_ids)
156
156
  # Convert TaskRes to Message
@@ -181,19 +181,17 @@ def run_server_app() -> None:
181
181
  )
182
182
  flwr_dir = get_flwr_dir(args.flwr_dir)
183
183
  run_ = driver.run
184
- if run_.fab_hash:
185
- fab_req = GetFabRequest(hash_str=run_.fab_hash)
186
- # pylint: disable-next=W0212
187
- fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
188
- if fab_res.fab.hash_str != run_.fab_hash:
189
- raise ValueError("FAB hashes don't match.")
190
-
191
- install_from_fab(fab_res.fab.content, flwr_dir, True)
192
- fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
193
- else:
194
- fab_id, fab_version = run_.fab_id, run_.fab_version
195
-
196
- app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
184
+ if not run_.fab_hash:
185
+ raise ValueError("FAB hash not provided.")
186
+ fab_req = GetFabRequest(hash_str=run_.fab_hash)
187
+ # pylint: disable-next=W0212
188
+ fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
189
+ if fab_res.fab.hash_str != run_.fab_hash:
190
+ raise ValueError("FAB hashes don't match.")
191
+ install_from_fab(fab_res.fab.content, flwr_dir, True)
192
+ fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
193
+
194
+ app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
197
195
  config = get_project_config(app_path)
198
196
  else:
199
197
  # User provided `app_dir`, but not `--run-id`
flwr/server/server_app.py CHANGED
@@ -71,7 +71,7 @@ class ServerApp:
71
71
  >>> print("ServerApp running")
72
72
  """
73
73
 
74
- # pylint: disable=too-many-arguments
74
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
75
75
  def __init__(
76
76
  self,
77
77
  server: Optional[Server] = None,
@@ -88,7 +88,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
88
88
  >>> )
89
89
  """
90
90
 
91
- # pylint: disable=too-many-arguments,too-many-instance-attributes
91
+ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
92
92
  def __init__(
93
93
  self,
94
94
  strategy: Strategy,
@@ -307,7 +307,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
307
307
  >>> )
308
308
  """
309
309
 
310
- # pylint: disable=too-many-arguments,too-many-instance-attributes
310
+ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
311
311
  def __init__(
312
312
  self,
313
313
  strategy: Strategy,
@@ -39,7 +39,7 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
39
39
  This class is deprecated and will be removed in a future release.
40
40
  """
41
41
 
42
- # pylint: disable=too-many-arguments,too-many-instance-attributes
42
+ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
43
43
  def __init__(
44
44
  self,
45
45
  strategy: Strategy,
@@ -36,7 +36,7 @@ class DPFedAvgFixed(Strategy):
36
36
  This class is deprecated and will be removed in a future release.
37
37
  """
38
38
 
39
- # pylint: disable=too-many-arguments,too-many-instance-attributes
39
+ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
40
40
  def __init__(
41
41
  self,
42
42
  strategy: Strategy,
@@ -155,7 +155,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
155
155
  context.add_callback(on_rpc_done)
156
156
 
157
157
  # Read from state
158
- task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids, limit=None)
158
+ task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
159
159
 
160
160
  context.set_code(grpc.StatusCode.OK)
161
161
  return PullTaskResResponse(task_res_list=task_res_list)
@@ -60,7 +60,7 @@ def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool:
60
60
  return is_valid
61
61
 
62
62
 
63
- def start_grpc_server( # pylint: disable=too-many-arguments
63
+ def start_grpc_server( # pylint: disable=too-many-arguments,R0917
64
64
  client_manager: ClientManager,
65
65
  server_address: str,
66
66
  max_concurrent_workers: int = 1000,
@@ -156,7 +156,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments
156
156
  return server
157
157
 
158
158
 
159
- def generic_create_grpc_server( # pylint: disable=too-many-arguments
159
+ def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
160
160
  servicer_and_add_fn: Union[
161
161
  tuple[FleetServicer, AddServicerToServerFn],
162
162
  tuple[GrpcAdapterServicer, AddServicerToServerFn],
@@ -174,7 +174,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
174
174
 
175
175
  Parameters
176
176
  ----------
177
- servicer_and_add_fn : Tuple
177
+ servicer_and_add_fn : tuple
178
178
  A tuple holding a servicer implementation and a matching
179
179
  add_Servicer_to_server function.
180
180
  server_address : str
@@ -214,6 +214,8 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
214
214
  * CA certificate.
215
215
  * server certificate.
216
216
  * server private key.
217
+ interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None)
218
+ A list of gRPC interceptors.
217
219
 
218
220
  Returns
219
221
  -------
@@ -172,6 +172,7 @@ def put_taskres_into_state(
172
172
  pass
173
173
 
174
174
 
175
+ # pylint: disable=too-many-positional-arguments
175
176
  def run_api(
176
177
  app_fn: Callable[[], ClientApp],
177
178
  backend_fn: Callable[[], Backend],
@@ -251,7 +252,7 @@ def run_api(
251
252
 
252
253
 
253
254
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
254
- # pylint: disable=too-many-statements
255
+ # pylint: disable=too-many-statements,too-many-positional-arguments
255
256
  def start_vce(
256
257
  backend_name: str,
257
258
  backend_config_json_stream: str,
@@ -267,6 +268,8 @@ def start_vce(
267
268
  existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
268
269
  ) -> None:
269
270
  """Start Fleet API with the Simulation Engine."""
271
+ nodes_mapping = {}
272
+
270
273
  if client_app_attr is not None and client_app is not None:
271
274
  raise ValueError(
272
275
  "Both `client_app_attr` and `client_app` are provided, "
@@ -340,17 +343,17 @@ def start_vce(
340
343
  # Load ClientApp if needed
341
344
  def _load() -> ClientApp:
342
345
 
346
+ if client_app:
347
+ return client_app
343
348
  if client_app_attr:
344
- app = get_load_client_app_fn(
349
+ return get_load_client_app_fn(
345
350
  default_app_ref=client_app_attr,
346
351
  app_path=app_dir,
347
352
  flwr_dir=flwr_dir,
348
353
  multi_app=False,
349
- )(run.fab_id, run.fab_version)
354
+ )(run.fab_id, run.fab_version, run.fab_hash)
350
355
 
351
- if client_app:
352
- app = client_app
353
- return app
356
+ raise ValueError("Either `client_app_attr` or `client_app` must be provided")
354
357
 
355
358
  app_fn = _load
356
359
 
@@ -116,6 +116,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
116
116
  # Return TaskIns
117
117
  return task_ins_list
118
118
 
119
+ # pylint: disable=R0911
119
120
  def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
120
121
  """Store one TaskRes."""
121
122
  # Validate task
@@ -129,6 +130,17 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
129
130
  task_ins_id = task_res.task.ancestry[0]
130
131
  task_ins = self.task_ins_store.get(UUID(task_ins_id))
131
132
 
133
+ # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
134
+ if (
135
+ task_ins
136
+ and task_res
137
+ and not (
138
+ task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
139
+ )
140
+ and task_ins.task.consumer.node_id != task_res.task.producer.node_id
141
+ ):
142
+ return None
143
+
132
144
  if task_ins is None:
133
145
  log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id)
134
146
  return None
@@ -178,27 +190,33 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
178
190
  # Return the new task_id
179
191
  return task_id
180
192
 
181
- def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
193
+ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
182
194
  """Get all TaskRes that have not been delivered yet."""
183
- if limit is not None and limit < 1:
184
- raise AssertionError("`limit` must be >= 1")
185
-
186
195
  with self.lock:
187
196
  # Find TaskRes that were not delivered yet
188
197
  task_res_list: list[TaskRes] = []
189
198
  replied_task_ids: set[UUID] = set()
190
199
  for _, task_res in self.task_res_store.items():
191
200
  reply_to = UUID(task_res.task.ancestry[0])
201
+
202
+ # Check if corresponding TaskIns exists and is not expired
203
+ task_ins = self.task_ins_store.get(reply_to)
204
+ if task_ins is None:
205
+ log(WARNING, "TaskIns with task_id %s does not exist.", reply_to)
206
+ task_ids.remove(reply_to)
207
+ continue
208
+
209
+ if task_ins.task.created_at + task_ins.task.ttl <= time.time():
210
+ log(WARNING, "TaskIns with task_id %s is expired.", reply_to)
211
+ task_ids.remove(reply_to)
212
+ continue
213
+
192
214
  if reply_to in task_ids and task_res.task.delivered_at == "":
193
215
  task_res_list.append(task_res)
194
216
  replied_task_ids.add(reply_to)
195
- if limit and len(task_res_list) == limit:
196
- break
197
217
 
198
218
  # Check if the node is offline
199
219
  for task_id in task_ids - replied_task_ids:
200
- if limit and len(task_res_list) == limit:
201
- break
202
220
  task_ins = self.task_ins_store.get(task_id)
203
221
  if task_ins is None:
204
222
  continue
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SQLite based implemenation of server state."""
16
16
 
17
+ # pylint: disable=too-many-lines
17
18
 
18
19
  import json
19
20
  import re
@@ -150,6 +151,11 @@ class SqliteState(State): # pylint: disable=R0904
150
151
  ----------
151
152
  log_queries : bool
152
153
  Log each query which is executed.
154
+
155
+ Returns
156
+ -------
157
+ list[tuple[str]]
158
+ The list of all tables in the DB.
153
159
  """
154
160
  self.conn = sqlite3.connect(self.database_path)
155
161
  self.conn.execute("PRAGMA foreign_keys = ON;")
@@ -389,6 +395,16 @@ class SqliteState(State): # pylint: disable=R0904
389
395
  )
390
396
  return None
391
397
 
398
+ # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
399
+ if (
400
+ task_ins
401
+ and task_res
402
+ and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
403
+ and convert_sint64_to_uint64(task_ins["consumer_node_id"])
404
+ != task_res.task.producer.node_id
405
+ ):
406
+ return None
407
+
392
408
  # Fail if the TaskRes TTL exceeds the
393
409
  # expiration time of the TaskIns it replies to.
394
410
  # Condition: TaskIns.created_at + TaskIns.ttl ≥
@@ -432,8 +448,8 @@ class SqliteState(State): # pylint: disable=R0904
432
448
 
433
449
  return task_id
434
450
 
435
- # pylint: disable-next=R0914
436
- def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
451
+ # pylint: disable-next=R0912,R0915,R0914
452
+ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
437
453
  """Get TaskRes for task_ids.
438
454
 
439
455
  Usually, the Driver API calls this method to get results for instructions it has
@@ -448,8 +464,34 @@ class SqliteState(State): # pylint: disable=R0904
448
464
  will only take effect if enough task_ids are in the set AND are currently
449
465
  available. If `limit` is set, it has to be greater than zero.
450
466
  """
451
- if limit is not None and limit < 1:
452
- raise AssertionError("`limit` must be >= 1")
467
+ # Check if corresponding TaskIns exists and is not expired
468
+ task_ids_placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
469
+ query = f"""
470
+ SELECT *
471
+ FROM task_ins
472
+ WHERE task_id IN ({task_ids_placeholders})
473
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
474
+ """
475
+ query += ";"
476
+
477
+ task_ins_data = {}
478
+ for index, task_id in enumerate(task_ids):
479
+ task_ins_data[f"id_{index}"] = str(task_id)
480
+
481
+ task_ins_rows = self.query(query, task_ins_data)
482
+
483
+ if not task_ins_rows:
484
+ return []
485
+
486
+ for row in task_ins_rows:
487
+ # Convert values from sint64 to uint64
488
+ convert_sint64_values_in_dict_to_uint64(
489
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
490
+ )
491
+ task_ins = dict_to_task_ins(row)
492
+ if task_ins.task.created_at + task_ins.task.ttl <= time.time():
493
+ log(WARNING, "TaskIns with task_id %s is expired.", task_ins.task_id)
494
+ task_ids.remove(UUID(task_ins.task_id))
453
495
 
454
496
  # Retrieve all anonymous Tasks
455
497
  if len(task_ids) == 0:
@@ -465,10 +507,6 @@ class SqliteState(State): # pylint: disable=R0904
465
507
 
466
508
  data: dict[str, Union[str, float, int]] = {}
467
509
 
468
- if limit is not None:
469
- query += " LIMIT :limit"
470
- data["limit"] = limit
471
-
472
510
  query += ";"
473
511
 
474
512
  for index, task_id in enumerate(task_ids):
@@ -543,9 +581,6 @@ class SqliteState(State): # pylint: disable=R0904
543
581
 
544
582
  # Make TaskRes containing node unavailabe error
545
583
  for row in task_ins_rows:
546
- if limit and len(result) == limit:
547
- break
548
-
549
584
  for row in rows:
550
585
  # Convert values from sint64 to uint64
551
586
  convert_sint64_values_in_dict_to_uint64(
@@ -98,7 +98,7 @@ class State(abc.ABC): # pylint: disable=R0904
98
98
  """
99
99
 
100
100
  @abc.abstractmethod
101
- def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
101
+ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
102
102
  """Get TaskRes for task_ids.
103
103
 
104
104
  Usually, the Driver API calls this method to get results for instructions it has
@@ -106,12 +106,6 @@ class State(abc.ABC): # pylint: disable=R0904
106
106
 
107
107
  Retrieves all TaskRes for the given `task_ids` and returns and empty list of
108
108
  none could be found.
109
-
110
- Constraints
111
- -----------
112
- If `limit` is not `None`, return, at most, `limit` number of TaskRes. The limit
113
- will only take effect if enough task_ids are in the set AND are currently
114
- available. If `limit` is set, it has to be greater zero.
115
109
  """
116
110
 
117
111
  @abc.abstractmethod
@@ -100,11 +100,6 @@ def convert_uint64_values_in_dict_to_sint64(
100
100
  A dictionary where the values are integers to be converted.
101
101
  keys : list[str]
102
102
  A list of keys in the dictionary whose values need to be converted.
103
-
104
- Returns
105
- -------
106
- None
107
- This function does not return a value. It modifies `data_dict` in place.
108
103
  """
109
104
  for key in keys:
110
105
  if key in data_dict:
@@ -122,11 +117,6 @@ def convert_sint64_values_in_dict_to_uint64(
122
117
  A dictionary where the values are integers to be converted.
123
118
  keys : list[str]
124
119
  A list of keys in the dictionary whose values need to be converted.
125
-
126
- Returns
127
- -------
128
- None
129
- This function does not return a value. It modifies `data_dict` in place.
130
120
  """
131
121
  for key in keys:
132
122
  if key in data_dict:
@@ -48,7 +48,7 @@ from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
48
48
  class RayActorClientProxy(ClientProxy):
49
49
  """Flower client proxy which delegates work using Ray."""
50
50
 
51
- def __init__( # pylint: disable=too-many-arguments
51
+ def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
52
52
  self,
53
53
  client_fn: ClientFnExt,
54
54
  node_id: int,