flwr-nightly 1.19.0.dev20250526__py3-none-any.whl → 1.19.0.dev20250528__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +3 -3
- flwr/cli/run/run.py +2 -6
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +5 -4
- flwr/client/grpc_rere_client/connection.py +2 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +18 -0
- flwr/common/constant.py +3 -0
- flwr/common/inflatable.py +33 -2
- flwr/common/message.py +5 -1
- flwr/common/record/array.py +38 -1
- flwr/common/record/arrayrecord.py +34 -0
- flwr/common/serde.py +6 -1
- flwr/compat/client/app.py +9 -151
- flwr/proto/fleet_pb2.py +25 -13
- flwr/proto/fleet_pb2.pyi +60 -3
- flwr/proto/message_pb2.py +22 -19
- flwr/proto/message_pb2.pyi +25 -2
- flwr/proto/serverappio_pb2.py +31 -19
- flwr/proto/serverappio_pb2.pyi +60 -3
- flwr/server/app.py +44 -1
- flwr/server/grid/grpc_grid.py +2 -1
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -2
- flwr/server/superlink/fleet/vce/vce_api.py +3 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -25
- flwr/server/superlink/linkstate/linkstate.py +9 -10
- flwr/server/superlink/linkstate/sqlite_linkstate.py +11 -21
- flwr/server/superlink/linkstate/utils.py +23 -23
- flwr/server/superlink/serverappio/serverappio_servicer.py +6 -10
- flwr/server/utils/validator.py +2 -2
- flwr/supercore/object_store/in_memory_object_store.py +30 -4
- flwr/supercore/object_store/object_store.py +48 -1
- flwr/superexec/exec_servicer.py +1 -2
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/RECORD +41 -41
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/entry_points.txt +0 -0
flwr/cli/log.py
CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import log as logger
|
|
35
35
|
from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
|
36
36
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
37
37
|
|
38
|
-
from .utils import init_channel, try_obtain_cli_auth_plugin
|
38
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
39
39
|
|
40
40
|
|
41
41
|
class AllLogsRetrieved(BaseException):
|
@@ -95,7 +95,7 @@ def stream_logs(
|
|
95
95
|
latest_timestamp = 0.0
|
96
96
|
res = None
|
97
97
|
try:
|
98
|
-
with
|
98
|
+
with flwr_cli_grpc_exc_handler():
|
99
99
|
for res in stub.StreamLogs(req, timeout=duration):
|
100
100
|
print(res.log_output, end="")
|
101
101
|
raise AllLogsRetrieved()
|
@@ -116,7 +116,7 @@ def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
|
|
116
116
|
req = StreamLogsRequest(run_id=run_id, after_timestamp=0.0)
|
117
117
|
|
118
118
|
try:
|
119
|
-
with
|
119
|
+
with flwr_cli_grpc_exc_handler():
|
120
120
|
# Enforce timeout for graceful exit
|
121
121
|
for res in stub.StreamLogs(req, timeout=timeout):
|
122
122
|
print(res.log_output)
|
flwr/cli/login/login.py
CHANGED
@@ -35,11 +35,7 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
35
35
|
)
|
36
36
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
37
37
|
|
38
|
-
from ..utils import
|
39
|
-
init_channel,
|
40
|
-
try_obtain_cli_auth_plugin,
|
41
|
-
unauthenticated_exc_handler,
|
42
|
-
)
|
38
|
+
from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
43
39
|
|
44
40
|
|
45
41
|
def login( # pylint: disable=R0914
|
@@ -96,7 +92,7 @@ def login( # pylint: disable=R0914
|
|
96
92
|
stub = ExecStub(channel)
|
97
93
|
|
98
94
|
login_request = GetLoginDetailsRequest()
|
99
|
-
with
|
95
|
+
with flwr_cli_grpc_exc_handler():
|
100
96
|
login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request)
|
101
97
|
|
102
98
|
# Get the auth plugin
|
@@ -120,7 +116,7 @@ def login( # pylint: disable=R0914
|
|
120
116
|
expires_in=login_response.expires_in,
|
121
117
|
interval=login_response.interval,
|
122
118
|
)
|
123
|
-
with
|
119
|
+
with flwr_cli_grpc_exc_handler():
|
124
120
|
credentials = auth_plugin.login(details, stub)
|
125
121
|
|
126
122
|
# Store the tokens
|
flwr/cli/ls.py
CHANGED
@@ -44,7 +44,7 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
44
44
|
)
|
45
45
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
46
46
|
|
47
|
-
from .utils import init_channel, try_obtain_cli_auth_plugin
|
47
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
48
48
|
|
49
49
|
_RunListType = tuple[int, str, str, str, str, str, str, str, str]
|
50
50
|
|
@@ -305,7 +305,7 @@ def _list_runs(
|
|
305
305
|
output_format: str = CliOutputFormat.DEFAULT,
|
306
306
|
) -> None:
|
307
307
|
"""List all runs."""
|
308
|
-
with
|
308
|
+
with flwr_cli_grpc_exc_handler():
|
309
309
|
res: ListRunsResponse = stub.ListRuns(ListRunsRequest())
|
310
310
|
run_dict = {run_id: run_from_proto(proto) for run_id, proto in res.run_dict.items()}
|
311
311
|
|
@@ -322,7 +322,7 @@ def _display_one_run(
|
|
322
322
|
output_format: str = CliOutputFormat.DEFAULT,
|
323
323
|
) -> None:
|
324
324
|
"""Display information about a specific run."""
|
325
|
-
with
|
325
|
+
with flwr_cli_grpc_exc_handler():
|
326
326
|
res: ListRunsResponse = stub.ListRuns(ListRunsRequest(run_id=run_id))
|
327
327
|
if not res.run_dict:
|
328
328
|
raise ValueError(f"Run ID {run_id} not found")
|
flwr/cli/run/run.py
CHANGED
@@ -45,11 +45,7 @@ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
45
45
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
46
46
|
|
47
47
|
from ..log import start_stream
|
48
|
-
from ..utils import
|
49
|
-
init_channel,
|
50
|
-
try_obtain_cli_auth_plugin,
|
51
|
-
unauthenticated_exc_handler,
|
52
|
-
)
|
48
|
+
from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
53
49
|
|
54
50
|
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
|
55
51
|
|
@@ -172,7 +168,7 @@ def _run_with_exec_api(
|
|
172
168
|
override_config=user_config_to_proto(parse_config_args(config_overrides)),
|
173
169
|
federation_options=config_record_to_proto(c_record),
|
174
170
|
)
|
175
|
-
with
|
171
|
+
with flwr_cli_grpc_exc_handler():
|
176
172
|
res = stub.StartRun(req)
|
177
173
|
|
178
174
|
if res.HasField("run_id"):
|
flwr/cli/stop.py
CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import print_json_error, redirect_output, restore_output
|
|
35
35
|
from flwr.proto.exec_pb2 import StopRunRequest, StopRunResponse # pylint: disable=E0611
|
36
36
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
37
37
|
|
38
|
-
from .utils import init_channel, try_obtain_cli_auth_plugin
|
38
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
39
39
|
|
40
40
|
|
41
41
|
def stop( # pylint: disable=R0914
|
@@ -122,7 +122,7 @@ def stop( # pylint: disable=R0914
|
|
122
122
|
|
123
123
|
def _stop_run(stub: ExecStub, run_id: int, output_format: str) -> None:
|
124
124
|
"""Stop a run."""
|
125
|
-
with
|
125
|
+
with flwr_cli_grpc_exc_handler():
|
126
126
|
response: StopRunResponse = stub.StopRun(request=StopRunRequest(run_id=run_id))
|
127
127
|
if response.success:
|
128
128
|
typer.secho(f"✅ Run {run_id} successfully stopped.", fg=typer.colors.GREEN)
|
flwr/cli/utils.py
CHANGED
@@ -288,11 +288,12 @@ def init_channel(
|
|
288
288
|
|
289
289
|
|
290
290
|
@contextmanager
|
291
|
-
def
|
292
|
-
"""Context manager to handle gRPC
|
291
|
+
def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
292
|
+
"""Context manager to handle specific gRPC errors.
|
293
293
|
|
294
|
-
It catches grpc.RpcError exceptions with UNAUTHENTICATED
|
295
|
-
and exits the application. All other exceptions will be allowed to
|
294
|
+
It catches grpc.RpcError exceptions with UNAUTHENTICATED and UNIMPLEMENTED statuses,
|
295
|
+
informs the user, and exits the application. All other exceptions will be allowed to
|
296
|
+
escape.
|
296
297
|
"""
|
297
298
|
try:
|
298
299
|
yield
|
@@ -279,6 +279,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
279
279
|
log(ERROR, "No current message")
|
280
280
|
return
|
281
281
|
|
282
|
+
# Set message_id
|
283
|
+
message.metadata.__dict__["_message_id"] = message.object_id
|
282
284
|
# Validate out message
|
283
285
|
if not validate_out_message(message, metadata):
|
284
286
|
log(ERROR, "Invalid out message")
|
@@ -164,7 +164,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
164
164
|
in_meta = in_message_metadata
|
165
165
|
if ( # pylint: disable-next=too-many-boolean-expressions
|
166
166
|
out_meta.run_id == in_meta.run_id
|
167
|
-
and out_meta.message_id ==
|
167
|
+
and out_meta.message_id == out_message.object_id # Should match the object id
|
168
168
|
and out_meta.src_node_id == in_meta.dst_node_id
|
169
169
|
and out_meta.dst_node_id == in_meta.src_node_id
|
170
170
|
and out_meta.reply_to_message_id == in_meta.message_id
|
@@ -17,8 +17,10 @@
|
|
17
17
|
|
18
18
|
from .auth_plugin import CliAuthPlugin as CliAuthPlugin
|
19
19
|
from .auth_plugin import ExecAuthPlugin as ExecAuthPlugin
|
20
|
+
from .auth_plugin import ExecAuthzPlugin as ExecAuthzPlugin
|
20
21
|
|
21
22
|
__all__ = [
|
22
23
|
"CliAuthPlugin",
|
23
24
|
"ExecAuthPlugin",
|
25
|
+
"ExecAuthzPlugin",
|
24
26
|
]
|
@@ -64,6 +64,24 @@ class ExecAuthPlugin(ABC):
|
|
64
64
|
"""Refresh authentication tokens in the provided metadata."""
|
65
65
|
|
66
66
|
|
67
|
+
class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
|
68
|
+
"""Abstract Flower Authorization Plugin class for ExecServicer.
|
69
|
+
|
70
|
+
Parameters
|
71
|
+
----------
|
72
|
+
user_authz_config_path : Path
|
73
|
+
Path to the YAML file containing the authorization configuration.
|
74
|
+
"""
|
75
|
+
|
76
|
+
@abstractmethod
|
77
|
+
def __init__(self, user_authz_config_path: Path, verify_tls_cert: bool):
|
78
|
+
"""Abstract constructor."""
|
79
|
+
|
80
|
+
@abstractmethod
|
81
|
+
def verify_user_authorization(self, user_info: UserInfo) -> bool:
|
82
|
+
"""Verify user authorization request."""
|
83
|
+
|
84
|
+
|
67
85
|
class CliAuthPlugin(ABC):
|
68
86
|
"""Abstract Flower Auth Plugin class for CLI.
|
69
87
|
|
flwr/common/constant.py
CHANGED
@@ -115,6 +115,9 @@ AUTH_TYPE_YAML_KEY = "auth_type" # For key name in YAML file
|
|
115
115
|
ACCESS_TOKEN_KEY = "flwr-oidc-access-token"
|
116
116
|
REFRESH_TOKEN_KEY = "flwr-oidc-refresh-token"
|
117
117
|
|
118
|
+
# Constants for user authorization
|
119
|
+
AUTHZ_TYPE_YAML_KEY = "authz_type" # For key name in YAML file
|
120
|
+
|
118
121
|
# Constants for node authentication
|
119
122
|
PUBLIC_KEY_HEADER = "flwr-public-key-bin" # Must end with "-bin" for binary data
|
120
123
|
SIGNATURE_HEADER = "flwr-signature-bin" # Must end with "-bin" for binary data
|
flwr/common/inflatable.py
CHANGED
@@ -18,7 +18,7 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import hashlib
|
21
|
-
from typing import TypeVar
|
21
|
+
from typing import TypeVar, cast
|
22
22
|
|
23
23
|
from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
|
24
24
|
|
@@ -55,13 +55,24 @@ class InflatableObject:
|
|
55
55
|
@property
|
56
56
|
def object_id(self) -> str:
|
57
57
|
"""Get object_id."""
|
58
|
-
|
58
|
+
if self.is_dirty or "_object_id" not in self.__dict__:
|
59
|
+
self.__dict__["_object_id"] = get_object_id(self.deflate())
|
60
|
+
return cast(str, self.__dict__["_object_id"])
|
59
61
|
|
60
62
|
@property
|
61
63
|
def children(self) -> dict[str, InflatableObject] | None:
|
62
64
|
"""Get all child objects as a dictionary or None if there are no children."""
|
63
65
|
return None
|
64
66
|
|
67
|
+
@property
|
68
|
+
def is_dirty(self) -> bool:
|
69
|
+
"""Check if the object is dirty after the last deflation.
|
70
|
+
|
71
|
+
An object is considered dirty if its content has changed since the last its
|
72
|
+
object ID was computed.
|
73
|
+
"""
|
74
|
+
return True
|
75
|
+
|
65
76
|
|
66
77
|
T = TypeVar("T", bound=InflatableObject)
|
67
78
|
|
@@ -178,3 +189,23 @@ def get_object_head_values_from_object_content(
|
|
178
189
|
obj_type, children_str, body_len = head.split(HEAD_VALUE_DIVIDER)
|
179
190
|
children_ids = children_str.split(",") if children_str else []
|
180
191
|
return obj_type, children_ids, int(body_len)
|
192
|
+
|
193
|
+
|
194
|
+
def _get_descendants_object_ids_recursively(obj: InflatableObject) -> set[str]:
|
195
|
+
|
196
|
+
descendants: set[str] = set()
|
197
|
+
if children := obj.children:
|
198
|
+
for child in children.values():
|
199
|
+
descendants |= _get_descendants_object_ids_recursively(child)
|
200
|
+
|
201
|
+
descendants.add(obj.object_id)
|
202
|
+
|
203
|
+
return descendants
|
204
|
+
|
205
|
+
|
206
|
+
def get_desdendant_object_ids(obj: InflatableObject) -> set[str]:
|
207
|
+
"""Get a set of object IDs of all descendants."""
|
208
|
+
descendants = _get_descendants_object_ids_recursively(obj)
|
209
|
+
# Exclude Object ID of parent object
|
210
|
+
descendants.discard(obj.object_id)
|
211
|
+
return descendants
|
flwr/common/message.py
CHANGED
@@ -23,6 +23,7 @@ from typing import Any, cast, overload
|
|
23
23
|
from flwr.common.date import now
|
24
24
|
from flwr.common.logger import warn_deprecated_feature
|
25
25
|
from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
|
26
|
+
from flwr.proto.message_pb2 import Metadata as ProtoMetadata # pylint: disable=E0611
|
26
27
|
|
27
28
|
from ..app.error import Error
|
28
29
|
from ..app.metadata import Metadata
|
@@ -351,9 +352,12 @@ class Message(InflatableObject):
|
|
351
352
|
|
352
353
|
def deflate(self) -> bytes:
|
353
354
|
"""Deflate message."""
|
355
|
+
# Exclude message_id from serialization
|
356
|
+
proto_metadata: ProtoMetadata = metadata_to_proto(self.metadata)
|
357
|
+
proto_metadata.message_id = ""
|
354
358
|
# Store message metadata and error in object body
|
355
359
|
obj_body = ProtoMessage(
|
356
|
-
metadata=
|
360
|
+
metadata=proto_metadata,
|
357
361
|
content=None,
|
358
362
|
error=error_to_proto(self.error) if self.has_error() else None,
|
359
363
|
).SerializeToString(deterministic=True)
|
flwr/common/record/array.py
CHANGED
@@ -107,10 +107,21 @@ class Array(InflatableObject):
|
|
107
107
|
"""
|
108
108
|
|
109
109
|
dtype: str
|
110
|
-
shape: list[int]
|
111
110
|
stype: str
|
112
111
|
data: bytes
|
113
112
|
|
113
|
+
@property
|
114
|
+
def shape(self) -> list[int]:
|
115
|
+
"""Get the shape of the array."""
|
116
|
+
self.is_dirty = True # Mark as dirty when shape is accessed
|
117
|
+
return cast(list[int], self.__dict__["_shape"])
|
118
|
+
|
119
|
+
@shape.setter
|
120
|
+
def shape(self, value: list[int]) -> None:
|
121
|
+
"""Set the shape of the array."""
|
122
|
+
self.is_dirty = True # Mark as dirty when shape is set
|
123
|
+
self.__dict__["_shape"] = value
|
124
|
+
|
114
125
|
@overload
|
115
126
|
def __init__( # noqa: E704
|
116
127
|
self, dtype: str, shape: list[int], stype: str, data: bytes
|
@@ -295,3 +306,29 @@ class Array(InflatableObject):
|
|
295
306
|
stype=proto_array.stype,
|
296
307
|
data=proto_array.data,
|
297
308
|
)
|
309
|
+
|
310
|
+
@property
|
311
|
+
def object_id(self) -> str:
|
312
|
+
"""Get object ID."""
|
313
|
+
ret = super().object_id
|
314
|
+
self.is_dirty = False # Reset dirty flag
|
315
|
+
return ret
|
316
|
+
|
317
|
+
@property
|
318
|
+
def is_dirty(self) -> bool:
|
319
|
+
"""Check if the object is dirty after the last deflation."""
|
320
|
+
if "_is_dirty" not in self.__dict__:
|
321
|
+
self.__dict__["_is_dirty"] = True
|
322
|
+
return cast(bool, self.__dict__["_is_dirty"])
|
323
|
+
|
324
|
+
@is_dirty.setter
|
325
|
+
def is_dirty(self, value: bool) -> None:
|
326
|
+
"""Set the dirty flag."""
|
327
|
+
self.__dict__["_is_dirty"] = value
|
328
|
+
|
329
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
330
|
+
"""Set attribute with special handling for dirty state."""
|
331
|
+
if name in ("dtype", "stype", "data"):
|
332
|
+
# Mark as dirty if any of the main attributes are set
|
333
|
+
self.is_dirty = True
|
334
|
+
super().__setattr__(name, value)
|
@@ -429,6 +429,40 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
429
429
|
)
|
430
430
|
)
|
431
431
|
|
432
|
+
@property
|
433
|
+
def object_id(self) -> str:
|
434
|
+
"""Get object ID."""
|
435
|
+
ret = super().object_id
|
436
|
+
self.is_dirty = False # Reset dirty flag
|
437
|
+
return ret
|
438
|
+
|
439
|
+
@property
|
440
|
+
def is_dirty(self) -> bool:
|
441
|
+
"""Check if the object is dirty after the last deflation."""
|
442
|
+
if "_is_dirty" not in self.__dict__:
|
443
|
+
self.__dict__["_is_dirty"] = True
|
444
|
+
|
445
|
+
if not self.__dict__["_is_dirty"]:
|
446
|
+
if any(v.is_dirty for v in self.values()):
|
447
|
+
# If any Array is dirty, mark the record as dirty
|
448
|
+
self.__dict__["_is_dirty"] = True
|
449
|
+
return cast(bool, self.__dict__["_is_dirty"])
|
450
|
+
|
451
|
+
@is_dirty.setter
|
452
|
+
def is_dirty(self, value: bool) -> None:
|
453
|
+
"""Set the dirty flag."""
|
454
|
+
self.__dict__["_is_dirty"] = value
|
455
|
+
|
456
|
+
def __setitem__(self, key: str, value: Array) -> None:
|
457
|
+
"""Set item and mark the record as dirty."""
|
458
|
+
self.is_dirty = True # Mark as dirty when setting an item
|
459
|
+
super().__setitem__(key, value)
|
460
|
+
|
461
|
+
def __delitem__(self, key: str) -> None:
|
462
|
+
"""Delete item and mark the record as dirty."""
|
463
|
+
self.is_dirty = True # Mark as dirty when deleting an item
|
464
|
+
super().__delitem__(key)
|
465
|
+
|
432
466
|
|
433
467
|
class ParametersRecord(ArrayRecord):
|
434
468
|
"""Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
|
flwr/common/serde.py
CHANGED
@@ -378,7 +378,12 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
378
378
|
|
379
379
|
def array_to_proto(array: Array) -> ProtoArray:
|
380
380
|
"""Serialize Array to ProtoBuf."""
|
381
|
-
return ProtoArray(
|
381
|
+
return ProtoArray(
|
382
|
+
dtype=array.dtype,
|
383
|
+
shape=array.shape,
|
384
|
+
stype=array.stype,
|
385
|
+
data=array.data,
|
386
|
+
)
|
382
387
|
|
383
388
|
|
384
389
|
def array_from_proto(array_proto: ProtoArray) -> Array:
|
flwr/compat/client/app.py
CHANGED
@@ -15,18 +15,12 @@
|
|
15
15
|
"""Flower client app."""
|
16
16
|
|
17
17
|
|
18
|
-
import multiprocessing
|
19
|
-
import os
|
20
|
-
import sys
|
21
|
-
import threading
|
22
18
|
import time
|
23
19
|
from contextlib import AbstractContextManager
|
24
20
|
from logging import ERROR, INFO, WARN
|
25
|
-
from os import urandom
|
26
21
|
from pathlib import Path
|
27
|
-
from typing import Callable, Optional, Union
|
22
|
+
from typing import Callable, Optional, Union
|
28
23
|
|
29
|
-
import grpc
|
30
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
31
25
|
from grpc import RpcError
|
32
26
|
|
@@ -35,11 +29,6 @@ from flwr.cli.config_utils import get_fab_metadata
|
|
35
29
|
from flwr.cli.install import install_from_fab
|
36
30
|
from flwr.client.client import Client
|
37
31
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
38
|
-
from flwr.client.clientapp.app import flwr_clientapp
|
39
|
-
from flwr.client.clientapp.clientappio_servicer import (
|
40
|
-
ClientAppInputs,
|
41
|
-
ClientAppIoServicer,
|
42
|
-
)
|
43
32
|
from flwr.client.grpc_adapter_client.connection import grpc_adapter
|
44
33
|
from flwr.client.grpc_rere_client.connection import grpc_request_response
|
45
34
|
from flwr.client.message_handler.message_handler import handle_control_message
|
@@ -49,13 +38,7 @@ from flwr.client.typing import ClientFnExt
|
|
49
38
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
|
50
39
|
from flwr.common.address import parse_address
|
51
40
|
from flwr.common.constant import (
|
52
|
-
CLIENT_OCTET,
|
53
|
-
CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS,
|
54
|
-
ISOLATION_MODE_PROCESS,
|
55
|
-
ISOLATION_MODE_SUBPROCESS,
|
56
41
|
MAX_RETRY_DELAY,
|
57
|
-
RUN_ID_NUM_BYTES,
|
58
|
-
SERVER_OCTET,
|
59
42
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
60
43
|
TRANSPORT_TYPE_GRPC_BIDI,
|
61
44
|
TRANSPORT_TYPE_GRPC_RERE,
|
@@ -64,12 +47,10 @@ from flwr.common.constant import (
|
|
64
47
|
ErrorCode,
|
65
48
|
)
|
66
49
|
from flwr.common.exit import ExitCode, flwr_exit
|
67
|
-
from flwr.common.grpc import generic_create_grpc_server
|
68
50
|
from flwr.common.logger import log, warn_deprecated_feature
|
69
51
|
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
70
52
|
from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
|
71
53
|
from flwr.compat.client.grpc_client.connection import grpc_connection
|
72
|
-
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
|
73
54
|
from flwr.supernode.nodestate import NodeStateFactory
|
74
55
|
|
75
56
|
|
@@ -238,8 +219,6 @@ def start_client_internal(
|
|
238
219
|
max_retries: Optional[int] = None,
|
239
220
|
max_wait_time: Optional[float] = None,
|
240
221
|
flwr_path: Optional[Path] = None,
|
241
|
-
isolation: Optional[str] = None,
|
242
|
-
clientappio_api_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS,
|
243
222
|
) -> None:
|
244
223
|
"""Start a Flower client node which connects to a Flower server.
|
245
224
|
|
@@ -292,17 +271,6 @@ def start_client_internal(
|
|
292
271
|
If set to None, there is no limit to the total time.
|
293
272
|
flwr_path: Optional[Path] (default: None)
|
294
273
|
The fully resolved path containing installed Flower Apps.
|
295
|
-
isolation : Optional[str] (default: None)
|
296
|
-
Isolation mode for `ClientApp`. Possible values are `subprocess` and
|
297
|
-
`process`. Defaults to `None`, which runs the `ClientApp` in the same process
|
298
|
-
as the SuperNode. If `subprocess`, the `ClientApp` runs in a subprocess started
|
299
|
-
by the SueprNode and communicates using gRPC at the address
|
300
|
-
`clientappio_api_address`. If `process`, the `ClientApp` runs in a separate
|
301
|
-
isolated process and communicates using gRPC at the address
|
302
|
-
`clientappio_api_address`.
|
303
|
-
clientappio_api_address : Optional[str]
|
304
|
-
(default: `CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS`)
|
305
|
-
The SuperNode gRPC server address.
|
306
274
|
"""
|
307
275
|
if insecure is None:
|
308
276
|
insecure = root_certificates is None
|
@@ -328,18 +296,6 @@ def start_client_internal(
|
|
328
296
|
|
329
297
|
load_client_app_fn = _load_client_app
|
330
298
|
|
331
|
-
if isolation:
|
332
|
-
if clientappio_api_address is None:
|
333
|
-
raise ValueError(
|
334
|
-
f"`clientappio_api_address` required when `isolation` is "
|
335
|
-
f"{ISOLATION_MODE_SUBPROCESS} or {ISOLATION_MODE_PROCESS}",
|
336
|
-
)
|
337
|
-
_clientappio_grpc_server, clientappio_servicer = run_clientappio_api_grpc(
|
338
|
-
address=clientappio_api_address,
|
339
|
-
certificates=None,
|
340
|
-
)
|
341
|
-
clientappio_api_address = cast(str, clientappio_api_address)
|
342
|
-
|
343
299
|
# At this point, only `load_client_app_fn` should be used
|
344
300
|
# Both `client` and `client_fn` must not be used directly
|
345
301
|
|
@@ -390,7 +346,6 @@ def start_client_internal(
|
|
390
346
|
run_info_store: Optional[DeprecatedRunInfoStore] = None
|
391
347
|
state_factory = NodeStateFactory()
|
392
348
|
state = state_factory.state()
|
393
|
-
mp_spawn_context = multiprocessing.get_context("spawn")
|
394
349
|
|
395
350
|
runs: dict[int, Run] = {}
|
396
351
|
|
@@ -475,9 +430,8 @@ def start_client_internal(
|
|
475
430
|
run: Run = runs[run_id]
|
476
431
|
if get_fab is not None and run.fab_hash:
|
477
432
|
fab = get_fab(run.fab_hash, run_id)
|
478
|
-
|
479
|
-
|
480
|
-
install_from_fab(fab.content, flwr_path, True)
|
433
|
+
# If `ClientApp` runs in the same process, install the FAB
|
434
|
+
install_from_fab(fab.content, flwr_path, True)
|
481
435
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
482
436
|
else:
|
483
437
|
fab = None
|
@@ -504,73 +458,13 @@ def start_client_internal(
|
|
504
458
|
|
505
459
|
# Handle app loading and task message
|
506
460
|
try:
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
# 2. `process`: ClientApp process gets started separately
|
512
|
-
# (via `flwr-clientapp`), for example, in a separate
|
513
|
-
# Docker container.
|
514
|
-
|
515
|
-
# Generate SuperNode token
|
516
|
-
token = int.from_bytes(urandom(RUN_ID_NUM_BYTES), "little")
|
517
|
-
|
518
|
-
# Mode 1: SuperNode starts ClientApp as subprocess
|
519
|
-
start_subprocess = isolation == ISOLATION_MODE_SUBPROCESS
|
520
|
-
|
521
|
-
# Share Message and Context with servicer
|
522
|
-
clientappio_servicer.set_inputs(
|
523
|
-
clientapp_input=ClientAppInputs(
|
524
|
-
message=message,
|
525
|
-
context=context,
|
526
|
-
run=run,
|
527
|
-
fab=fab,
|
528
|
-
token=token,
|
529
|
-
),
|
530
|
-
token_returned=start_subprocess,
|
531
|
-
)
|
532
|
-
|
533
|
-
if start_subprocess:
|
534
|
-
_octet, _colon, _port = (
|
535
|
-
clientappio_api_address.rpartition(":")
|
536
|
-
)
|
537
|
-
io_address = (
|
538
|
-
f"{CLIENT_OCTET}:{_port}"
|
539
|
-
if _octet == SERVER_OCTET
|
540
|
-
else clientappio_api_address
|
541
|
-
)
|
542
|
-
# Start ClientApp subprocess
|
543
|
-
command = [
|
544
|
-
"flwr-clientapp",
|
545
|
-
"--clientappio-api-address",
|
546
|
-
io_address,
|
547
|
-
"--token",
|
548
|
-
str(token),
|
549
|
-
]
|
550
|
-
command.append("--insecure")
|
551
|
-
|
552
|
-
proc = mp_spawn_context.Process(
|
553
|
-
target=_run_flwr_clientapp,
|
554
|
-
args=(command, os.getpid()),
|
555
|
-
daemon=True,
|
556
|
-
)
|
557
|
-
proc.start()
|
558
|
-
proc.join()
|
559
|
-
else:
|
560
|
-
# Wait for output to become available
|
561
|
-
while not clientappio_servicer.has_outputs():
|
562
|
-
time.sleep(0.1)
|
563
|
-
|
564
|
-
outputs = clientappio_servicer.get_outputs()
|
565
|
-
reply_message, context = outputs.message, outputs.context
|
566
|
-
else:
|
567
|
-
# Load ClientApp instance
|
568
|
-
client_app: ClientApp = load_client_app_fn(
|
569
|
-
fab_id, fab_version, run.fab_hash
|
570
|
-
)
|
461
|
+
# Load ClientApp instance
|
462
|
+
client_app: ClientApp = load_client_app_fn(
|
463
|
+
fab_id, fab_version, run.fab_hash
|
464
|
+
)
|
571
465
|
|
572
|
-
|
573
|
-
|
466
|
+
# Execute ClientApp
|
467
|
+
reply_message = client_app(message=message, context=context)
|
574
468
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
575
469
|
|
576
470
|
# Legacy grpc-bidi
|
@@ -801,39 +695,3 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
|
|
801
695
|
)
|
802
696
|
|
803
697
|
return connection, address, error_type
|
804
|
-
|
805
|
-
|
806
|
-
def _run_flwr_clientapp(args: list[str], main_pid: int) -> None:
|
807
|
-
# Monitor the main process in case of SIGKILL
|
808
|
-
def main_process_monitor() -> None:
|
809
|
-
while True:
|
810
|
-
time.sleep(1)
|
811
|
-
if os.getppid() != main_pid:
|
812
|
-
os.kill(os.getpid(), 9)
|
813
|
-
|
814
|
-
threading.Thread(target=main_process_monitor, daemon=True).start()
|
815
|
-
|
816
|
-
# Run the command
|
817
|
-
sys.argv = args
|
818
|
-
flwr_clientapp()
|
819
|
-
|
820
|
-
|
821
|
-
def run_clientappio_api_grpc(
|
822
|
-
address: str,
|
823
|
-
certificates: Optional[tuple[bytes, bytes, bytes]],
|
824
|
-
) -> tuple[grpc.Server, ClientAppIoServicer]:
|
825
|
-
"""Run ClientAppIo API gRPC server."""
|
826
|
-
clientappio_servicer: grpc.Server = ClientAppIoServicer()
|
827
|
-
clientappio_add_servicer_to_server_fn = add_ClientAppIoServicer_to_server
|
828
|
-
clientappio_grpc_server = generic_create_grpc_server(
|
829
|
-
servicer_and_add_fn=(
|
830
|
-
clientappio_servicer,
|
831
|
-
clientappio_add_servicer_to_server_fn,
|
832
|
-
),
|
833
|
-
server_address=address,
|
834
|
-
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
835
|
-
certificates=certificates,
|
836
|
-
)
|
837
|
-
log(INFO, "Starting Flower ClientAppIo gRPC server on %s", address)
|
838
|
-
clientappio_grpc_server.start()
|
839
|
-
return clientappio_grpc_server, clientappio_servicer
|