flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
flwr/server/server_app.py CHANGED
@@ -15,18 +15,18 @@
15
15
  """Flower ServerApp."""
16
16
 
17
17
 
18
+ import inspect
19
+ from collections.abc import Iterator
20
+ from contextlib import contextmanager
18
21
  from typing import Callable, Optional
19
22
 
20
23
  from flwr.common import Context
21
- from flwr.common.logger import (
22
- warn_deprecated_feature_with_example,
23
- warn_preview_feature,
24
- )
24
+ from flwr.common.logger import warn_deprecated_feature_with_example
25
25
  from flwr.server.strategy import Strategy
26
26
 
27
27
  from .client_manager import ClientManager
28
- from .compat import start_driver
29
- from .driver import Driver
28
+ from .compat import start_grid
29
+ from .grid import Driver, Grid
30
30
  from .server import Server
31
31
  from .server_config import ServerConfig
32
32
  from .typing import ServerAppCallable, ServerFn
@@ -44,13 +44,33 @@ SERVER_FN_USAGE_EXAMPLE = """
44
44
  app = ServerApp(server_fn=server_fn)
45
45
  """
46
46
 
47
+ GRID_USAGE_EXAMPLE = """
48
+ app = ServerApp()
47
49
 
48
- class ServerApp:
50
+ @app.main()
51
+ def main(grid: Grid, context: Context) -> None:
52
+ # Your existing ServerApp code ...
53
+ """
54
+
55
+ DRIVER_DEPRECATION_MSG = """
56
+ The `Driver` class is deprecated, it will be removed in a future release.
57
+ """
58
+ DRIVER_EXAMPLE_MSG = """
59
+ Instead, use `Grid` in the signature of your `ServerApp`. For example:
60
+ """
61
+
62
+
63
+ @contextmanager
64
+ def _empty_lifespan(_: Context) -> Iterator[None]:
65
+ yield
66
+
67
+
68
+ class ServerApp: # pylint: disable=too-many-instance-attributes
49
69
  """Flower ServerApp.
50
70
 
51
71
  Examples
52
72
  --------
53
- Use the `ServerApp` with an existing `Strategy`:
73
+ Use the ``ServerApp`` with an existing ``Strategy``:
54
74
 
55
75
  >>> def server_fn(context: Context):
56
76
  >>> server_config = ServerConfig(num_rounds=3)
@@ -62,12 +82,12 @@ class ServerApp:
62
82
  >>>
63
83
  >>> app = ServerApp(server_fn=server_fn)
64
84
 
65
- Use the `ServerApp` with a custom main function:
85
+ Use the ``ServerApp`` with a custom main function:
66
86
 
67
87
  >>> app = ServerApp()
68
88
  >>>
69
89
  >>> @app.main()
70
- >>> def main(driver: Driver, context: Context) -> None:
90
+ >>> def main(grid: Grid, context: Context) -> None:
71
91
  >>> print("ServerApp running")
72
92
  """
73
93
 
@@ -105,29 +125,31 @@ class ServerApp:
105
125
  self._client_manager = client_manager
106
126
  self._server_fn = server_fn
107
127
  self._main: Optional[ServerAppCallable] = None
128
+ self._lifespan = _empty_lifespan
108
129
 
109
- def __call__(self, driver: Driver, context: Context) -> None:
130
+ def __call__(self, grid: Grid, context: Context) -> None:
110
131
  """Execute `ServerApp`."""
111
- # Compatibility mode
112
- if not self._main:
113
- if self._server_fn:
114
- # Execute server_fn()
115
- components = self._server_fn(context)
116
- self._server = components.server
117
- self._config = components.config
118
- self._strategy = components.strategy
119
- self._client_manager = components.client_manager
120
- start_driver(
121
- server=self._server,
122
- config=self._config,
123
- strategy=self._strategy,
124
- client_manager=self._client_manager,
125
- driver=driver,
126
- )
127
- return
132
+ with self._lifespan(context):
133
+ # Compatibility mode
134
+ if not self._main:
135
+ if self._server_fn:
136
+ # Execute server_fn()
137
+ components = self._server_fn(context)
138
+ self._server = components.server
139
+ self._config = components.config
140
+ self._strategy = components.strategy
141
+ self._client_manager = components.client_manager
142
+ start_grid(
143
+ server=self._server,
144
+ config=self._config,
145
+ strategy=self._strategy,
146
+ client_manager=self._client_manager,
147
+ grid=grid,
148
+ )
149
+ return
128
150
 
129
- # New execution mode
130
- self._main(driver, context)
151
+ # New execution mode
152
+ self._main(grid, context)
131
153
 
132
154
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
133
155
  """Return a decorator that registers the main fn with the server app.
@@ -137,7 +159,7 @@ class ServerApp:
137
159
  >>> app = ServerApp()
138
160
  >>>
139
161
  >>> @app.main()
140
- >>> def main(driver: Driver, context: Context) -> None:
162
+ >>> def main(grid: Grid, context: Context) -> None:
141
163
  >>> print("ServerApp running")
142
164
  """
143
165
 
@@ -162,12 +184,20 @@ class ServerApp:
162
184
  >>> app = ServerApp()
163
185
  >>>
164
186
  >>> @app.main()
165
- >>> def main(driver: Driver, context: Context) -> None:
187
+ >>> def main(grid: Grid, context: Context) -> None:
166
188
  >>> print("ServerApp running")
167
189
  """,
168
190
  )
169
191
 
170
- warn_preview_feature("ServerApp-register-main-function")
192
+ sig = inspect.signature(main_fn)
193
+ param = list(sig.parameters.values())[0]
194
+ # Check if parameter name or the annotation should be updated
195
+ if param.name == "driver" or param.annotation is Driver:
196
+ warn_deprecated_feature_with_example(
197
+ deprecation_message=DRIVER_DEPRECATION_MSG,
198
+ example_message=DRIVER_EXAMPLE_MSG,
199
+ code_example=GRID_USAGE_EXAMPLE,
200
+ )
171
201
 
172
202
  # Register provided function with the ServerApp object
173
203
  self._main = main_fn
@@ -177,6 +207,69 @@ class ServerApp:
177
207
 
178
208
  return main_decorator
179
209
 
210
+ def lifespan(
211
+ self,
212
+ ) -> Callable[
213
+ [Callable[[Context], Iterator[None]]], Callable[[Context], Iterator[None]]
214
+ ]:
215
+ """Return a decorator that registers the lifespan fn with the server app.
216
+
217
+ The decorated function should accept a `Context` object and use `yield`
218
+ to define enter and exit behavior.
219
+
220
+ Examples
221
+ --------
222
+ >>> app = ServerApp()
223
+ >>>
224
+ >>> @app.lifespan()
225
+ >>> def lifespan(context: Context) -> None:
226
+ >>> # Perform initialization tasks before the app starts
227
+ >>> print("Initializing ServerApp")
228
+ >>>
229
+ >>> yield # ServerApp is running
230
+ >>>
231
+ >>> # Perform cleanup tasks after the app stops
232
+ >>> print("Cleaning up ServerApp")
233
+ """
234
+
235
+ def lifespan_decorator(
236
+ lifespan_fn: Callable[[Context], Iterator[None]],
237
+ ) -> Callable[[Context], Iterator[None]]:
238
+ """Register the lifespan fn with the ServerApp object."""
239
+
240
+ @contextmanager
241
+ def decorated_lifespan(context: Context) -> Iterator[None]:
242
+ # Execute the code before `yield` in lifespan_fn
243
+ try:
244
+ if not isinstance(it := lifespan_fn(context), Iterator):
245
+ raise StopIteration
246
+ next(it)
247
+ except StopIteration:
248
+ raise RuntimeError(
249
+ "lifespan function should yield at least once."
250
+ ) from None
251
+
252
+ try:
253
+ # Enter the context
254
+ yield
255
+ finally:
256
+ try:
257
+ # Execute the code after `yield` in lifespan_fn
258
+ next(it)
259
+ except StopIteration:
260
+ pass
261
+ else:
262
+ raise RuntimeError("lifespan function should only yield once.")
263
+
264
+ # Register provided function with the ServerApp object
265
+ # Ignore mypy error because of different argument names (`_` vs `context`)
266
+ self._lifespan = decorated_lifespan # type: ignore
267
+
268
+ # Return provided function unmodified
269
+ return lifespan_fn
270
+
271
+ return lifespan_decorator
272
+
180
273
 
181
274
  class LoadServerAppError(Exception):
182
275
  """Error when trying to load `ServerApp`."""
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
60
60
  PullServerAppInputsResponse,
61
61
  PushServerAppOutputsRequest,
62
62
  )
63
- from flwr.server.driver.grpc_driver import GrpcDriver
63
+ from flwr.server.grid.grpc_grid import GrpcGrid
64
64
  from flwr.server.run_serverapp import run as run_
65
65
 
66
66
 
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
106
106
  certificates: Optional[bytes] = None,
107
107
  ) -> None:
108
108
  """Run Flower ServerApp process."""
109
- driver = GrpcDriver(
109
+ grid = GrpcGrid(
110
110
  serverappio_service_address=serverappio_api_address,
111
111
  root_certificates=certificates,
112
112
  )
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
123
123
  # Pull ServerAppInputs from LinkState
124
124
  req = PullServerAppInputsRequest()
125
125
  log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
126
- res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
126
+ res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
127
127
  if not res.HasField("run"):
128
128
  sleep(3)
129
129
  run_status = None
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
135
135
 
136
136
  hash_run_id = get_sha256_hash(run.run_id)
137
137
 
138
- driver.set_run(run.run_id)
138
+ grid.set_run(run.run_id)
139
139
 
140
140
  # Start log uploader for this run
141
141
  log_uploader = start_log_uploader(
142
142
  log_queue=log_queue,
143
143
  node_id=0,
144
144
  run_id=run.run_id,
145
- stub=driver._stub,
145
+ stub=grid._stub,
146
146
  )
147
147
 
148
148
  log(DEBUG, "[flwr-serverapp] Start FAB installation.")
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
173
173
 
174
174
  # Change status to Running
175
175
  run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
176
- driver._stub.UpdateRunStatus(
176
+ grid._stub.UpdateRunStatus(
177
177
  UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
178
178
  )
179
179
 
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
182
182
  event_details={"run-id-hash": hash_run_id},
183
183
  )
184
184
 
185
- # Load and run the ServerApp with the Driver
185
+ # Load and run the ServerApp with the Grid
186
186
  updated_context = run_(
187
- driver=driver,
187
+ grid=grid,
188
188
  server_app_dir=app_path,
189
189
  server_app_attr=server_app_attr,
190
190
  context=context,
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
196
196
  out_req = PushServerAppOutputsRequest(
197
197
  run_id=run.run_id, context=context_proto
198
198
  )
199
- _ = driver._stub.PushServerAppOutputs(out_req)
199
+ _ = grid._stub.PushServerAppOutputs(out_req)
200
200
 
201
201
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
202
202
  except RunNotRunningException:
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
221
221
  # Update run status
222
222
  if run_status:
223
223
  run_status_proto = run_status_to_proto(run_status)
224
- driver._stub.UpdateRunStatus(
224
+ grid._stub.UpdateRunStatus(
225
225
  UpdateRunStatusRequest(
226
226
  run_id=run.run_id, run_status=run_status_proto
227
227
  )
@@ -103,11 +103,11 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
103
103
  if request.messages_list:
104
104
  log(
105
105
  INFO,
106
- "[Fleet.PushMessages] Push results from node_id=%s",
106
+ "[Fleet.PushMessages] Push replies from node_id=%s",
107
107
  request.messages_list[0].metadata.src_node_id,
108
108
  )
109
109
  else:
110
- log(INFO, "[Fleet.PushMessages] No task results to push")
110
+ log(INFO, "[Fleet.PushMessages] No replies to push")
111
111
 
112
112
  try:
113
113
  res = message_handler.push_messages(
@@ -18,13 +18,12 @@
18
18
  from typing import Optional
19
19
  from uuid import UUID
20
20
 
21
+ from flwr.common import Message
21
22
  from flwr.common.constant import Status
22
23
  from flwr.common.serde import (
23
24
  fab_to_proto,
24
25
  message_from_proto,
25
- message_from_taskins,
26
26
  message_to_proto,
27
- message_to_taskres,
28
27
  user_config_to_proto,
29
28
  )
30
29
  from flwr.common.typing import Fab, InvalidRunStatusException
@@ -48,7 +47,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
48
47
  GetRunResponse,
49
48
  Run,
50
49
  )
51
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
52
50
  from flwr.server.superlink.ffs.ffs import Ffs
53
51
  from flwr.server.superlink.linkstate import LinkState
54
52
  from flwr.server.superlink.utils import check_abort
@@ -92,13 +90,12 @@ def pull_messages(
92
90
  node = request.node # pylint: disable=no-member
93
91
  node_id: int = node.node_id
94
92
 
95
- # Retrieve TaskIns from State
96
- task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
93
+ # Retrieve Message from State
94
+ message_list: list[Message] = state.get_message_ins(node_id=node_id, limit=1)
97
95
 
98
96
  # Convert to Messages
99
97
  msg_proto = []
100
- for task_ins in task_ins_list:
101
- msg = message_from_taskins(task_ins)
98
+ for msg in message_list:
102
99
  msg_proto.append(message_to_proto(msg))
103
100
 
104
101
  return PullMessagesResponse(messages_list=msg_proto)
@@ -108,21 +105,20 @@ def push_messages(
108
105
  request: PushMessagesRequest, state: LinkState
109
106
  ) -> PushMessagesResponse:
110
107
  """Push Messages handler."""
111
- # Convert Message to TaskRes
108
+ # Convert Message from proto
112
109
  msg = message_from_proto(message_proto=request.messages_list[0])
113
- task_res = message_to_taskres(msg)
114
110
 
115
111
  # Abort if the run is not running
116
112
  abort_msg = check_abort(
117
- task_res.run_id,
113
+ msg.metadata.run_id,
118
114
  [Status.PENDING, Status.STARTING, Status.FINISHED],
119
115
  state,
120
116
  )
121
117
  if abort_msg:
122
118
  raise InvalidRunStatusException(abort_msg)
123
119
 
124
- # Store TaskRes in State
125
- message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
120
+ # Store Message in State
121
+ message_id: Optional[UUID] = state.store_message_res(message=msg)
126
122
 
127
123
  # Build response
128
124
  response = PushMessagesResponse(
@@ -21,9 +21,9 @@ from typing import Callable
21
21
  from flwr.client.client_app import ClientApp
22
22
  from flwr.common.context import Context
23
23
  from flwr.common.message import Message
24
- from flwr.common.typing import ConfigsRecordValues
24
+ from flwr.common.typing import ConfigRecordValues
25
25
 
26
- BackendConfig = dict[str, dict[str, ConfigsRecordValues]]
26
+ BackendConfig = dict[str, dict[str, ConfigRecordValues]]
27
27
 
28
28
 
29
29
  class Backend(ABC):
@@ -45,7 +45,7 @@ class Backend(ABC):
45
45
  def num_workers(self) -> int:
46
46
  """Return number of workers in the backend.
47
47
 
48
- This is the number of TaskIns that can be processed concurrently.
48
+ This is the number of Messages that can be processed concurrently.
49
49
  """
50
50
  return 0
51
51
 
@@ -26,7 +26,7 @@ from flwr.common.constant import PARTITION_ID_KEY
26
26
  from flwr.common.context import Context
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.message import Message
29
- from flwr.common.typing import ConfigsRecordValues
29
+ from flwr.common.typing import ConfigRecordValues
30
30
  from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
31
31
  from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
32
32
 
@@ -104,7 +104,7 @@ class RayBackend(Backend):
104
104
  if not ray.is_initialized():
105
105
  ray_init_args: dict[
106
106
  str,
107
- ConfigsRecordValues,
107
+ ConfigRecordValues,
108
108
  ] = {}
109
109
 
110
110
  if backend_config.get(self.init_args_key):
@@ -29,6 +29,7 @@ from typing import Callable, Optional
29
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
30
30
  from flwr.client.clientapp.utils import get_load_client_app_fn
31
31
  from flwr.client.run_info_store import DeprecatedRunInfoStore
32
+ from flwr.common import Message
32
33
  from flwr.common.constant import (
33
34
  NUM_PARTITIONS_KEY,
34
35
  PARTITION_ID_KEY,
@@ -37,9 +38,7 @@ from flwr.common.constant import (
37
38
  )
38
39
  from flwr.common.logger import log
39
40
  from flwr.common.message import Error
40
- from flwr.common.serde import message_from_taskins, message_to_taskres
41
41
  from flwr.common.typing import Run
42
- from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
42
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
44
43
 
45
44
  from .backend import Backend, error_messages_backends, supported_backends
@@ -87,33 +86,33 @@ def _register_node_info_stores(
87
86
 
88
87
  # pylint: disable=too-many-arguments,too-many-locals
89
88
  def worker(
90
- taskins_queue: "Queue[TaskIns]",
91
- taskres_queue: "Queue[TaskRes]",
89
+ messageins_queue: Queue[Message],
90
+ messageres_queue: Queue[Message],
92
91
  node_info_store: dict[int, DeprecatedRunInfoStore],
93
92
  backend: Backend,
94
93
  f_stop: threading.Event,
95
94
  ) -> None:
96
- """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
95
+ """Process messages from the queue, execute them, update context, and enqueue
96
+ replies."""
97
97
  while not f_stop.is_set():
98
98
  out_mssg = None
99
99
  try:
100
100
  # Fetch from queue with timeout. We use a timeout so
101
101
  # the stopping event can be evaluated even when the queue is empty.
102
- task_ins: TaskIns = taskins_queue.get(timeout=1.0)
103
- node_id = task_ins.task.consumer.node_id
102
+ message: Message = messageins_queue.get(timeout=1.0)
103
+ node_id = message.metadata.dst_node_id
104
104
 
105
105
  # Retrieve context
106
- context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
107
-
108
- # Convert TaskIns to Message
109
- message = message_from_taskins(task_ins)
106
+ context = node_info_store[node_id].retrieve_context(
107
+ run_id=message.metadata.run_id
108
+ )
110
109
 
111
110
  # Let backend process message
112
111
  out_mssg, updated_context = backend.process_message(message, context)
113
112
 
114
113
  # Update Context
115
114
  node_info_store[node_id].update_context(
116
- task_ins.run_id, context=updated_context
115
+ message.metadata.run_id, context=updated_context
117
116
  )
118
117
  except Empty:
119
118
  # An exception raised if queue.get times out
@@ -131,41 +130,37 @@ def worker(
131
130
  e_code = ErrorCode.UNKNOWN
132
131
 
133
132
  reason = str(type(ex)) + ":<'" + str(ex) + "'>"
134
- out_mssg = message.create_error_reply(
135
- error=Error(code=e_code, reason=reason)
136
- )
133
+ out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
137
134
 
138
135
  finally:
139
136
  if out_mssg:
140
- # Convert to TaskRes
141
- task_res = message_to_taskres(out_mssg)
142
- # Store TaskRes in state
143
- taskres_queue.put(task_res)
137
+ # Store reply Messages in state
138
+ messageres_queue.put(out_mssg)
144
139
 
145
140
 
146
- def add_taskins_to_queue(
141
+ def add_messages_to_queue(
147
142
  state: LinkState,
148
- queue: "Queue[TaskIns]",
143
+ queue: Queue[Message],
149
144
  nodes_mapping: NodeToPartitionMapping,
150
145
  f_stop: threading.Event,
151
146
  ) -> None:
152
- """Put TaskIns in a queue from State."""
147
+ """Put Messages in the queue from the LinkState."""
153
148
  while not f_stop.is_set():
154
149
  for node_id in nodes_mapping.keys():
155
- task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
156
- for task_ins in task_ins_list:
157
- queue.put(task_ins)
150
+ message_ins_list = state.get_message_ins(node_id=node_id, limit=1)
151
+ for msg in message_ins_list:
152
+ queue.put(msg)
158
153
  sleep(0.1)
159
154
 
160
155
 
161
- def put_taskres_into_state(
162
- state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
156
+ def put_message_into_state(
157
+ state: LinkState, queue: Queue[Message], f_stop: threading.Event
163
158
  ) -> None:
164
- """Put TaskRes into State from a queue."""
159
+ """Store reply Messages into the LinkState from the queue."""
165
160
  while not f_stop.is_set():
166
161
  try:
167
- taskres = queue.get(timeout=1.0)
168
- state.store_task_res(taskres)
162
+ message_reply = queue.get(timeout=1.0)
163
+ state.store_message_res(message_reply)
169
164
  except Empty:
170
165
  # queue is empty when timeout was triggered
171
166
  pass
@@ -181,8 +176,8 @@ def run_api(
181
176
  f_stop: threading.Event,
182
177
  ) -> None:
183
178
  """Run the VCE."""
184
- taskins_queue: Queue[TaskIns] = Queue()
185
- taskres_queue: Queue[TaskRes] = Queue()
179
+ messageins_queue: Queue[Message] = Queue()
180
+ messageres_queue: Queue[Message] = Queue()
186
181
 
187
182
  try:
188
183
 
@@ -196,10 +191,10 @@ def run_api(
196
191
  state = state_factory.state()
197
192
 
198
193
  extractor_th = threading.Thread(
199
- target=add_taskins_to_queue,
194
+ target=add_messages_to_queue,
200
195
  args=(
201
196
  state,
202
- taskins_queue,
197
+ messageins_queue,
203
198
  nodes_mapping,
204
199
  f_stop,
205
200
  ),
@@ -207,10 +202,10 @@ def run_api(
207
202
  extractor_th.start()
208
203
 
209
204
  injector_th = threading.Thread(
210
- target=put_taskres_into_state,
205
+ target=put_message_into_state,
211
206
  args=(
212
207
  state,
213
- taskres_queue,
208
+ messageres_queue,
214
209
  f_stop,
215
210
  ),
216
211
  )
@@ -220,8 +215,8 @@ def run_api(
220
215
  _ = [
221
216
  executor.submit(
222
217
  worker,
223
- taskins_queue,
224
- taskres_queue,
218
+ messageins_queue,
219
+ messageres_queue,
225
220
  node_info_stores,
226
221
  backend,
227
222
  f_stop,