flwr 1.13.1__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +1 -0
  7. flwr/cli/log.py +18 -36
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +81 -0
  10. flwr/cli/ls.py +205 -106
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +2 -1
  13. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  23. flwr/cli/run/__init__.py +1 -0
  24. flwr/cli/run/run.py +89 -39
  25. flwr/cli/stop.py +130 -0
  26. flwr/cli/utils.py +172 -8
  27. flwr/client/app.py +14 -3
  28. flwr/client/client.py +1 -32
  29. flwr/client/clientapp/app.py +4 -1
  30. flwr/client/clientapp/utils.py +1 -0
  31. flwr/client/grpc_adapter_client/connection.py +1 -1
  32. flwr/client/grpc_client/connection.py +1 -1
  33. flwr/client/grpc_rere_client/connection.py +13 -7
  34. flwr/client/message_handler/message_handler.py +1 -2
  35. flwr/client/mod/comms_mods.py +1 -0
  36. flwr/client/mod/localdp_mod.py +1 -1
  37. flwr/client/nodestate/__init__.py +1 -0
  38. flwr/client/nodestate/nodestate.py +1 -0
  39. flwr/client/nodestate/nodestate_factory.py +1 -0
  40. flwr/client/numpy_client.py +0 -44
  41. flwr/client/rest_client/connection.py +3 -3
  42. flwr/client/supernode/app.py +2 -2
  43. flwr/common/address.py +1 -0
  44. flwr/common/args.py +1 -0
  45. flwr/common/auth_plugin/__init__.py +24 -0
  46. flwr/common/auth_plugin/auth_plugin.py +111 -0
  47. flwr/common/config.py +3 -1
  48. flwr/common/constant.py +17 -1
  49. flwr/common/logger.py +40 -0
  50. flwr/common/message.py +1 -0
  51. flwr/common/object_ref.py +57 -54
  52. flwr/common/pyproject.py +1 -0
  53. flwr/common/record/__init__.py +1 -0
  54. flwr/common/record/parametersrecord.py +1 -0
  55. flwr/common/retry_invoker.py +77 -0
  56. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  57. flwr/common/telemetry.py +15 -4
  58. flwr/common/typing.py +12 -0
  59. flwr/common/version.py +1 -0
  60. flwr/proto/exec_pb2.py +38 -14
  61. flwr/proto/exec_pb2.pyi +107 -2
  62. flwr/proto/exec_pb2_grpc.py +102 -0
  63. flwr/proto/exec_pb2_grpc.pyi +39 -0
  64. flwr/proto/fab_pb2.py +4 -4
  65. flwr/proto/fab_pb2.pyi +4 -1
  66. flwr/proto/serverappio_pb2.py +18 -18
  67. flwr/proto/serverappio_pb2.pyi +8 -2
  68. flwr/proto/serverappio_pb2_grpc.py +34 -0
  69. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  70. flwr/proto/simulationio_pb2.py +2 -2
  71. flwr/proto/simulationio_pb2_grpc.py +34 -0
  72. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  73. flwr/server/app.py +54 -2
  74. flwr/server/compat/app_utils.py +7 -1
  75. flwr/server/driver/grpc_driver.py +11 -63
  76. flwr/server/driver/inmemory_driver.py +5 -1
  77. flwr/server/run_serverapp.py +8 -9
  78. flwr/server/serverapp/app.py +25 -3
  79. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  80. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  81. flwr/server/superlink/driver/serverappio_servicer.py +82 -23
  82. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  83. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  84. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  85. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  86. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
  87. flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
  88. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  89. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  90. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  91. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
  93. flwr/server/superlink/linkstate/linkstate.py +17 -2
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
  95. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  96. flwr/server/superlink/utils.py +65 -0
  97. flwr/simulation/app.py +16 -4
  98. flwr/simulation/ray_transport/ray_actor.py +1 -0
  99. flwr/simulation/ray_transport/utils.py +1 -0
  100. flwr/simulation/run_simulation.py +36 -22
  101. flwr/simulation/simulationio_connection.py +3 -0
  102. flwr/superexec/app.py +1 -0
  103. flwr/superexec/deployment.py +1 -0
  104. flwr/superexec/exec_grpc.py +19 -1
  105. flwr/superexec/exec_servicer.py +76 -2
  106. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  107. flwr/superexec/executor.py +1 -0
  108. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -7
  109. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
  110. flwr/proto/common_pb2.py +0 -36
  111. flwr/proto/common_pb2.pyi +0 -121
  112. flwr/proto/common_pb2_grpc.py +0 -4
  113. flwr/proto/common_pb2_grpc.pyi +0 -4
  114. flwr/proto/control_pb2.py +0 -27
  115. flwr/proto/control_pb2.pyi +0 -7
  116. flwr/proto/control_pb2_grpc.py +0 -135
  117. flwr/proto/control_pb2_grpc.pyi +0 -53
  118. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
  119. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
  120. {flwr-1.13.1.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower ServerApp process."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import DEBUG, ERROR, INFO
@@ -24,6 +25,7 @@ from typing import Optional
24
25
 
25
26
  from flwr.cli.config_utils import get_fab_metadata
26
27
  from flwr.cli.install import install_from_fab
28
+ from flwr.cli.utils import get_sha256_hash
27
29
  from flwr.common.args import add_args_flwr_app_common
28
30
  from flwr.common.config import (
29
31
  get_flwr_dir,
@@ -50,7 +52,8 @@ from flwr.common.serde import (
50
52
  run_from_proto,
51
53
  run_status_to_proto,
52
54
  )
53
- from flwr.common.typing import RunStatus
55
+ from flwr.common.telemetry import EventType, event
56
+ from flwr.common.typing import RunNotRunningException, RunStatus
54
57
  from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
55
58
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
56
59
  PullServerAppInputsRequest,
@@ -96,7 +99,7 @@ def flwr_serverapp() -> None:
96
99
  restore_output()
97
100
 
98
101
 
99
- def run_serverapp( # pylint: disable=R0914, disable=W0212
102
+ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
100
103
  serverappio_api_address: str,
101
104
  log_queue: Queue[Optional[str]],
102
105
  run_once: bool,
@@ -112,7 +115,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
112
115
  # Resolve directory where FABs are installed
113
116
  flwr_dir_ = get_flwr_dir(flwr_dir)
114
117
  log_uploader = None
115
-
118
+ success = True
119
+ hash_run_id = None
116
120
  while True:
117
121
 
118
122
  try:
@@ -128,6 +132,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
128
132
  run = run_from_proto(res.run)
129
133
  fab = fab_from_proto(res.fab)
130
134
 
135
+ hash_run_id = get_sha256_hash(run.run_id)
136
+
131
137
  driver.set_run(run.run_id)
132
138
 
133
139
  # Start log uploader for this run
@@ -170,6 +176,11 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
170
176
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
171
177
  )
172
178
 
179
+ event(
180
+ EventType.FLWR_SERVERAPP_RUN_ENTER,
181
+ event_details={"run-id-hash": hash_run_id},
182
+ )
183
+
173
184
  # Load and run the ServerApp with the Driver
174
185
  updated_context = run_(
175
186
  driver=driver,
@@ -186,11 +197,18 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
186
197
  _ = driver._stub.PushServerAppOutputs(out_req)
187
198
 
188
199
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
200
+ except RunNotRunningException:
201
+ log(INFO, "")
202
+ log(INFO, "Run ID %s stopped.", run.run_id)
203
+ log(INFO, "")
204
+ run_status = None
205
+ success = False
189
206
 
190
207
  except Exception as ex: # pylint: disable=broad-exception-caught
191
208
  exc_entity = "ServerApp"
192
209
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
193
210
  run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
211
+ success = False
194
212
 
195
213
  finally:
196
214
  # Stop log uploader for this run and upload final logs
@@ -206,6 +224,10 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
206
224
  run_id=run.run_id, run_status=run_status_proto
207
225
  )
208
226
  )
227
+ event(
228
+ EventType.FLWR_SERVERAPP_RUN_LEAVE,
229
+ event_details={"run-id-hash": hash_run_id, "success": success},
230
+ )
209
231
 
210
232
  # Stop the loop if `flwr-serverapp` is expected to process a single run
211
233
  if run_once:
@@ -17,6 +17,7 @@
17
17
  Paper: arxiv.org/pdf/1710.06963.pdf
18
18
  """
19
19
 
20
+
20
21
  from typing import Optional, Union
21
22
 
22
23
  from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """ServerAppIo gRPC API."""
16
16
 
17
+
17
18
  from logging import INFO
18
19
  from typing import Optional
19
20
 
@@ -32,6 +32,7 @@ from flwr.common.serde import (
32
32
  fab_from_proto,
33
33
  fab_to_proto,
34
34
  run_status_from_proto,
35
+ run_status_to_proto,
35
36
  run_to_proto,
36
37
  user_config_from_proto,
37
38
  )
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
48
49
  CreateRunResponse,
49
50
  GetRunRequest,
50
51
  GetRunResponse,
52
+ GetRunStatusRequest,
53
+ GetRunStatusResponse,
51
54
  UpdateRunStatusRequest,
52
55
  UpdateRunStatusResponse,
53
56
  )
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
67
70
  from flwr.server.superlink.ffs.ffs import Ffs
68
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
69
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
73
+ from flwr.server.superlink.utils import abort_if
70
74
  from flwr.server.utils.validator import validate_task_ins_or_res
71
75
 
72
76
 
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
85
89
  ) -> GetNodesResponse:
86
90
  """Get available nodes."""
87
91
  log(DEBUG, "ServerAppIoServicer.GetNodes")
92
+
93
+ # Init state
88
94
  state: LinkState = self.state_factory.state()
95
+
96
+ # Abort if the run is not running
97
+ abort_if(
98
+ request.run_id,
99
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
100
+ state,
101
+ context,
102
+ )
103
+
89
104
  all_ids: set[int] = state.get_nodes(request.run_id)
90
105
  nodes: list[Node] = [
91
106
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
123
138
  """Push a set of TaskIns."""
124
139
  log(DEBUG, "ServerAppIoServicer.PushTaskIns")
125
140
 
141
+ # Init state
142
+ state: LinkState = self.state_factory.state()
143
+
144
+ # Abort if the run is not running
145
+ abort_if(
146
+ request.run_id,
147
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
148
+ state,
149
+ context,
150
+ )
151
+
126
152
  # Set pushed_at (timestamp in seconds)
127
153
  pushed_at = time.time()
128
154
  for task_ins in request.task_ins_list:
@@ -133,9 +159,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
133
159
  for task_ins in request.task_ins_list:
134
160
  validation_errors = validate_task_ins_or_res(task_ins)
135
161
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
136
-
137
- # Init state
138
- state: LinkState = self.state_factory.state()
162
+ _raise_if(
163
+ request.run_id != task_ins.run_id, "`task_ins` has mismatched `run_id`"
164
+ )
139
165
 
140
166
  # Store each TaskIns
141
167
  task_ids: list[Optional[UUID]] = []
@@ -153,33 +179,35 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
153
179
  """Pull a set of TaskRes."""
154
180
  log(DEBUG, "ServerAppIoServicer.PullTaskRes")
155
181
 
156
- # Convert each task_id str to UUID
157
- task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
158
-
159
182
  # Init state
160
183
  state: LinkState = self.state_factory.state()
161
184
 
162
- # Register callback
163
- def on_rpc_done() -> None:
164
- log(
165
- DEBUG,
166
- "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
167
- )
168
-
169
- if context.is_active():
170
- return
171
- if context.code() != grpc.StatusCode.OK:
172
- return
173
-
174
- # Delete delivered TaskIns and TaskRes
175
- state.delete_tasks(task_ids=task_ids)
185
+ # Abort if the run is not running
186
+ abort_if(
187
+ request.run_id,
188
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
189
+ state,
190
+ context,
191
+ )
176
192
 
177
- context.add_callback(on_rpc_done)
193
+ # Convert each task_id str to UUID
194
+ task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
178
195
 
179
196
  # Read from state
180
197
  task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
181
198
 
182
- context.set_code(grpc.StatusCode.OK)
199
+ # Validate request
200
+ for task_res in task_res_list:
201
+ _raise_if(
202
+ request.run_id != task_res.run_id, "`task_res` has mismatched `run_id`"
203
+ )
204
+
205
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
206
+ task_ins_ids_to_delete = {
207
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
208
+ }
209
+ state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
210
+
183
211
  return PullTaskResResponse(task_res_list=task_res_list)
184
212
 
185
213
  def GetRun(
@@ -255,7 +283,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
255
283
  ) -> PushServerAppOutputsResponse:
256
284
  """Push ServerApp process outputs."""
257
285
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
286
+
287
+ # Init state
258
288
  state = self.state_factory.state()
289
+
290
+ # Abort if the run is not running
291
+ abort_if(
292
+ request.run_id,
293
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
294
+ state,
295
+ context,
296
+ )
297
+
259
298
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
260
299
  return PushServerAppOutputsResponse()
261
300
 
@@ -263,9 +302,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
263
302
  self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
264
303
  ) -> UpdateRunStatusResponse:
265
304
  """Update the status of a run."""
266
- log(DEBUG, "ControlServicer.UpdateRunStatus")
305
+ log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
306
+
307
+ # Init state
267
308
  state = self.state_factory.state()
268
309
 
310
+ # Abort if the run is finished
311
+ abort_if(request.run_id, [Status.FINISHED], state, context)
312
+
269
313
  # Update the run status
270
314
  state.update_run_status(
271
315
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
@@ -284,6 +328,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
284
328
  state.add_serverapp_log(request.run_id, merged_logs)
285
329
  return PushLogsResponse()
286
330
 
331
+ def GetRunStatus(
332
+ self, request: GetRunStatusRequest, context: grpc.ServicerContext
333
+ ) -> GetRunStatusResponse:
334
+ """Get the status of a run."""
335
+ log(DEBUG, "ServerAppIoServicer.GetRunStatus")
336
+ state = self.state_factory.state()
337
+
338
+ # Get run status from LinkState
339
+ run_statuses = state.get_run_status(set(request.run_ids))
340
+ run_status_dict = {
341
+ run_id: run_status_to_proto(run_status)
342
+ for run_id, run_status in run_statuses.items()
343
+ }
344
+ return GetRunStatusResponse(run_status_dict=run_status_dict)
345
+
287
346
 
288
347
  def _raise_if(validation_error: bool, detail: str) -> None:
289
348
  if validation_error:
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Disk based Flower File Storage."""
16
16
 
17
+
17
18
  import hashlib
18
19
  import json
19
20
  from pathlib import Path
@@ -158,4 +158,5 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
158
158
  return message_handler.get_fab(
159
159
  request=request,
160
160
  ffs=self.ffs_factory.ffs(),
161
+ state=self.state_factory.state(),
161
162
  )
@@ -18,6 +18,7 @@ Relevant knowledge for reading this modules code:
18
18
  - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
19
19
  """
20
20
 
21
+
21
22
  import uuid
22
23
  from collections.abc import Iterator
23
24
  from typing import Callable
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
20
20
  import grpc
21
21
 
22
22
  from flwr.common.logger import log
23
+ from flwr.common.typing import InvalidRunStatusException
23
24
  from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
38
39
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
40
  from flwr.server.superlink.fleet.message_handler import message_handler
40
41
  from flwr.server.superlink.linkstate import LinkStateFactory
42
+ from flwr.server.superlink.utils import abort_grpc_context
41
43
 
42
44
 
43
45
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
105
107
  )
106
108
  else:
107
109
  log(INFO, "[Fleet.PushTaskRes] No task results to push")
108
- return message_handler.push_task_res(
109
- request=request,
110
- state=self.state_factory.state(),
111
- )
110
+
111
+ try:
112
+ res = message_handler.push_task_res(
113
+ request=request,
114
+ state=self.state_factory.state(),
115
+ )
116
+ except InvalidRunStatusException as e:
117
+ abort_grpc_context(e.message, context)
118
+
119
+ return res
112
120
 
113
121
  def GetRun(
114
122
  self, request: GetRunRequest, context: grpc.ServicerContext
115
123
  ) -> GetRunResponse:
116
124
  """Get run information."""
117
125
  log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
118
- return message_handler.get_run(
119
- request=request,
120
- state=self.state_factory.state(),
121
- )
126
+
127
+ try:
128
+ res = message_handler.get_run(
129
+ request=request,
130
+ state=self.state_factory.state(),
131
+ )
132
+ except InvalidRunStatusException as e:
133
+ abort_grpc_context(e.message, context)
134
+
135
+ return res
122
136
 
123
137
  def GetFab(
124
138
  self, request: GetFabRequest, context: grpc.ServicerContext
125
139
  ) -> GetFabResponse:
126
140
  """Get FAB."""
127
141
  log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
128
- return message_handler.get_fab(
129
- request=request,
130
- ffs=self.ffs_factory.ffs(),
131
- )
142
+ try:
143
+ res = message_handler.get_fab(
144
+ request=request,
145
+ ffs=self.ffs_factory.ffs(),
146
+ state=self.state_factory.state(),
147
+ )
148
+ except InvalidRunStatusException as e:
149
+ abort_grpc_context(e.message, context)
150
+
151
+ return res
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  )
46
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
47
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
- from flwr.server.superlink.linkstate import LinkState
48
+ from flwr.server.superlink.linkstate import LinkStateFactory
49
49
 
50
50
  _PUBLIC_KEY_HEADER = "public-key"
51
51
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -84,15 +84,16 @@ def _get_value_from_tuples(
84
84
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
85
  """Server interceptor for node authentication."""
86
86
 
87
- def __init__(self, state: LinkState):
88
- self.state = state
87
+ def __init__(self, state_factory: LinkStateFactory):
88
+ self.state_factory = state_factory
89
+ state = self.state_factory.state()
89
90
 
90
91
  self.node_public_keys = state.get_node_public_keys()
91
92
  if len(self.node_public_keys) == 0:
92
93
  log(WARNING, "Authentication enabled, but no known public keys configured")
93
94
 
94
- private_key = self.state.get_server_private_key()
95
- public_key = self.state.get_server_public_key()
95
+ private_key = state.get_server_private_key()
96
+ public_key = state.get_server_public_key()
96
97
 
97
98
  if private_key is None or public_key is None:
98
99
  raise ValueError("Error loading authentication keys")
@@ -154,7 +155,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
154
155
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
155
156
 
156
157
  # Verify node_id
157
- node_id = self.state.get_node_id(node_public_key_bytes)
158
+ node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
158
159
 
159
160
  if not self._verify_node_id(node_id, request):
160
161
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
@@ -186,7 +187,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
186
187
  return False
187
188
  return request.task_res_list[0].task.producer.node_id == node_id
188
189
  if isinstance(request, GetRunRequest):
189
- return node_id in self.state.get_nodes(request.run_id)
190
+ return node_id in self.state_factory.state().get_nodes(request.run_id)
190
191
  return request.node.node_id == node_id
191
192
 
192
193
  def _verify_hmac(
@@ -210,17 +211,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
210
211
  ),
211
212
  )
212
213
  )
213
-
214
- node_id = self.state.get_node_id(public_key_bytes)
214
+ state = self.state_factory.state()
215
+ node_id = state.get_node_id(public_key_bytes)
215
216
 
216
217
  # Handle `CreateNode` here instead of calling the default method handler
217
218
  # Return previously assigned `node_id` for the provided `public_key`
218
219
  if node_id is not None:
219
- self.state.acknowledge_ping(node_id, request.ping_interval)
220
+ state.acknowledge_ping(node_id, request.ping_interval)
220
221
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
221
222
 
222
223
  # No `node_id` exists for the provided `public_key`
223
224
  # Handle `CreateNode` here instead of calling the default method handler
224
225
  # Note: the innermost `CreateNode` method will never be called
225
- node_id = self.state.create_node(request.ping_interval, public_key_bytes)
226
+ node_id = state.create_node(request.ping_interval, public_key_bytes)
226
227
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -19,8 +19,9 @@ import time
19
19
  from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.constant import Status
22
23
  from flwr.common.serde import fab_to_proto, user_config_to_proto
23
- from flwr.common.typing import Fab
24
+ from flwr.common.typing import Fab, InvalidRunStatusException
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
26
27
  CreateNodeRequest,
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
44
45
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
46
  from flwr.server.superlink.ffs.ffs import Ffs
46
47
  from flwr.server.superlink.linkstate import LinkState
48
+ from flwr.server.superlink.utils import check_abort
47
49
 
48
50
 
49
51
  def create_node(
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
98
100
  task_res: TaskRes = request.task_res_list[0]
99
101
  # pylint: enable=no-member
100
102
 
103
+ # Abort if the run is not running
104
+ abort_msg = check_abort(
105
+ task_res.run_id,
106
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
107
+ state,
108
+ )
109
+ if abort_msg:
110
+ raise InvalidRunStatusException(abort_msg)
111
+
101
112
  # Set pushed_at (timestamp in seconds)
102
113
  task_res.task.pushed_at = time.time()
103
114
 
@@ -112,15 +123,22 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
112
123
  return response
113
124
 
114
125
 
115
- def get_run(
116
- request: GetRunRequest, state: LinkState # pylint: disable=W0613
117
- ) -> GetRunResponse:
126
+ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
118
127
  """Get run information."""
119
128
  run = state.get_run(request.run_id)
120
129
 
121
130
  if run is None:
122
131
  return GetRunResponse()
123
132
 
133
+ # Abort if the run is not running
134
+ abort_msg = check_abort(
135
+ request.run_id,
136
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
137
+ state,
138
+ )
139
+ if abort_msg:
140
+ raise InvalidRunStatusException(abort_msg)
141
+
124
142
  return GetRunResponse(
125
143
  run=Run(
126
144
  run_id=run.run_id,
@@ -133,9 +151,18 @@ def get_run(
133
151
 
134
152
 
135
153
  def get_fab(
136
- request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
154
+ request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
137
155
  ) -> GetFabResponse:
138
156
  """Get FAB."""
157
+ # Abort if the run is not running
158
+ abort_msg = check_abort(
159
+ request.run_id,
160
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
161
+ state,
162
+ )
163
+ if abort_msg:
164
+ raise InvalidRunStatusException(abort_msg)
165
+
139
166
  if result := ffs.get(request.hash_str):
140
167
  fab = Fab(request.hash_str, result[0])
141
168
  return GetFabResponse(fab=fab_to_proto(fab))
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
154
154
  # Get ffs from app
155
155
  ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
156
156
 
157
+ # Get state from app
158
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
159
+
157
160
  # Handle message
158
- return message_handler.get_fab(request=request, ffs=ffs)
161
+ return message_handler.get_fab(request=request, ffs=ffs, state=state)
159
162
 
160
163
 
161
164
  routes = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine side."""
16
16
 
17
+
17
18
  from .vce_api import start_vce
18
19
 
19
20
  __all__ = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Simulation Engine Backends."""
16
16
 
17
+
17
18
  import importlib
18
19
 
19
20
  from .backend import Backend, BackendConfig
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
+
17
18
  import sys
18
19
  from logging import DEBUG, ERROR
19
20
  from typing import Callable, Optional, Union
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
265
265
  for task_res in task_res_found:
266
266
  task_res.task.delivered_at = delivered_at
267
267
 
268
- # Cleanup
269
- self._force_delete_tasks_by_ids(set(ret.keys()))
270
-
271
268
  return list(ret.values())
272
269
 
273
- def delete_tasks(self, task_ids: set[UUID]) -> None:
274
- """Delete all delivered TaskIns/TaskRes pairs."""
275
- task_ins_to_be_deleted: set[UUID] = set()
276
- task_res_to_be_deleted: set[UUID] = set()
277
-
278
- with self.lock:
279
- for task_ins_id in task_ids:
280
- # Find the task_id of the matching task_res
281
- for task_res_id, task_res in self.task_res_store.items():
282
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
283
- continue
284
- if task_res.task.delivered_at == "":
285
- continue
286
-
287
- task_ins_to_be_deleted.add(task_ins_id)
288
- task_res_to_be_deleted.add(task_res_id)
289
-
290
- for task_id in task_ins_to_be_deleted:
291
- del self.task_ins_store[task_id]
292
- del self.task_ins_id_to_task_res_id[task_id]
293
- for task_id in task_res_to_be_deleted:
294
- del self.task_res_store[task_id]
295
-
296
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
297
- """Delete tasks based on a set of TaskIns IDs."""
298
- if not task_ids:
270
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
271
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
272
+ if not task_ins_ids:
299
273
  return
300
274
 
301
275
  with self.lock:
302
- for task_id in task_ids:
276
+ for task_id in task_ins_ids:
303
277
  # Delete TaskIns
304
278
  if task_id in self.task_ins_store:
305
279
  del self.task_ins_store[task_id]
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
308
282
  task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
309
283
  del self.task_res_store[task_res_id]
310
284
 
285
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
286
+ """Get all TaskIns IDs for the given run_id."""
287
+ task_id_list: set[UUID] = set()
288
+ with self.lock:
289
+ for task_id, task_ins in self.task_ins_store.items():
290
+ if task_ins.run_id == run_id:
291
+ task_id_list.add(task_id)
292
+
293
+ return task_id_list
294
+
311
295
  def num_task_ins(self) -> int:
312
296
  """Calculate the number of task_ins in store.
313
297
 
@@ -446,6 +430,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
446
430
  """Retrieve `server_public_key` in urlsafe bytes."""
447
431
  return self.server_public_key
448
432
 
433
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
434
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
435
+ with self.lock:
436
+ self.server_private_key = None
437
+ self.server_public_key = None
438
+ self.node_public_keys.clear()
439
+
449
440
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
450
441
  """Store a set of `node_public_keys` in the link state."""
451
442
  with self.lock: