flwr 1.15.1__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/new.py +1 -1
  4. flwr/cli/new/templates/app/README.baseline.md.tpl +4 -4
  5. flwr/cli/new/templates/app/README.md.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  13. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  14. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  15. flwr/client/client_app.py +147 -36
  16. flwr/client/clientapp/app.py +4 -0
  17. flwr/client/message_handler/message_handler.py +1 -1
  18. flwr/client/rest_client/connection.py +4 -6
  19. flwr/client/supernode/__init__.py +0 -2
  20. flwr/client/supernode/app.py +1 -11
  21. flwr/common/address.py +35 -0
  22. flwr/common/args.py +8 -2
  23. flwr/common/auth_plugin/auth_plugin.py +2 -1
  24. flwr/common/constant.py +16 -0
  25. flwr/common/event_log_plugin/__init__.py +22 -0
  26. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  27. flwr/common/grpc.py +1 -1
  28. flwr/common/message.py +18 -7
  29. flwr/common/object_ref.py +0 -10
  30. flwr/common/record/conversion_utils.py +8 -17
  31. flwr/common/record/parametersrecord.py +151 -16
  32. flwr/common/record/recordset.py +95 -88
  33. flwr/common/secure_aggregation/quantization.py +5 -1
  34. flwr/common/serde.py +8 -126
  35. flwr/common/telemetry.py +0 -10
  36. flwr/common/typing.py +36 -0
  37. flwr/server/app.py +18 -2
  38. flwr/server/compat/app.py +4 -1
  39. flwr/server/compat/app_utils.py +10 -2
  40. flwr/server/compat/driver_client_proxy.py +2 -2
  41. flwr/server/driver/driver.py +1 -1
  42. flwr/server/driver/grpc_driver.py +10 -1
  43. flwr/server/driver/inmemory_driver.py +17 -21
  44. flwr/server/run_serverapp.py +2 -13
  45. flwr/server/server_app.py +93 -20
  46. flwr/server/superlink/driver/serverappio_servicer.py +27 -33
  47. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  48. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -16
  49. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  50. flwr/server/superlink/fleet/vce/vce_api.py +32 -36
  51. flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
  52. flwr/server/superlink/linkstate/linkstate.py +47 -60
  53. flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -282
  54. flwr/server/superlink/linkstate/utils.py +91 -119
  55. flwr/server/utils/__init__.py +2 -2
  56. flwr/server/utils/validator.py +53 -71
  57. flwr/server/workflow/default_workflows.py +4 -1
  58. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
  59. flwr/superexec/app.py +0 -14
  60. flwr/superexec/exec_servicer.py +4 -4
  61. flwr/superexec/exec_user_auth_interceptor.py +5 -3
  62. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/METADATA +5 -5
  63. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/RECORD +66 -69
  64. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
  65. flwr/client/message_handler/task_handler.py +0 -37
  66. flwr/proto/task_pb2.py +0 -33
  67. flwr/proto/task_pb2.pyi +0 -103
  68. flwr/proto/task_pb2_grpc.py +0 -4
  69. flwr/proto/task_pb2_grpc.pyi +0 -4
  70. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
  71. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
@@ -23,7 +23,6 @@ from uuid import UUID
23
23
 
24
24
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
25
  from flwr.common.constant import SUPERLINK_NODE_ID
26
- from flwr.common.serde import message_from_taskres, message_to_taskins
27
26
  from flwr.common.typing import Run
28
27
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
28
  from flwr.server.superlink.linkstate import LinkStateFactory
@@ -60,6 +59,7 @@ class InMemoryDriver(Driver):
60
59
  and message.metadata.message_id == ""
61
60
  and message.metadata.reply_to_message == ""
62
61
  and message.metadata.ttl > 0
62
+ and message.metadata.delivered_at == ""
63
63
  ):
64
64
  raise ValueError(f"Invalid message: {message}")
65
65
 
@@ -109,9 +109,9 @@ class InMemoryDriver(Driver):
109
109
  )
110
110
  return Message(metadata=metadata, content=content)
111
111
 
112
- def get_node_ids(self) -> list[int]:
112
+ def get_node_ids(self) -> Iterable[int]:
113
113
  """Get node IDs."""
114
- return list(self.state.get_nodes(cast(Run, self._run).run_id))
114
+ return self.state.get_nodes(cast(Run, self._run).run_id)
115
115
 
116
116
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
117
117
  """Push messages to specified node IDs.
@@ -119,19 +119,16 @@ class InMemoryDriver(Driver):
119
119
  This method takes an iterable of messages and sends each message
120
120
  to the node specified in `dst_node_id`.
121
121
  """
122
- task_ids: list[str] = []
122
+ msg_ids: list[str] = []
123
123
  for msg in messages:
124
124
  # Check message
125
125
  self._check_message(msg)
126
- # Convert Message to TaskIns
127
- taskins = message_to_taskins(msg)
128
126
  # Store in state
129
- taskins.task.pushed_at = time.time()
130
- task_id = self.state.store_task_ins(taskins)
131
- if task_id:
132
- task_ids.append(str(task_id))
127
+ msg_id = self.state.store_message_ins(msg)
128
+ if msg_id:
129
+ msg_ids.append(str(msg_id))
133
130
 
134
- return task_ids
131
+ return msg_ids
135
132
 
136
133
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
137
134
  """Pull messages based on message IDs.
@@ -140,17 +137,16 @@ class InMemoryDriver(Driver):
140
137
  set of given message IDs.
141
138
  """
142
139
  msg_ids = {UUID(msg_id) for msg_id in message_ids}
143
- # Pull TaskRes
144
- task_res_list = self.state.get_task_res(task_ids=msg_ids)
145
- # Delete tasks in state
146
- # Delete the TaskIns/TaskRes pairs if TaskRes is found
147
- task_ins_ids_to_delete = {
148
- UUID(task_res.task.ancestry[0]) for task_res in task_res_list
140
+ # Pull Messages
141
+ message_res_list = self.state.get_message_res(message_ids=msg_ids)
142
+ # Get IDs of Messages these replies are for
143
+ message_ins_ids_to_delete = {
144
+ UUID(msg_res.metadata.reply_to_message) for msg_res in message_res_list
149
145
  }
150
- self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
151
- # Convert TaskRes to Message
152
- msgs = [message_from_taskres(taskres) for taskres in task_res_list]
153
- return msgs
146
+ # Delete
147
+ self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
148
+
149
+ return message_res_list
154
150
 
155
151
  def send_and_receive(
156
152
  self,
@@ -15,11 +15,10 @@
15
15
  """Run ServerApp."""
16
16
 
17
17
 
18
- from logging import DEBUG, ERROR
18
+ from logging import DEBUG
19
19
  from typing import Optional
20
20
 
21
- from flwr.common import Context, EventType, event
22
- from flwr.common.exit_handlers import register_exit_handlers
21
+ from flwr.common import Context
23
22
  from flwr.common.logger import log
24
23
  from flwr.common.object_ref import load_app
25
24
 
@@ -64,13 +63,3 @@ def run(
64
63
 
65
64
  log(DEBUG, "ServerApp finished running.")
66
65
  return context
67
-
68
-
69
- def run_server_app() -> None:
70
- """Run Flower server app."""
71
- event(EventType.RUN_SERVER_APP_ENTER)
72
- log(
73
- ERROR,
74
- "The command `flower-server-app` has been replaced by `flwr run`.",
75
- )
76
- register_exit_handlers(event_type=EventType.RUN_SERVER_APP_LEAVE)
flwr/server/server_app.py CHANGED
@@ -15,6 +15,8 @@
15
15
  """Flower ServerApp."""
16
16
 
17
17
 
18
+ from collections.abc import Iterator
19
+ from contextlib import contextmanager
18
20
  from typing import Callable, Optional
19
21
 
20
22
  from flwr.common import Context
@@ -45,7 +47,12 @@ SERVER_FN_USAGE_EXAMPLE = """
45
47
  """
46
48
 
47
49
 
48
- class ServerApp:
50
+ @contextmanager
51
+ def _empty_lifespan(_: Context) -> Iterator[None]:
52
+ yield
53
+
54
+
55
+ class ServerApp: # pylint: disable=too-many-instance-attributes
49
56
  """Flower ServerApp.
50
57
 
51
58
  Examples
@@ -105,29 +112,31 @@ class ServerApp:
105
112
  self._client_manager = client_manager
106
113
  self._server_fn = server_fn
107
114
  self._main: Optional[ServerAppCallable] = None
115
+ self._lifespan = _empty_lifespan
108
116
 
109
117
  def __call__(self, driver: Driver, context: Context) -> None:
110
118
  """Execute `ServerApp`."""
111
- # Compatibility mode
112
- if not self._main:
113
- if self._server_fn:
114
- # Execute server_fn()
115
- components = self._server_fn(context)
116
- self._server = components.server
117
- self._config = components.config
118
- self._strategy = components.strategy
119
- self._client_manager = components.client_manager
120
- start_driver(
121
- server=self._server,
122
- config=self._config,
123
- strategy=self._strategy,
124
- client_manager=self._client_manager,
125
- driver=driver,
126
- )
127
- return
119
+ with self._lifespan(context):
120
+ # Compatibility mode
121
+ if not self._main:
122
+ if self._server_fn:
123
+ # Execute server_fn()
124
+ components = self._server_fn(context)
125
+ self._server = components.server
126
+ self._config = components.config
127
+ self._strategy = components.strategy
128
+ self._client_manager = components.client_manager
129
+ start_driver(
130
+ server=self._server,
131
+ config=self._config,
132
+ strategy=self._strategy,
133
+ client_manager=self._client_manager,
134
+ driver=driver,
135
+ )
136
+ return
128
137
 
129
- # New execution mode
130
- self._main(driver, context)
138
+ # New execution mode
139
+ self._main(driver, context)
131
140
 
132
141
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
133
142
  """Return a decorator that registers the main fn with the server app.
@@ -177,6 +186,70 @@ class ServerApp:
177
186
 
178
187
  return main_decorator
179
188
 
189
+ def lifespan(
190
+ self,
191
+ ) -> Callable[
192
+ [Callable[[Context], Iterator[None]]], Callable[[Context], Iterator[None]]
193
+ ]:
194
+ """Return a decorator that registers the lifespan fn with the server app.
195
+
196
+ The decorated function should accept a `Context` object and use `yield`
197
+ to define enter and exit behavior.
198
+
199
+ Examples
200
+ --------
201
+ >>> app = ServerApp()
202
+ >>>
203
+ >>> @app.lifespan()
204
+ >>> def lifespan(context: Context) -> None:
205
+ >>> # Perform initialization tasks before the app starts
206
+ >>> print("Initializing ServerApp")
207
+ >>>
208
+ >>> yield # ServerApp is running
209
+ >>>
210
+ >>> # Perform cleanup tasks after the app stops
211
+ >>> print("Cleaning up ServerApp")
212
+ """
213
+
214
+ def lifespan_decorator(
215
+ lifespan_fn: Callable[[Context], Iterator[None]]
216
+ ) -> Callable[[Context], Iterator[None]]:
217
+ """Register the lifespan fn with the ServerApp object."""
218
+ warn_preview_feature("ServerApp-register-lifespan-function")
219
+
220
+ @contextmanager
221
+ def decorated_lifespan(context: Context) -> Iterator[None]:
222
+ # Execute the code before `yield` in lifespan_fn
223
+ try:
224
+ if not isinstance(it := lifespan_fn(context), Iterator):
225
+ raise StopIteration
226
+ next(it)
227
+ except StopIteration:
228
+ raise RuntimeError(
229
+ "lifespan function should yield at least once."
230
+ ) from None
231
+
232
+ try:
233
+ # Enter the context
234
+ yield
235
+ finally:
236
+ try:
237
+ # Execute the code after `yield` in lifespan_fn
238
+ next(it)
239
+ except StopIteration:
240
+ pass
241
+ else:
242
+ raise RuntimeError("lifespan function should only yield once.")
243
+
244
+ # Register provided function with the ServerApp object
245
+ # Ignore mypy error because of different argument names (`_` vs `context`)
246
+ self._lifespan = decorated_lifespan # type: ignore
247
+
248
+ # Return provided function unmodified
249
+ return lifespan_fn
250
+
251
+ return lifespan_decorator
252
+
180
253
 
181
254
  class LoadServerAppError(Exception):
182
255
  """Error when trying to load `ServerApp`."""
@@ -22,8 +22,8 @@ from uuid import UUID
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import ConfigsRecord, now
26
- from flwr.common.constant import Status
25
+ from flwr.common import ConfigsRecord, Message
26
+ from flwr.common.constant import SUPERLINK_NODE_ID, Status
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
29
29
  context_from_proto,
@@ -31,9 +31,7 @@ from flwr.common.serde import (
31
31
  fab_from_proto,
32
32
  fab_to_proto,
33
33
  message_from_proto,
34
- message_from_taskres,
35
34
  message_to_proto,
36
- message_to_taskins,
37
35
  run_status_from_proto,
38
36
  run_status_to_proto,
39
37
  run_to_proto,
@@ -69,12 +67,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
69
67
  PushServerAppOutputsRequest,
70
68
  PushServerAppOutputsResponse,
71
69
  )
72
- from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
73
70
  from flwr.server.superlink.ffs.ffs import Ffs
74
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
75
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
76
73
  from flwr.server.superlink.utils import abort_if
77
- from flwr.server.utils.validator import validate_task_ins_or_res
74
+ from flwr.server.utils.validator import validate_message
78
75
 
79
76
 
80
77
  class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
@@ -151,9 +148,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
151
148
  context,
152
149
  )
153
150
 
154
- # Set pushed_at (timestamp in seconds)
155
- pushed_at = now().timestamp()
156
-
157
151
  # Validate request and insert in State
158
152
  _raise_if(
159
153
  validation_error=len(request.messages_list) == 0,
@@ -164,21 +158,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
164
158
  while request.messages_list:
165
159
  message_proto = request.messages_list.pop(0)
166
160
  message = message_from_proto(message_proto=message_proto)
167
- task_ins = message_to_taskins(message=message)
168
- task_ins.task.pushed_at = pushed_at
169
- validation_errors = validate_task_ins_or_res(task_ins)
161
+ validation_errors = validate_message(message, is_reply_message=False)
170
162
  _raise_if(
171
163
  validation_error=bool(validation_errors),
172
164
  request_name="PushMessages",
173
165
  detail=", ".join(validation_errors),
174
166
  )
175
167
  _raise_if(
176
- validation_error=request.run_id != task_ins.run_id,
168
+ validation_error=request.run_id != message.metadata.run_id,
177
169
  request_name="PushMessages",
178
- detail="`task_ins` has mismatched `run_id`",
170
+ detail="`Message.metadata` has mismatched `run_id`",
179
171
  )
180
172
  # Store
181
- message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
173
+ message_id: Optional[UUID] = state.store_message_ins(message=message)
182
174
  message_ids.append(message_id)
183
175
 
184
176
  return PushInsMessagesResponse(
@@ -204,32 +196,34 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
204
196
  context,
205
197
  )
206
198
 
207
- # Convert each task_id str to UUID
199
+ # Convert each message_id str to UUID
208
200
  message_ids: set[UUID] = {
209
201
  UUID(message_id) for message_id in request.message_ids
210
202
  }
211
203
 
212
204
  # Read from state
213
- task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)
205
+ messages_res: list[Message] = state.get_message_res(message_ids=message_ids)
214
206
 
215
- # Convert to Messages
216
- messages_list = []
217
- while task_res_list:
218
- task_res = task_res_list.pop(0)
219
- _raise_if(
220
- validation_error=request.run_id != task_res.run_id,
221
- request_name="PullMessages",
222
- detail="`task_res` has mismatched `run_id`",
223
- )
224
- message = message_from_taskres(taskres=task_res)
225
- messages_list.append(message_to_proto(message))
226
-
227
- # Delete the TaskIns/TaskRes pairs if TaskRes is found
228
- task_ins_ids_to_delete = {
229
- UUID(task_res.task.ancestry[0]) for task_res in task_res_list
207
+ # Delete the instruction Messages and their replies if found
208
+ message_ins_ids_to_delete = {
209
+ UUID(msg_res.metadata.reply_to_message) for msg_res in messages_res
230
210
  }
231
211
 
232
- state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
212
+ state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
213
+
214
+ # Convert Messages to proto
215
+ messages_list = []
216
+ while messages_res:
217
+ msg = messages_res.pop(0)
218
+
219
+ # Skip `run_id` check for SuperLink generated replies
220
+ if msg.metadata.src_node_id != SUPERLINK_NODE_ID:
221
+ _raise_if(
222
+ validation_error=request.run_id != msg.metadata.run_id,
223
+ request_name="PullMessages",
224
+ detail="`message.metadata` has mismatched `run_id`",
225
+ )
226
+ messages_list.append(message_to_proto(msg))
233
227
 
234
228
  return PullResMessagesResponse(messages_list=messages_list)
235
229
 
@@ -103,11 +103,11 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
103
103
  if request.messages_list:
104
104
  log(
105
105
  INFO,
106
- "[Fleet.PushMessages] Push results from node_id=%s",
106
+ "[Fleet.PushMessages] Push replies from node_id=%s",
107
107
  request.messages_list[0].metadata.src_node_id,
108
108
  )
109
109
  else:
110
- log(INFO, "[Fleet.PushMessages] No task results to push")
110
+ log(INFO, "[Fleet.PushMessages] No replies to push")
111
111
 
112
112
  try:
113
113
  res = message_handler.push_messages(
@@ -15,17 +15,15 @@
15
15
  """Fleet API message handlers."""
16
16
 
17
17
 
18
- import time
19
18
  from typing import Optional
20
19
  from uuid import UUID
21
20
 
21
+ from flwr.common import Message
22
22
  from flwr.common.constant import Status
23
23
  from flwr.common.serde import (
24
24
  fab_to_proto,
25
25
  message_from_proto,
26
- message_from_taskins,
27
26
  message_to_proto,
28
- message_to_taskres,
29
27
  user_config_to_proto,
30
28
  )
31
29
  from flwr.common.typing import Fab, InvalidRunStatusException
@@ -49,7 +47,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
49
47
  GetRunResponse,
50
48
  Run,
51
49
  )
52
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
53
50
  from flwr.server.superlink.ffs.ffs import Ffs
54
51
  from flwr.server.superlink.linkstate import LinkState
55
52
  from flwr.server.superlink.utils import check_abort
@@ -93,13 +90,12 @@ def pull_messages(
93
90
  node = request.node # pylint: disable=no-member
94
91
  node_id: int = node.node_id
95
92
 
96
- # Retrieve TaskIns from State
97
- task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
93
+ # Retrieve Message from State
94
+ message_list: list[Message] = state.get_message_ins(node_id=node_id, limit=1)
98
95
 
99
96
  # Convert to Messages
100
97
  msg_proto = []
101
- for task_ins in task_ins_list:
102
- msg = message_from_taskins(task_ins)
98
+ for msg in message_list:
103
99
  msg_proto.append(message_to_proto(msg))
104
100
 
105
101
  return PullMessagesResponse(messages_list=msg_proto)
@@ -109,24 +105,20 @@ def push_messages(
109
105
  request: PushMessagesRequest, state: LinkState
110
106
  ) -> PushMessagesResponse:
111
107
  """Push Messages handler."""
112
- # Convert Message to TaskRes
108
+ # Convert Message from proto
113
109
  msg = message_from_proto(message_proto=request.messages_list[0])
114
- task_res = message_to_taskres(msg)
115
110
 
116
111
  # Abort if the run is not running
117
112
  abort_msg = check_abort(
118
- task_res.run_id,
113
+ msg.metadata.run_id,
119
114
  [Status.PENDING, Status.STARTING, Status.FINISHED],
120
115
  state,
121
116
  )
122
117
  if abort_msg:
123
118
  raise InvalidRunStatusException(abort_msg)
124
119
 
125
- # Set pushed_at (timestamp in seconds)
126
- task_res.task.pushed_at = time.time()
127
-
128
- # Store TaskRes in State
129
- message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
120
+ # Store Message in State
121
+ message_id: Optional[UUID] = state.store_message_res(message=msg)
130
122
 
131
123
  # Build response
132
124
  response = PushMessagesResponse(
@@ -45,7 +45,7 @@ class Backend(ABC):
45
45
  def num_workers(self) -> int:
46
46
  """Return number of workers in the backend.
47
47
 
48
- This is the number of TaskIns that can be processed concurrently.
48
+ This is the number of Messages that can be processed concurrently.
49
49
  """
50
50
  return 0
51
51
 
@@ -29,6 +29,7 @@ from typing import Callable, Optional
29
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
30
30
  from flwr.client.clientapp.utils import get_load_client_app_fn
31
31
  from flwr.client.run_info_store import DeprecatedRunInfoStore
32
+ from flwr.common import Message
32
33
  from flwr.common.constant import (
33
34
  NUM_PARTITIONS_KEY,
34
35
  PARTITION_ID_KEY,
@@ -37,9 +38,7 @@ from flwr.common.constant import (
37
38
  )
38
39
  from flwr.common.logger import log
39
40
  from flwr.common.message import Error
40
- from flwr.common.serde import message_from_taskins, message_to_taskres
41
41
  from flwr.common.typing import Run
42
- from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
42
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
44
43
 
45
44
  from .backend import Backend, error_messages_backends, supported_backends
@@ -87,33 +86,33 @@ def _register_node_info_stores(
87
86
 
88
87
  # pylint: disable=too-many-arguments,too-many-locals
89
88
  def worker(
90
- taskins_queue: "Queue[TaskIns]",
91
- taskres_queue: "Queue[TaskRes]",
89
+ messageins_queue: Queue[Message],
90
+ messageres_queue: Queue[Message],
92
91
  node_info_store: dict[int, DeprecatedRunInfoStore],
93
92
  backend: Backend,
94
93
  f_stop: threading.Event,
95
94
  ) -> None:
96
- """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
95
+ """Process messages from the queue, execute them, update context, and enqueue
96
+ replies."""
97
97
  while not f_stop.is_set():
98
98
  out_mssg = None
99
99
  try:
100
100
  # Fetch from queue with timeout. We use a timeout so
101
101
  # the stopping event can be evaluated even when the queue is empty.
102
- task_ins: TaskIns = taskins_queue.get(timeout=1.0)
103
- node_id = task_ins.task.consumer.node_id
102
+ message: Message = messageins_queue.get(timeout=1.0)
103
+ node_id = message.metadata.dst_node_id
104
104
 
105
105
  # Retrieve context
106
- context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
107
-
108
- # Convert TaskIns to Message
109
- message = message_from_taskins(task_ins)
106
+ context = node_info_store[node_id].retrieve_context(
107
+ run_id=message.metadata.run_id
108
+ )
110
109
 
111
110
  # Let backend process message
112
111
  out_mssg, updated_context = backend.process_message(message, context)
113
112
 
114
113
  # Update Context
115
114
  node_info_store[node_id].update_context(
116
- task_ins.run_id, context=updated_context
115
+ message.metadata.run_id, context=updated_context
117
116
  )
118
117
  except Empty:
119
118
  # An exception raised if queue.get times out
@@ -137,36 +136,33 @@ def worker(
137
136
 
138
137
  finally:
139
138
  if out_mssg:
140
- # Convert to TaskRes
141
- task_res = message_to_taskres(out_mssg)
142
- # Store TaskRes in state
143
- task_res.task.pushed_at = time.time()
144
- taskres_queue.put(task_res)
139
+ # Store reply Messages in state
140
+ messageres_queue.put(out_mssg)
145
141
 
146
142
 
147
- def add_taskins_to_queue(
143
+ def add_messages_to_queue(
148
144
  state: LinkState,
149
- queue: "Queue[TaskIns]",
145
+ queue: Queue[Message],
150
146
  nodes_mapping: NodeToPartitionMapping,
151
147
  f_stop: threading.Event,
152
148
  ) -> None:
153
- """Put TaskIns in a queue from State."""
149
+ """Put Messages in the queue from the LinkState."""
154
150
  while not f_stop.is_set():
155
151
  for node_id in nodes_mapping.keys():
156
- task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
157
- for task_ins in task_ins_list:
158
- queue.put(task_ins)
152
+ message_ins_list = state.get_message_ins(node_id=node_id, limit=1)
153
+ for msg in message_ins_list:
154
+ queue.put(msg)
159
155
  sleep(0.1)
160
156
 
161
157
 
162
- def put_taskres_into_state(
163
- state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
158
+ def put_message_into_state(
159
+ state: LinkState, queue: Queue[Message], f_stop: threading.Event
164
160
  ) -> None:
165
- """Put TaskRes into State from a queue."""
161
+ """Store reply Messages into the LinkState from the queue."""
166
162
  while not f_stop.is_set():
167
163
  try:
168
- taskres = queue.get(timeout=1.0)
169
- state.store_task_res(taskres)
164
+ message_reply = queue.get(timeout=1.0)
165
+ state.store_message_res(message_reply)
170
166
  except Empty:
171
167
  # queue is empty when timeout was triggered
172
168
  pass
@@ -182,8 +178,8 @@ def run_api(
182
178
  f_stop: threading.Event,
183
179
  ) -> None:
184
180
  """Run the VCE."""
185
- taskins_queue: Queue[TaskIns] = Queue()
186
- taskres_queue: Queue[TaskRes] = Queue()
181
+ messageins_queue: Queue[Message] = Queue()
182
+ messageres_queue: Queue[Message] = Queue()
187
183
 
188
184
  try:
189
185
 
@@ -197,10 +193,10 @@ def run_api(
197
193
  state = state_factory.state()
198
194
 
199
195
  extractor_th = threading.Thread(
200
- target=add_taskins_to_queue,
196
+ target=add_messages_to_queue,
201
197
  args=(
202
198
  state,
203
- taskins_queue,
199
+ messageins_queue,
204
200
  nodes_mapping,
205
201
  f_stop,
206
202
  ),
@@ -208,10 +204,10 @@ def run_api(
208
204
  extractor_th.start()
209
205
 
210
206
  injector_th = threading.Thread(
211
- target=put_taskres_into_state,
207
+ target=put_message_into_state,
212
208
  args=(
213
209
  state,
214
- taskres_queue,
210
+ messageres_queue,
215
211
  f_stop,
216
212
  ),
217
213
  )
@@ -221,8 +217,8 @@ def run_api(
221
217
  _ = [
222
218
  executor.submit(
223
219
  worker,
224
- taskins_queue,
225
- taskres_queue,
220
+ messageins_queue,
221
+ messageres_queue,
226
222
  node_info_stores,
227
223
  backend,
228
224
  f_stop,