flwr-nightly 1.19.0.dev20250529__py3-none-any.whl → 1.19.0.dev20250531__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/utils.py +11 -3
- flwr/common/auth_plugin/auth_plugin.py +1 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/inflatable_grpc_utils.py +27 -13
- flwr/common/record/array.py +10 -21
- flwr/common/record/arrayrecord.py +1 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/serde.py +1 -1
- flwr/server/app.py +17 -25
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +25 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
- flwr/server/superlink/linkstate/utils.py +8 -5
- flwr/server/superlink/serverappio/serverappio_servicer.py +35 -4
- flwr/supercore/object_store/__init__.py +2 -1
- flwr/supercore/object_store/in_memory_object_store.py +9 -2
- flwr/supercore/object_store/object_store.py +12 -0
- flwr/superexec/exec_grpc.py +4 -3
- flwr/superexec/exec_user_auth_interceptor.py +33 -4
- flwr/supernode/start_client_internal.py +144 -170
- {flwr_nightly-1.19.0.dev20250529.dist-info → flwr_nightly-1.19.0.dev20250531.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250529.dist-info → flwr_nightly-1.19.0.dev20250531.dist-info}/RECORD +24 -24
- {flwr_nightly-1.19.0.dev20250529.dist-info → flwr_nightly-1.19.0.dev20250531.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250529.dist-info → flwr_nightly-1.19.0.dev20250531.dist-info}/entry_points.txt +0 -0
flwr/cli/utils.py
CHANGED
@@ -291,9 +291,9 @@ def init_channel(
|
|
291
291
|
def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
292
292
|
"""Context manager to handle specific gRPC errors.
|
293
293
|
|
294
|
-
It catches grpc.RpcError exceptions with UNAUTHENTICATED
|
295
|
-
informs the user, and exits the application. All other
|
296
|
-
escape.
|
294
|
+
It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED, and
|
295
|
+
PERMISSION_DENIED statuses, informs the user, and exits the application. All other
|
296
|
+
exceptions will be allowed to escape.
|
297
297
|
"""
|
298
298
|
try:
|
299
299
|
yield
|
@@ -313,4 +313,12 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
313
313
|
bold=True,
|
314
314
|
)
|
315
315
|
raise typer.Exit(code=1) from None
|
316
|
+
if e.code() == grpc.StatusCode.PERMISSION_DENIED:
|
317
|
+
typer.secho(
|
318
|
+
"❌ Authorization failed. Please contact your administrator"
|
319
|
+
" to check your permissions.",
|
320
|
+
fg=typer.colors.RED,
|
321
|
+
bold=True,
|
322
|
+
)
|
323
|
+
raise typer.Exit(code=1) from None
|
316
324
|
raise
|
@@ -63,7 +63,7 @@ class ExecAuthPlugin(ABC):
|
|
63
63
|
@abstractmethod
|
64
64
|
def refresh_tokens(
|
65
65
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
66
|
-
) -> Optional[Sequence[tuple[str, Union[str, bytes]]]]:
|
66
|
+
) -> tuple[Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[UserInfo]]:
|
67
67
|
"""Refresh authentication tokens in the provided metadata."""
|
68
68
|
|
69
69
|
|
flwr/common/exit_handlers.py
CHANGED
@@ -30,6 +30,7 @@ SIGNAL_TO_EXIT_CODE: dict[int, int] = {
|
|
30
30
|
signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT,
|
31
31
|
signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM,
|
32
32
|
}
|
33
|
+
registered_exit_handlers: list[Callable[[], None]] = []
|
33
34
|
|
34
35
|
# SIGQUIT is not available on Windows
|
35
36
|
if hasattr(signal, "SIGQUIT"):
|
@@ -41,6 +42,7 @@ def register_exit_handlers(
|
|
41
42
|
exit_message: Optional[str] = None,
|
42
43
|
grpc_servers: Optional[list[Server]] = None,
|
43
44
|
bckg_threads: Optional[list[Thread]] = None,
|
45
|
+
exit_handlers: Optional[list[Callable[[], None]]] = None,
|
44
46
|
) -> None:
|
45
47
|
"""Register exit handlers for `SIGINT`, `SIGTERM` and `SIGQUIT` signals.
|
46
48
|
|
@@ -56,8 +58,12 @@ def register_exit_handlers(
|
|
56
58
|
bckg_threads: Optional[List[Thread]] (default: None)
|
57
59
|
An optional list of threads that need to be gracefully
|
58
60
|
terminated before exiting.
|
61
|
+
exit_handlers: Optional[List[Callable[[], None]]] (default: None)
|
62
|
+
An optional list of exit handlers to be called before exiting.
|
63
|
+
Additional exit handlers can be added using `add_exit_handler`.
|
59
64
|
"""
|
60
65
|
default_handlers: dict[int, Callable[[int, FrameType], None]] = {}
|
66
|
+
registered_exit_handlers.extend(exit_handlers or [])
|
61
67
|
|
62
68
|
def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
|
63
69
|
"""Exit handler to be registered with `signal.signal`.
|
@@ -68,6 +74,9 @@ def register_exit_handlers(
|
|
68
74
|
# Reset to default handler
|
69
75
|
signal.signal(signalnum, default_handlers[signalnum]) # type: ignore
|
70
76
|
|
77
|
+
for handler in registered_exit_handlers:
|
78
|
+
handler()
|
79
|
+
|
71
80
|
if grpc_servers is not None:
|
72
81
|
for grpc_server in grpc_servers:
|
73
82
|
grpc_server.stop(grace=1)
|
@@ -87,3 +96,24 @@ def register_exit_handlers(
|
|
87
96
|
for sig in SIGNAL_TO_EXIT_CODE:
|
88
97
|
default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore
|
89
98
|
default_handlers[sig] = default_handler # type: ignore
|
99
|
+
|
100
|
+
|
101
|
+
def add_exit_handler(exit_handler: Callable[[], None]) -> None:
|
102
|
+
"""Add an exit handler to be called on graceful exit.
|
103
|
+
|
104
|
+
This function allows you to register additional exit handlers
|
105
|
+
that will be executed when the application exits gracefully,
|
106
|
+
if `register_exit_handlers` was called.
|
107
|
+
|
108
|
+
Parameters
|
109
|
+
----------
|
110
|
+
exit_handler : Callable[[], None]
|
111
|
+
A callable that takes no arguments and performs cleanup or
|
112
|
+
other actions before the application exits.
|
113
|
+
|
114
|
+
Notes
|
115
|
+
-----
|
116
|
+
This method is not thread-safe, and it allows you to add the
|
117
|
+
same exit handler multiple times.
|
118
|
+
"""
|
119
|
+
registered_exit_handlers.append(exit_handler)
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""InflatableObject utils."""
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Union
|
18
|
+
from typing import Optional, Union
|
19
19
|
|
20
20
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
21
21
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
@@ -24,6 +24,7 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
24
24
|
PushObjectRequest,
|
25
25
|
PushObjectResponse,
|
26
26
|
)
|
27
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
27
28
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
28
29
|
|
29
30
|
from .inflatable import (
|
@@ -46,40 +47,51 @@ inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
|
46
47
|
|
47
48
|
|
48
49
|
def push_object_to_servicer(
|
49
|
-
obj: InflatableObject,
|
50
|
+
obj: InflatableObject,
|
51
|
+
stub: Union[FleetStub, ServerAppIoStub],
|
52
|
+
node: Node,
|
53
|
+
object_ids_to_push: Optional[set[str]] = None,
|
50
54
|
) -> set[str]:
|
51
55
|
"""Recursively deflate an object and push it to the servicer.
|
52
56
|
|
53
|
-
Objects with the same ID are not pushed twice.
|
57
|
+
Objects with the same ID are not pushed twice. If `object_ids_to_push` is set,
|
58
|
+
only objects with those IDs are pushed. It returns the set of pushed object
|
54
59
|
IDs.
|
55
60
|
"""
|
56
61
|
pushed_object_ids: set[str] = set()
|
57
62
|
# Push children if it has any
|
58
63
|
if children := obj.children:
|
59
64
|
for child in children.values():
|
60
|
-
pushed_object_ids |= push_object_to_servicer(
|
65
|
+
pushed_object_ids |= push_object_to_servicer(
|
66
|
+
child, stub, node, object_ids_to_push
|
67
|
+
)
|
61
68
|
|
62
69
|
# Deflate object and push
|
63
70
|
object_content = obj.deflate()
|
64
71
|
object_id = get_object_id(object_content)
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
72
|
+
# Push always if no object set is specified, or if the object is in the set
|
73
|
+
if object_ids_to_push is None or object_id in object_ids_to_push:
|
74
|
+
_: PushObjectResponse = stub.PushObject(
|
75
|
+
PushObjectRequest(
|
76
|
+
node=node,
|
77
|
+
object_id=object_id,
|
78
|
+
object_content=object_content,
|
79
|
+
)
|
69
80
|
)
|
70
|
-
|
71
|
-
pushed_object_ids.add(object_id)
|
81
|
+
pushed_object_ids.add(object_id)
|
72
82
|
|
73
83
|
return pushed_object_ids
|
74
84
|
|
75
85
|
|
76
86
|
def pull_object_from_servicer(
|
77
|
-
object_id: str,
|
87
|
+
object_id: str,
|
88
|
+
stub: Union[FleetStub, ServerAppIoStub],
|
89
|
+
node: Node,
|
78
90
|
) -> InflatableObject:
|
79
91
|
"""Recursively inflate an object by pulling it from the servicer."""
|
80
92
|
# Pull object
|
81
93
|
object_proto: PullObjectResponse = stub.PullObject(
|
82
|
-
PullObjectRequest(object_id=object_id)
|
94
|
+
PullObjectRequest(node=node, object_id=object_id)
|
83
95
|
)
|
84
96
|
object_content = object_proto.object_content
|
85
97
|
|
@@ -93,7 +105,9 @@ def pull_object_from_servicer(
|
|
93
105
|
# Pull all children objects
|
94
106
|
children: dict[str, InflatableObject] = {}
|
95
107
|
for child_object_id in children_obj_ids:
|
96
|
-
children[child_object_id] = pull_object_from_servicer(
|
108
|
+
children[child_object_id] = pull_object_from_servicer(
|
109
|
+
child_object_id, stub, node
|
110
|
+
)
|
97
111
|
|
98
112
|
# Inflate object passing its children
|
99
113
|
return cls_type.inflate(object_content, children=children)
|
flwr/common/record/array.py
CHANGED
@@ -62,8 +62,8 @@ class Array(InflatableObject):
|
|
62
62
|
A string representing the data type of the serialized object (e.g. `"float32"`).
|
63
63
|
Only required if you are not passing in a ndarray or a tensor.
|
64
64
|
|
65
|
-
shape : Optional[
|
66
|
-
A
|
65
|
+
shape : Optional[tuple[int, ...]] (default: None)
|
66
|
+
A tuple representing the shape of the unserialized array-like object. Only
|
67
67
|
required if you are not passing in a ndarray or a tensor.
|
68
68
|
|
69
69
|
stype : Optional[str] (default: None)
|
@@ -107,24 +107,13 @@ class Array(InflatableObject):
|
|
107
107
|
"""
|
108
108
|
|
109
109
|
dtype: str
|
110
|
+
shape: tuple[int, ...]
|
110
111
|
stype: str
|
111
112
|
data: bytes
|
112
113
|
|
113
|
-
@property
|
114
|
-
def shape(self) -> list[int]:
|
115
|
-
"""Get the shape of the array."""
|
116
|
-
self.is_dirty = True # Mark as dirty when shape is accessed
|
117
|
-
return cast(list[int], self.__dict__["_shape"])
|
118
|
-
|
119
|
-
@shape.setter
|
120
|
-
def shape(self, value: list[int]) -> None:
|
121
|
-
"""Set the shape of the array."""
|
122
|
-
self.is_dirty = True # Mark as dirty when shape is set
|
123
|
-
self.__dict__["_shape"] = value
|
124
|
-
|
125
114
|
@overload
|
126
115
|
def __init__( # noqa: E704
|
127
|
-
self, dtype: str, shape:
|
116
|
+
self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
|
128
117
|
) -> None: ...
|
129
118
|
|
130
119
|
@overload
|
@@ -137,7 +126,7 @@ class Array(InflatableObject):
|
|
137
126
|
self,
|
138
127
|
*args: Any,
|
139
128
|
dtype: str | None = None,
|
140
|
-
shape:
|
129
|
+
shape: tuple[int, ...] | None = None,
|
141
130
|
stype: str | None = None,
|
142
131
|
data: bytes | None = None,
|
143
132
|
ndarray: NDArray | None = None,
|
@@ -145,7 +134,7 @@ class Array(InflatableObject):
|
|
145
134
|
) -> None:
|
146
135
|
# Determine the initialization method and validate input arguments.
|
147
136
|
# Support three initialization formats:
|
148
|
-
# 1. Array(dtype: str, shape:
|
137
|
+
# 1. Array(dtype: str, shape: tuple[int, ...], stype: str, data: bytes)
|
149
138
|
# 2. Array(ndarray: NDArray)
|
150
139
|
# 3. Array(torch_tensor: torch.Tensor)
|
151
140
|
|
@@ -192,7 +181,7 @@ class Array(InflatableObject):
|
|
192
181
|
if (
|
193
182
|
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
|
194
183
|
and isinstance(all_args[0], str)
|
195
|
-
and isinstance(all_args[1],
|
184
|
+
and isinstance(all_args[1], tuple)
|
196
185
|
and all(isinstance(i, int) for i in all_args[1])
|
197
186
|
and isinstance(all_args[2], str)
|
198
187
|
and isinstance(all_args[3], bytes)
|
@@ -232,7 +221,7 @@ class Array(InflatableObject):
|
|
232
221
|
data = buffer.getvalue()
|
233
222
|
return Array(
|
234
223
|
dtype=str(ndarray.dtype),
|
235
|
-
shape=
|
224
|
+
shape=tuple(ndarray.shape),
|
236
225
|
stype=SType.NUMPY,
|
237
226
|
data=data,
|
238
227
|
)
|
@@ -302,7 +291,7 @@ class Array(InflatableObject):
|
|
302
291
|
proto_array = ArrayProto.FromString(obj_body)
|
303
292
|
return cls(
|
304
293
|
dtype=proto_array.dtype,
|
305
|
-
shape=
|
294
|
+
shape=tuple(proto_array.shape),
|
306
295
|
stype=proto_array.stype,
|
307
296
|
data=proto_array.data,
|
308
297
|
)
|
@@ -328,7 +317,7 @@ class Array(InflatableObject):
|
|
328
317
|
|
329
318
|
def __setattr__(self, name: str, value: Any) -> None:
|
330
319
|
"""Set attribute with special handling for dirty state."""
|
331
|
-
if name in ("dtype", "stype", "data"):
|
320
|
+
if name in ("dtype", "shape", "stype", "data"):
|
332
321
|
# Mark as dirty if any of the main attributes are set
|
333
322
|
self.is_dirty = True
|
334
323
|
super().__setattr__(name, value)
|
@@ -252,7 +252,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
252
252
|
record = ArrayRecord()
|
253
253
|
for k, v in array_dict.items():
|
254
254
|
record[k] = Array(
|
255
|
-
dtype=v.dtype, shape=
|
255
|
+
dtype=v.dtype, shape=tuple(v.shape), stype=v.stype, data=v.data
|
256
256
|
)
|
257
257
|
if not keep_input:
|
258
258
|
array_dict.clear()
|
flwr/common/recorddict_compat.py
CHANGED
@@ -111,12 +111,12 @@ def parameters_to_arrayrecord(parameters: Parameters, keep_input: bool) -> Array
|
|
111
111
|
else:
|
112
112
|
tensor = parameters.tensors.pop(0)
|
113
113
|
ordered_dict[str(idx)] = Array(
|
114
|
-
data=tensor, dtype="", stype=tensor_type, shape=
|
114
|
+
data=tensor, dtype="", stype=tensor_type, shape=()
|
115
115
|
)
|
116
116
|
|
117
117
|
if num_arrays == 0:
|
118
118
|
ordered_dict[EMPTY_TENSOR_KEY] = Array(
|
119
|
-
data=b"", dtype="", stype=tensor_type, shape=
|
119
|
+
data=b"", dtype="", stype=tensor_type, shape=()
|
120
120
|
)
|
121
121
|
return ArrayRecord(ordered_dict, keep_input=keep_input)
|
122
122
|
|
flwr/common/serde.py
CHANGED
@@ -390,7 +390,7 @@ def array_from_proto(array_proto: ProtoArray) -> Array:
|
|
390
390
|
"""Deserialize Array from ProtoBuf."""
|
391
391
|
return Array(
|
392
392
|
dtype=array_proto.dtype,
|
393
|
-
shape=
|
393
|
+
shape=tuple(array_proto.shape),
|
394
394
|
stype=array_proto.stype,
|
395
395
|
data=array_proto.data,
|
396
396
|
)
|
flwr/server/app.py
CHANGED
@@ -27,7 +27,7 @@ from collections.abc import Sequence
|
|
27
27
|
from logging import DEBUG, INFO, WARN
|
28
28
|
from pathlib import Path
|
29
29
|
from time import sleep
|
30
|
-
from typing import Any, Callable, Optional,
|
30
|
+
from typing import Any, Callable, Optional, TypeVar
|
31
31
|
|
32
32
|
import grpc
|
33
33
|
import yaml
|
@@ -85,6 +85,7 @@ from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
|
85
85
|
|
86
86
|
DATABASE = ":flwr-in-memory-state:"
|
87
87
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
88
|
+
P = TypeVar("P", ExecAuthPlugin, ExecAuthzPlugin)
|
88
89
|
|
89
90
|
|
90
91
|
try:
|
@@ -151,15 +152,13 @@ def run_superlink() -> None:
|
|
151
152
|
verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
|
152
153
|
|
153
154
|
auth_plugin: Optional[ExecAuthPlugin] = None
|
154
|
-
authz_plugin: Optional[ExecAuthzPlugin] = None
|
155
|
+
authz_plugin: Optional[ExecAuthzPlugin] = None
|
155
156
|
event_log_plugin: Optional[EventLogWriterPlugin] = None
|
156
157
|
# Load the auth plugin if the args.user_auth_config is provided
|
157
158
|
if cfg_path := getattr(args, "user_auth_config", None):
|
158
|
-
|
159
|
-
auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins( # noqa: F841
|
159
|
+
auth_plugin, authz_plugin = _try_obtain_exec_auth_plugins(
|
160
160
|
Path(cfg_path), verify_tls_cert
|
161
161
|
)
|
162
|
-
# pylint: enable=unused-variable
|
163
162
|
# Enable event logging if the args.enable_event_log is True
|
164
163
|
if args.enable_event_log:
|
165
164
|
event_log_plugin = _try_obtain_exec_event_log_writer_plugin()
|
@@ -185,6 +184,7 @@ def run_superlink() -> None:
|
|
185
184
|
[args.executor_config] if args.executor_config else args.executor_config
|
186
185
|
),
|
187
186
|
auth_plugin=auth_plugin,
|
187
|
+
authz_plugin=authz_plugin,
|
188
188
|
event_log_plugin=event_log_plugin,
|
189
189
|
)
|
190
190
|
grpc_servers = [exec_server]
|
@@ -490,15 +490,13 @@ def _try_obtain_exec_auth_plugins(
|
|
490
490
|
config: dict[str, Any] = yaml.safe_load(file)
|
491
491
|
|
492
492
|
def _load_plugin(
|
493
|
-
section: str,
|
494
|
-
|
495
|
-
loader: Callable[[], dict[str, type[Union[ExecAuthPlugin, ExecAuthzPlugin]]]],
|
496
|
-
) -> Union[ExecAuthPlugin, ExecAuthzPlugin]:
|
493
|
+
section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
|
494
|
+
) -> P:
|
497
495
|
section_cfg = config.get(section, {})
|
498
496
|
auth_plugin_name = section_cfg.get(yaml_key, "")
|
499
497
|
try:
|
500
|
-
plugins = loader()
|
501
|
-
plugin_cls = plugins[auth_plugin_name]
|
498
|
+
plugins: dict[str, type[P]] = loader()
|
499
|
+
plugin_cls: type[P] = plugins[auth_plugin_name]
|
502
500
|
return plugin_cls(
|
503
501
|
user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
504
502
|
)
|
@@ -513,23 +511,17 @@ def _try_obtain_exec_auth_plugins(
|
|
513
511
|
sys.exit(f"No {section} plugins are currently supported.")
|
514
512
|
|
515
513
|
# Load authentication plugin
|
516
|
-
auth_plugin =
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
yaml_key=AUTH_TYPE_YAML_KEY,
|
521
|
-
loader=get_exec_auth_plugins,
|
522
|
-
),
|
514
|
+
auth_plugin = _load_plugin(
|
515
|
+
section="authentication",
|
516
|
+
yaml_key=AUTH_TYPE_YAML_KEY,
|
517
|
+
loader=get_exec_auth_plugins,
|
523
518
|
)
|
524
519
|
|
525
520
|
# Load authorization plugin
|
526
|
-
authz_plugin =
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
yaml_key=AUTHZ_TYPE_YAML_KEY,
|
531
|
-
loader=get_exec_authz_plugins,
|
532
|
-
),
|
521
|
+
authz_plugin = _load_plugin(
|
522
|
+
section="authorization",
|
523
|
+
yaml_key=AUTHZ_TYPE_YAML_KEY,
|
524
|
+
loader=get_exec_authz_plugins,
|
533
525
|
)
|
534
526
|
|
535
527
|
return auth_plugin, authz_plugin
|
@@ -14,10 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Fleet API message handlers."""
|
16
16
|
|
17
|
-
|
17
|
+
from logging import ERROR
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
-
from flwr.common import Message
|
20
|
+
from flwr.common import Message, log
|
21
21
|
from flwr.common.constant import Status
|
22
22
|
from flwr.common.serde import (
|
23
23
|
fab_to_proto,
|
@@ -42,6 +42,7 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
42
42
|
SendNodeHeartbeatRequest,
|
43
43
|
SendNodeHeartbeatResponse,
|
44
44
|
)
|
45
|
+
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
45
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
46
47
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
47
48
|
GetRunRequest,
|
@@ -51,7 +52,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
51
52
|
from flwr.server.superlink.ffs.ffs import Ffs
|
52
53
|
from flwr.server.superlink.linkstate import LinkState
|
53
54
|
from flwr.server.superlink.utils import check_abort
|
54
|
-
from flwr.supercore.object_store import ObjectStore
|
55
|
+
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
|
55
56
|
|
56
57
|
from ...utils import store_mapping_and_register_objects
|
57
58
|
|
@@ -89,7 +90,9 @@ def send_node_heartbeat(
|
|
89
90
|
|
90
91
|
|
91
92
|
def pull_messages(
|
92
|
-
request: PullMessagesRequest,
|
93
|
+
request: PullMessagesRequest,
|
94
|
+
state: LinkState,
|
95
|
+
store: ObjectStore,
|
93
96
|
) -> PullMessagesResponse:
|
94
97
|
"""Pull Messages handler."""
|
95
98
|
# Get node_id if client node is not anonymous
|
@@ -101,10 +104,25 @@ def pull_messages(
|
|
101
104
|
|
102
105
|
# Convert to Messages
|
103
106
|
msg_proto = []
|
107
|
+
objects_to_pull: dict[str, ObjectIDs] = {}
|
104
108
|
for msg in message_list:
|
105
|
-
|
106
|
-
|
107
|
-
|
109
|
+
try:
|
110
|
+
msg_proto.append(message_to_proto(msg))
|
111
|
+
|
112
|
+
msg_object_id = msg.metadata.message_id
|
113
|
+
descendants = store.get_message_descendant_ids(msg_object_id)
|
114
|
+
# Include the object_id of the message itself
|
115
|
+
objects_to_pull[msg_object_id] = ObjectIDs(
|
116
|
+
object_ids=descendants + [msg_object_id]
|
117
|
+
)
|
118
|
+
except NoObjectInStoreError as e:
|
119
|
+
log(ERROR, e.message)
|
120
|
+
# Delete message ins from state
|
121
|
+
state.delete_messages(message_ins_ids={msg_object_id})
|
122
|
+
|
123
|
+
return PullMessagesResponse(
|
124
|
+
messages_list=msg_proto, objects_to_pull=objects_to_pull
|
125
|
+
)
|
108
126
|
|
109
127
|
|
110
128
|
def push_messages(
|
@@ -114,9 +114,10 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
|
114
114
|
"""Pull PullMessages."""
|
115
115
|
# Get state from app
|
116
116
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
117
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.STATE_FACTORY).store()
|
117
118
|
|
118
119
|
# Handle message
|
119
|
-
return message_handler.pull_messages(request=request, state=state)
|
120
|
+
return message_handler.pull_messages(request=request, state=state, store=store)
|
120
121
|
|
121
122
|
|
122
123
|
@rest_request_response(PushMessagesRequest)
|
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
from os import urandom
|
19
19
|
from typing import Optional
|
20
|
-
from uuid import uuid4
|
21
20
|
|
22
21
|
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
23
22
|
from flwr.common.constant import (
|
@@ -246,7 +245,7 @@ def create_message_error_unavailable_res_message(
|
|
246
245
|
ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
|
247
246
|
metadata = Metadata(
|
248
247
|
run_id=ins_metadata.run_id,
|
249
|
-
message_id=
|
248
|
+
message_id="",
|
250
249
|
src_node_id=SUPERLINK_NODE_ID,
|
251
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
252
251
|
reply_to_message_id=ins_metadata.message_id,
|
@@ -256,7 +255,7 @@ def create_message_error_unavailable_res_message(
|
|
256
255
|
ttl=ttl,
|
257
256
|
)
|
258
257
|
|
259
|
-
|
258
|
+
msg = make_message(
|
260
259
|
metadata=metadata,
|
261
260
|
error=Error(
|
262
261
|
code=(
|
@@ -271,6 +270,8 @@ def create_message_error_unavailable_res_message(
|
|
271
270
|
),
|
272
271
|
),
|
273
272
|
)
|
273
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
274
|
+
return msg
|
274
275
|
|
275
276
|
|
276
277
|
def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Message:
|
@@ -278,7 +279,7 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
|
|
278
279
|
that it isn't found."""
|
279
280
|
metadata = Metadata(
|
280
281
|
run_id=0, # Unknown
|
281
|
-
message_id=
|
282
|
+
message_id="",
|
282
283
|
src_node_id=SUPERLINK_NODE_ID,
|
283
284
|
dst_node_id=SUPERLINK_NODE_ID,
|
284
285
|
reply_to_message_id=reply_to_message_id,
|
@@ -288,13 +289,15 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Me
|
|
288
289
|
ttl=0,
|
289
290
|
)
|
290
291
|
|
291
|
-
|
292
|
+
msg = make_message(
|
292
293
|
metadata=metadata,
|
293
294
|
error=Error(
|
294
295
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
295
296
|
reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
|
296
297
|
),
|
297
298
|
)
|
299
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
300
|
+
return msg
|
298
301
|
|
299
302
|
|
300
303
|
def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
|
@@ -16,14 +16,14 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import threading
|
19
|
-
from logging import DEBUG, INFO
|
19
|
+
from logging import DEBUG, ERROR, INFO
|
20
20
|
from typing import Optional
|
21
21
|
|
22
22
|
import grpc
|
23
23
|
|
24
24
|
from flwr.common import Message
|
25
25
|
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
26
|
-
from flwr.common.inflatable import check_body_len_consistency
|
26
|
+
from flwr.common.inflatable import check_body_len_consistency, get_desdendant_object_ids
|
27
27
|
from flwr.common.logger import log
|
28
28
|
from flwr.common.serde import (
|
29
29
|
context_from_proto,
|
@@ -47,6 +47,7 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
47
47
|
PushLogsResponse,
|
48
48
|
)
|
49
49
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
50
|
+
ObjectIDs,
|
50
51
|
PullObjectRequest,
|
51
52
|
PullObjectResponse,
|
52
53
|
PushObjectRequest,
|
@@ -78,7 +79,7 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
78
79
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
79
80
|
from flwr.server.superlink.utils import abort_if
|
80
81
|
from flwr.server.utils.validator import validate_message
|
81
|
-
from flwr.supercore.object_store import ObjectStoreFactory
|
82
|
+
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
|
82
83
|
|
83
84
|
from ..utils import store_mapping_and_register_objects
|
84
85
|
|
@@ -182,6 +183,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
182
183
|
# Init state
|
183
184
|
state: LinkState = self.state_factory.state()
|
184
185
|
|
186
|
+
# Init store
|
187
|
+
store = self.objectstore_factory.store()
|
188
|
+
|
185
189
|
# Abort if the run is not running
|
186
190
|
abort_if(
|
187
191
|
request.run_id,
|
@@ -195,6 +199,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
195
199
|
message_ids=set(request.message_ids)
|
196
200
|
)
|
197
201
|
|
202
|
+
# Register messages generated by LinkState in the Store for consistency
|
203
|
+
for msg_res in messages_res:
|
204
|
+
if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
|
205
|
+
descendants = list(get_desdendant_object_ids(msg_res))
|
206
|
+
message_obj_id = msg_res.metadata.message_id
|
207
|
+
# Store mapping
|
208
|
+
store.set_message_descendant_ids(
|
209
|
+
msg_object_id=message_obj_id, descendant_ids=descendants
|
210
|
+
)
|
211
|
+
# Preregister
|
212
|
+
store.preregister(descendants + [message_obj_id])
|
213
|
+
|
198
214
|
# Delete the instruction Messages and their replies if found
|
199
215
|
message_ins_ids_to_delete = {
|
200
216
|
msg_res.metadata.reply_to_message_id for msg_res in messages_res
|
@@ -204,6 +220,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
204
220
|
|
205
221
|
# Convert Messages to proto
|
206
222
|
messages_list = []
|
223
|
+
objects_to_pull: dict[str, ObjectIDs] = {}
|
207
224
|
while messages_res:
|
208
225
|
msg = messages_res.pop(0)
|
209
226
|
|
@@ -216,7 +233,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
216
233
|
)
|
217
234
|
messages_list.append(message_to_proto(msg))
|
218
235
|
|
219
|
-
|
236
|
+
try:
|
237
|
+
msg_object_id = msg.metadata.message_id
|
238
|
+
descendants = store.get_message_descendant_ids(msg_object_id)
|
239
|
+
# Include the object_id of the message itself
|
240
|
+
objects_to_pull[msg_object_id] = ObjectIDs(
|
241
|
+
object_ids=descendants + [msg_object_id]
|
242
|
+
)
|
243
|
+
except NoObjectInStoreError as e:
|
244
|
+
log(ERROR, e.message)
|
245
|
+
# Delete message ins from state
|
246
|
+
state.delete_messages(message_ins_ids={msg_object_id})
|
247
|
+
|
248
|
+
return PullResMessagesResponse(
|
249
|
+
messages_list=messages_list, objects_to_pull=objects_to_pull
|
250
|
+
)
|
220
251
|
|
221
252
|
def GetRun(
|
222
253
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
@@ -14,10 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Flower ObjectStore."""
|
16
16
|
|
17
|
-
from .object_store import ObjectStore
|
17
|
+
from .object_store import NoObjectInStoreError, ObjectStore
|
18
18
|
from .object_store_factory import ObjectStoreFactory
|
19
19
|
|
20
20
|
__all__ = [
|
21
|
+
"NoObjectInStoreError",
|
21
22
|
"ObjectStore",
|
22
23
|
"ObjectStoreFactory",
|
23
24
|
]
|