flwr-nightly 1.19.0.dev20250529__py3-none-any.whl → 1.19.0.dev20250530__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flwr/cli/utils.py CHANGED
@@ -291,9 +291,9 @@ def init_channel(
291
291
  def flwr_cli_grpc_exc_handler() -> Iterator[None]:
292
292
  """Context manager to handle specific gRPC errors.
293
293
 
294
- It catches grpc.RpcError exceptions with UNAUTHENTICATED and UNIMPLEMENTED statuses,
295
- informs the user, and exits the application. All other exceptions will be allowed to
296
- escape.
294
+ It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED, and
295
+ PERMISSION_DENIED statuses, informs the user, and exits the application. All other
296
+ exceptions will be allowed to escape.
297
297
  """
298
298
  try:
299
299
  yield
@@ -313,4 +313,12 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
313
313
  bold=True,
314
314
  )
315
315
  raise typer.Exit(code=1) from None
316
+ if e.code() == grpc.StatusCode.PERMISSION_DENIED:
317
+ typer.secho(
318
+ "❌ Authorization failed. Please contact your administrator"
319
+ " to check your permissions.",
320
+ fg=typer.colors.RED,
321
+ bold=True,
322
+ )
323
+ raise typer.Exit(code=1) from None
316
324
  raise
@@ -63,7 +63,7 @@ class ExecAuthPlugin(ABC):
63
63
  @abstractmethod
64
64
  def refresh_tokens(
65
65
  self, metadata: Sequence[tuple[str, Union[str, bytes]]]
66
- ) -> Optional[Sequence[tuple[str, Union[str, bytes]]]]:
66
+ ) -> tuple[Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[UserInfo]]:
67
67
  """Refresh authentication tokens in the provided metadata."""
68
68
 
69
69
 
@@ -30,6 +30,7 @@ SIGNAL_TO_EXIT_CODE: dict[int, int] = {
30
30
  signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT,
31
31
  signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM,
32
32
  }
33
+ registered_exit_handlers: list[Callable[[], None]] = []
33
34
 
34
35
  # SIGQUIT is not available on Windows
35
36
  if hasattr(signal, "SIGQUIT"):
@@ -41,6 +42,7 @@ def register_exit_handlers(
41
42
  exit_message: Optional[str] = None,
42
43
  grpc_servers: Optional[list[Server]] = None,
43
44
  bckg_threads: Optional[list[Thread]] = None,
45
+ exit_handlers: Optional[list[Callable[[], None]]] = None,
44
46
  ) -> None:
45
47
  """Register exit handlers for `SIGINT`, `SIGTERM` and `SIGQUIT` signals.
46
48
 
@@ -56,8 +58,12 @@ def register_exit_handlers(
56
58
  bckg_threads: Optional[List[Thread]] (default: None)
57
59
  An optional list of threads that need to be gracefully
58
60
  terminated before exiting.
61
+ exit_handlers: Optional[List[Callable[[], None]]] (default: None)
62
+ An optional list of exit handlers to be called before exiting.
63
+ Additional exit handlers can be added using `add_exit_handler`.
59
64
  """
60
65
  default_handlers: dict[int, Callable[[int, FrameType], None]] = {}
66
+ registered_exit_handlers.extend(exit_handlers or [])
61
67
 
62
68
  def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
63
69
  """Exit handler to be registered with `signal.signal`.
@@ -68,6 +74,9 @@ def register_exit_handlers(
68
74
  # Reset to default handler
69
75
  signal.signal(signalnum, default_handlers[signalnum]) # type: ignore
70
76
 
77
+ for handler in registered_exit_handlers:
78
+ handler()
79
+
71
80
  if grpc_servers is not None:
72
81
  for grpc_server in grpc_servers:
73
82
  grpc_server.stop(grace=1)
@@ -87,3 +96,24 @@ def register_exit_handlers(
87
96
  for sig in SIGNAL_TO_EXIT_CODE:
88
97
  default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore
89
98
  default_handlers[sig] = default_handler # type: ignore
99
+
100
+
101
+ def add_exit_handler(exit_handler: Callable[[], None]) -> None:
102
+ """Add an exit handler to be called on graceful exit.
103
+
104
+ This function allows you to register additional exit handlers
105
+ that will be executed when the application exits gracefully,
106
+ if `register_exit_handlers` was called.
107
+
108
+ Parameters
109
+ ----------
110
+ exit_handler : Callable[[], None]
111
+ A callable that takes no arguments and performs cleanup or
112
+ other actions before the application exits.
113
+
114
+ Notes
115
+ -----
116
+ This method is not thread-safe, and it allows you to add the
117
+ same exit handler multiple times.
118
+ """
119
+ registered_exit_handlers.append(exit_handler)
@@ -15,7 +15,7 @@
15
15
  """InflatableObject utils."""
16
16
 
17
17
 
18
- from typing import Union
18
+ from typing import Optional, Union
19
19
 
20
20
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
21
21
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
@@ -24,6 +24,7 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
24
24
  PushObjectRequest,
25
25
  PushObjectResponse,
26
26
  )
27
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
27
28
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
28
29
 
29
30
  from .inflatable import (
@@ -46,40 +47,51 @@ inflatable_class_registry: dict[str, type[InflatableObject]] = {
46
47
 
47
48
 
48
49
  def push_object_to_servicer(
49
- obj: InflatableObject, stub: Union[FleetStub, ServerAppIoStub]
50
+ obj: InflatableObject,
51
+ stub: Union[FleetStub, ServerAppIoStub],
52
+ node: Node,
53
+ object_ids_to_push: Optional[set[str]] = None,
50
54
  ) -> set[str]:
51
55
  """Recursively deflate an object and push it to the servicer.
52
56
 
53
- Objects with the same ID are not pushed twice. It returns the set of pushed object
57
+ Objects with the same ID are not pushed twice. If `object_ids_to_push` is set,
58
+ only objects with those IDs are pushed. It returns the set of pushed object
54
59
  IDs.
55
60
  """
56
61
  pushed_object_ids: set[str] = set()
57
62
  # Push children if it has any
58
63
  if children := obj.children:
59
64
  for child in children.values():
60
- pushed_object_ids |= push_object_to_servicer(child, stub)
65
+ pushed_object_ids |= push_object_to_servicer(
66
+ child, stub, node, object_ids_to_push
67
+ )
61
68
 
62
69
  # Deflate object and push
63
70
  object_content = obj.deflate()
64
71
  object_id = get_object_id(object_content)
65
- _: PushObjectResponse = stub.PushObject(
66
- PushObjectRequest(
67
- object_id=object_id,
68
- object_content=object_content,
72
+ # Push always if no object set is specified, or if the object is in the set
73
+ if object_ids_to_push is None or object_id in object_ids_to_push:
74
+ _: PushObjectResponse = stub.PushObject(
75
+ PushObjectRequest(
76
+ node=node,
77
+ object_id=object_id,
78
+ object_content=object_content,
79
+ )
69
80
  )
70
- )
71
- pushed_object_ids.add(object_id)
81
+ pushed_object_ids.add(object_id)
72
82
 
73
83
  return pushed_object_ids
74
84
 
75
85
 
76
86
  def pull_object_from_servicer(
77
- object_id: str, stub: Union[FleetStub, ServerAppIoStub]
87
+ object_id: str,
88
+ stub: Union[FleetStub, ServerAppIoStub],
89
+ node: Node,
78
90
  ) -> InflatableObject:
79
91
  """Recursively inflate an object by pulling it from the servicer."""
80
92
  # Pull object
81
93
  object_proto: PullObjectResponse = stub.PullObject(
82
- PullObjectRequest(object_id=object_id)
94
+ PullObjectRequest(node=node, object_id=object_id)
83
95
  )
84
96
  object_content = object_proto.object_content
85
97
 
@@ -93,7 +105,9 @@ def pull_object_from_servicer(
93
105
  # Pull all children objects
94
106
  children: dict[str, InflatableObject] = {}
95
107
  for child_object_id in children_obj_ids:
96
- children[child_object_id] = pull_object_from_servicer(child_object_id, stub)
108
+ children[child_object_id] = pull_object_from_servicer(
109
+ child_object_id, stub, node
110
+ )
97
111
 
98
112
  # Inflate object passing its children
99
113
  return cls_type.inflate(object_content, children=children)
@@ -62,8 +62,8 @@ class Array(InflatableObject):
62
62
  A string representing the data type of the serialized object (e.g. `"float32"`).
63
63
  Only required if you are not passing in a ndarray or a tensor.
64
64
 
65
- shape : Optional[list[int]] (default: None)
66
- A list representing the shape of the unserialized array-like object. Only
65
+ shape : Optional[tuple[int, ...]] (default: None)
66
+ A tuple representing the shape of the unserialized array-like object. Only
67
67
  required if you are not passing in a ndarray or a tensor.
68
68
 
69
69
  stype : Optional[str] (default: None)
@@ -107,24 +107,13 @@ class Array(InflatableObject):
107
107
  """
108
108
 
109
109
  dtype: str
110
+ shape: tuple[int, ...]
110
111
  stype: str
111
112
  data: bytes
112
113
 
113
- @property
114
- def shape(self) -> list[int]:
115
- """Get the shape of the array."""
116
- self.is_dirty = True # Mark as dirty when shape is accessed
117
- return cast(list[int], self.__dict__["_shape"])
118
-
119
- @shape.setter
120
- def shape(self, value: list[int]) -> None:
121
- """Set the shape of the array."""
122
- self.is_dirty = True # Mark as dirty when shape is set
123
- self.__dict__["_shape"] = value
124
-
125
114
  @overload
126
115
  def __init__( # noqa: E704
127
- self, dtype: str, shape: list[int], stype: str, data: bytes
116
+ self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
128
117
  ) -> None: ...
129
118
 
130
119
  @overload
@@ -137,7 +126,7 @@ class Array(InflatableObject):
137
126
  self,
138
127
  *args: Any,
139
128
  dtype: str | None = None,
140
- shape: list[int] | None = None,
129
+ shape: tuple[int, ...] | None = None,
141
130
  stype: str | None = None,
142
131
  data: bytes | None = None,
143
132
  ndarray: NDArray | None = None,
@@ -145,7 +134,7 @@ class Array(InflatableObject):
145
134
  ) -> None:
146
135
  # Determine the initialization method and validate input arguments.
147
136
  # Support three initialization formats:
148
- # 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
137
+ # 1. Array(dtype: str, shape: tuple[int, ...], stype: str, data: bytes)
149
138
  # 2. Array(ndarray: NDArray)
150
139
  # 3. Array(torch_tensor: torch.Tensor)
151
140
 
@@ -192,7 +181,7 @@ class Array(InflatableObject):
192
181
  if (
193
182
  len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
194
183
  and isinstance(all_args[0], str)
195
- and isinstance(all_args[1], list)
184
+ and isinstance(all_args[1], tuple)
196
185
  and all(isinstance(i, int) for i in all_args[1])
197
186
  and isinstance(all_args[2], str)
198
187
  and isinstance(all_args[3], bytes)
@@ -232,7 +221,7 @@ class Array(InflatableObject):
232
221
  data = buffer.getvalue()
233
222
  return Array(
234
223
  dtype=str(ndarray.dtype),
235
- shape=list(ndarray.shape),
224
+ shape=tuple(ndarray.shape),
236
225
  stype=SType.NUMPY,
237
226
  data=data,
238
227
  )
@@ -302,7 +291,7 @@ class Array(InflatableObject):
302
291
  proto_array = ArrayProto.FromString(obj_body)
303
292
  return cls(
304
293
  dtype=proto_array.dtype,
305
- shape=list(proto_array.shape),
294
+ shape=tuple(proto_array.shape),
306
295
  stype=proto_array.stype,
307
296
  data=proto_array.data,
308
297
  )
@@ -328,7 +317,7 @@ class Array(InflatableObject):
328
317
 
329
318
  def __setattr__(self, name: str, value: Any) -> None:
330
319
  """Set attribute with special handling for dirty state."""
331
- if name in ("dtype", "stype", "data"):
320
+ if name in ("dtype", "shape", "stype", "data"):
332
321
  # Mark as dirty if any of the main attributes are set
333
322
  self.is_dirty = True
334
323
  super().__setattr__(name, value)
@@ -252,7 +252,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
252
252
  record = ArrayRecord()
253
253
  for k, v in array_dict.items():
254
254
  record[k] = Array(
255
- dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
255
+ dtype=v.dtype, shape=tuple(v.shape), stype=v.stype, data=v.data
256
256
  )
257
257
  if not keep_input:
258
258
  array_dict.clear()
@@ -111,12 +111,12 @@ def parameters_to_arrayrecord(parameters: Parameters, keep_input: bool) -> Array
111
111
  else:
112
112
  tensor = parameters.tensors.pop(0)
113
113
  ordered_dict[str(idx)] = Array(
114
- data=tensor, dtype="", stype=tensor_type, shape=[]
114
+ data=tensor, dtype="", stype=tensor_type, shape=()
115
115
  )
116
116
 
117
117
  if num_arrays == 0:
118
118
  ordered_dict[EMPTY_TENSOR_KEY] = Array(
119
- data=b"", dtype="", stype=tensor_type, shape=[]
119
+ data=b"", dtype="", stype=tensor_type, shape=()
120
120
  )
121
121
  return ArrayRecord(ordered_dict, keep_input=keep_input)
122
122
 
flwr/common/serde.py CHANGED
@@ -390,7 +390,7 @@ def array_from_proto(array_proto: ProtoArray) -> Array:
390
390
  """Deserialize Array from ProtoBuf."""
391
391
  return Array(
392
392
  dtype=array_proto.dtype,
393
- shape=list(array_proto.shape),
393
+ shape=tuple(array_proto.shape),
394
394
  stype=array_proto.stype,
395
395
  data=array_proto.data,
396
396
  )
flwr/server/app.py CHANGED
@@ -27,7 +27,7 @@ from collections.abc import Sequence
27
27
  from logging import DEBUG, INFO, WARN
28
28
  from pathlib import Path
29
29
  from time import sleep
30
- from typing import Any, Callable, Optional, Union, cast
30
+ from typing import Any, Callable, Optional, TypeVar
31
31
 
32
32
  import grpc
33
33
  import yaml
@@ -85,6 +85,7 @@ from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
85
85
 
86
86
  DATABASE = ":flwr-in-memory-state:"
87
87
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
88
+ P = TypeVar("P", ExecAuthPlugin, ExecAuthzPlugin)
88
89
 
89
90
 
90
91
  try:
@@ -151,15 +152,13 @@ def run_superlink() -> None:
151
152
  verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
152
153
 
153
154
  auth_plugin: Optional[ExecAuthPlugin] = None
154
- authz_plugin: Optional[ExecAuthzPlugin] = None # pylint: disable=unused-variable
155
+ authz_plugin: Optional[ExecAuthzPlugin] = None
155
156
  event_log_plugin: Optional[EventLogWriterPlugin] = None
156
157
  # Load the auth plugin if the args.user_auth_config is provided
157
158
  if cfg_path := getattr(args, "user_auth_config", None):
158
- # pylint: disable=unused-variable
159
- auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins( # noqa: F841
159
+ auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins(
160
160
  Path(cfg_path), verify_tls_cert
161
161
  )
162
- # pylint: enable=unused-variable
163
162
  # Enable event logging if the args.enable_event_log is True
164
163
  if args.enable_event_log:
165
164
  event_log_plugin = _try_obtain_exec_event_log_writer_plugin()
@@ -185,6 +184,7 @@ def run_superlink() -> None:
185
184
  [args.executor_config] if args.executor_config else args.executor_config
186
185
  ),
187
186
  auth_plugin=auth_plugin,
187
+ authz_plugin=authz_plugin,
188
188
  event_log_plugin=event_log_plugin,
189
189
  )
190
190
  grpc_servers = [exec_server]
@@ -490,15 +490,13 @@ def _try_obtain_exec_auth_plugins(
490
490
  config: dict[str, Any] = yaml.safe_load(file)
491
491
 
492
492
  def _load_plugin(
493
- section: str,
494
- yaml_key: str,
495
- loader: Callable[[], dict[str, type[Union[ExecAuthPlugin, ExecAuthzPlugin]]]],
496
- ) -> Union[ExecAuthPlugin, ExecAuthzPlugin]:
493
+ section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
494
+ ) -> P:
497
495
  section_cfg = config.get(section, {})
498
496
  auth_plugin_name = section_cfg.get(yaml_key, "")
499
497
  try:
500
- plugins = loader()
501
- plugin_cls = plugins[auth_plugin_name]
498
+ plugins: dict[str, type[P]] = loader()
499
+ plugin_cls: type[P] = plugins[auth_plugin_name]
502
500
  return plugin_cls(
503
501
  user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
504
502
  )
@@ -513,23 +511,17 @@ def _try_obtain_exec_auth_plugins(
513
511
  sys.exit(f"No {section} plugins are currently supported.")
514
512
 
515
513
  # Load authentication plugin
516
- auth_plugin = cast(
517
- ExecAuthPlugin,
518
- _load_plugin(
519
- section="authentication",
520
- yaml_key=AUTH_TYPE_YAML_KEY,
521
- loader=get_exec_auth_plugins,
522
- ),
514
+ auth_plugin = _load_plugin(
515
+ section="authentication",
516
+ yaml_key=AUTH_TYPE_YAML_KEY,
517
+ loader=get_exec_auth_plugins,
523
518
  )
524
519
 
525
520
  # Load authorization plugin
526
- authz_plugin = cast(
527
- ExecAuthzPlugin,
528
- _load_plugin(
529
- section="authorization",
530
- yaml_key=AUTHZ_TYPE_YAML_KEY,
531
- loader=get_exec_authz_plugins,
532
- ),
521
+ authz_plugin = _load_plugin(
522
+ section="authorization",
523
+ yaml_key=AUTHZ_TYPE_YAML_KEY,
524
+ loader=get_exec_authz_plugins,
533
525
  )
534
526
 
535
527
  return auth_plugin, authz_plugin
@@ -114,6 +114,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
114
114
  return message_handler.pull_messages(
115
115
  request=request,
116
116
  state=self.state_factory.state(),
117
+ store=self.objectstore_factory.store(),
117
118
  )
118
119
 
119
120
  def PushMessages(
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
  """Fleet API message handlers."""
16
16
 
17
-
17
+ from logging import ERROR
18
18
  from typing import Optional
19
19
 
20
- from flwr.common import Message
20
+ from flwr.common import Message, log
21
21
  from flwr.common.constant import Status
22
22
  from flwr.common.serde import (
23
23
  fab_to_proto,
@@ -42,6 +42,7 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
42
42
  SendNodeHeartbeatRequest,
43
43
  SendNodeHeartbeatResponse,
44
44
  )
45
+ from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
45
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
46
47
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
47
48
  GetRunRequest,
@@ -51,7 +52,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
51
52
  from flwr.server.superlink.ffs.ffs import Ffs
52
53
  from flwr.server.superlink.linkstate import LinkState
53
54
  from flwr.server.superlink.utils import check_abort
54
- from flwr.supercore.object_store import ObjectStore
55
+ from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
55
56
 
56
57
  from ...utils import store_mapping_and_register_objects
57
58
 
@@ -89,7 +90,9 @@ def send_node_heartbeat(
89
90
 
90
91
 
91
92
  def pull_messages(
92
- request: PullMessagesRequest, state: LinkState
93
+ request: PullMessagesRequest,
94
+ state: LinkState,
95
+ store: ObjectStore,
93
96
  ) -> PullMessagesResponse:
94
97
  """Pull Messages handler."""
95
98
  # Get node_id if client node is not anonymous
@@ -101,10 +104,25 @@ def pull_messages(
101
104
 
102
105
  # Convert to Messages
103
106
  msg_proto = []
107
+ objects_to_pull: dict[str, ObjectIDs] = {}
104
108
  for msg in message_list:
105
- msg_proto.append(message_to_proto(msg))
106
-
107
- return PullMessagesResponse(messages_list=msg_proto)
109
+ try:
110
+ msg_proto.append(message_to_proto(msg))
111
+
112
+ msg_object_id = msg.metadata.message_id
113
+ descendants = store.get_message_descendant_ids(msg_object_id)
114
+ # Include the object_id of the message itself
115
+ objects_to_pull[msg_object_id] = ObjectIDs(
116
+ object_ids=descendants + [msg_object_id]
117
+ )
118
+ except NoObjectInStoreError as e:
119
+ log(ERROR, e.message)
120
+ # Delete message ins from state
121
+ state.delete_messages(message_ins_ids={msg_object_id})
122
+
123
+ return PullMessagesResponse(
124
+ messages_list=msg_proto, objects_to_pull=objects_to_pull
125
+ )
108
126
 
109
127
 
110
128
  def push_messages(
@@ -114,9 +114,10 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
114
114
  """Pull PullMessages."""
115
115
  # Get state from app
116
116
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
117
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
117
118
 
118
119
  # Handle message
119
- return message_handler.pull_messages(request=request, state=state)
120
+ return message_handler.pull_messages(request=request, state=state, store=store)
120
121
 
121
122
 
122
123
  @rest_request_response(PushMessagesRequest)
@@ -17,7 +17,6 @@
17
17
 
18
18
  from os import urandom
19
19
  from typing import Optional
20
- from uuid import uuid4
21
20
 
22
21
  from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
23
22
  from flwr.common.constant import (
@@ -246,7 +245,7 @@ def create_message_error_unavailable_res_message(
246
245
  ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
247
246
  metadata = Metadata(
248
247
  run_id=ins_metadata.run_id,
249
- message_id=str(uuid4()),
248
+ message_id="",
250
249
  src_node_id=SUPERLINK_NODE_ID,
251
250
  dst_node_id=SUPERLINK_NODE_ID,
252
251
  reply_to_message_id=ins_metadata.message_id,
@@ -256,7 +255,7 @@ def create_message_error_unavailable_res_message(
256
255
  ttl=ttl,
257
256
  )
258
257
 
259
- return make_message(
258
+ msg = make_message(
260
259
  metadata=metadata,
261
260
  error=Error(
262
261
  code=(
@@ -271,6 +270,8 @@ def create_message_error_unavailable_res_message(
271
270
  ),
272
271
  ),
273
272
  )
273
+ msg.metadata.__dict__["_message_id"] = msg.object_id
274
+ return msg
274
275
 
275
276
 
276
277
  def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Message:
@@ -278,7 +279,7 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
278
279
  that it isn't found."""
279
280
  metadata = Metadata(
280
281
  run_id=0, # Unknown
281
- message_id=str(uuid4()),
282
+ message_id="",
282
283
  src_node_id=SUPERLINK_NODE_ID,
283
284
  dst_node_id=SUPERLINK_NODE_ID,
284
285
  reply_to_message_id=reply_to_message_id,
@@ -288,13 +289,15 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
288
289
  ttl=0,
289
290
  )
290
291
 
291
- return make_message(
292
+ msg = make_message(
292
293
  metadata=metadata,
293
294
  error=Error(
294
295
  code=ErrorCode.MESSAGE_UNAVAILABLE,
295
296
  reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
296
297
  ),
297
298
  )
299
+ msg.metadata.__dict__["_message_id"] = msg.object_id
300
+ return msg
298
301
 
299
302
 
300
303
  def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
@@ -16,14 +16,14 @@
16
16
 
17
17
 
18
18
  import threading
19
- from logging import DEBUG, INFO
19
+ from logging import DEBUG, ERROR, INFO
20
20
  from typing import Optional
21
21
 
22
22
  import grpc
23
23
 
24
24
  from flwr.common import Message
25
25
  from flwr.common.constant import SUPERLINK_NODE_ID, Status
26
- from flwr.common.inflatable import check_body_len_consistency
26
+ from flwr.common.inflatable import check_body_len_consistency, get_desdendant_object_ids
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
29
29
  context_from_proto,
@@ -47,6 +47,7 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
47
47
  PushLogsResponse,
48
48
  )
49
49
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
50
+ ObjectIDs,
50
51
  PullObjectRequest,
51
52
  PullObjectResponse,
52
53
  PushObjectRequest,
@@ -78,7 +79,7 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
78
79
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
79
80
  from flwr.server.superlink.utils import abort_if
80
81
  from flwr.server.utils.validator import validate_message
81
- from flwr.supercore.object_store import ObjectStoreFactory
82
+ from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
82
83
 
83
84
  from ..utils import store_mapping_and_register_objects
84
85
 
@@ -182,6 +183,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
182
183
  # Init state
183
184
  state: LinkState = self.state_factory.state()
184
185
 
186
+ # Init store
187
+ store = self.objectstore_factory.store()
188
+
185
189
  # Abort if the run is not running
186
190
  abort_if(
187
191
  request.run_id,
@@ -195,6 +199,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
195
199
  message_ids=set(request.message_ids)
196
200
  )
197
201
 
202
+ # Register messages generated by LinkState in the Store for consistency
203
+ for msg_res in messages_res:
204
+ if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
205
+ descendants = list(get_desdendant_object_ids(msg_res))
206
+ message_obj_id = msg_res.metadata.message_id
207
+ # Store mapping
208
+ store.set_message_descendant_ids(
209
+ msg_object_id=message_obj_id, descendant_ids=descendants
210
+ )
211
+ # Preregister
212
+ store.preregister(descendants + [message_obj_id])
213
+
198
214
  # Delete the instruction Messages and their replies if found
199
215
  message_ins_ids_to_delete = {
200
216
  msg_res.metadata.reply_to_message_id for msg_res in messages_res
@@ -204,6 +220,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
204
220
 
205
221
  # Convert Messages to proto
206
222
  messages_list = []
223
+ objects_to_pull: dict[str, ObjectIDs] = {}
207
224
  while messages_res:
208
225
  msg = messages_res.pop(0)
209
226
 
@@ -216,7 +233,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
216
233
  )
217
234
  messages_list.append(message_to_proto(msg))
218
235
 
219
- return PullResMessagesResponse(messages_list=messages_list)
236
+ try:
237
+ msg_object_id = msg.metadata.message_id
238
+ descendants = store.get_message_descendant_ids(msg_object_id)
239
+ # Include the object_id of the message itself
240
+ objects_to_pull[msg_object_id] = ObjectIDs(
241
+ object_ids=descendants + [msg_object_id]
242
+ )
243
+ except NoObjectInStoreError as e:
244
+ log(ERROR, e.message)
245
+ # Delete message ins from state
246
+ state.delete_messages(message_ins_ids={msg_object_id})
247
+
248
+ return PullResMessagesResponse(
249
+ messages_list=messages_list, objects_to_pull=objects_to_pull
250
+ )
220
251
 
221
252
  def GetRun(
222
253
  self, request: GetRunRequest, context: grpc.ServicerContext
@@ -14,10 +14,11 @@
14
14
  # ==============================================================================
15
15
  """Flower ObjectStore."""
16
16
 
17
- from .object_store import ObjectStore
17
+ from .object_store import NoObjectInStoreError, ObjectStore
18
18
  from .object_store_factory import ObjectStoreFactory
19
19
 
20
20
  __all__ = [
21
+ "NoObjectInStoreError",
21
22
  "ObjectStore",
22
23
  "ObjectStoreFactory",
23
24
  ]