flwr-nightly 1.19.0.dev20250610__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/grpc_rere_client/connection.py +48 -29
- flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
- flwr/client/rest_client/connection.py +138 -27
- flwr/common/auth_plugin/auth_plugin.py +6 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/inflatable.py +70 -1
- flwr/common/inflatable_grpc_utils.py +1 -1
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/serde.py +2 -0
- flwr/common/typing.py +5 -3
- flwr/proto/fleet_pb2.py +12 -16
- flwr/proto/fleet_pb2.pyi +4 -19
- flwr/proto/fleet_pb2_grpc.py +34 -0
- flwr/proto/fleet_pb2_grpc.pyi +13 -0
- flwr/proto/message_pb2.py +15 -9
- flwr/proto/message_pb2.pyi +41 -0
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +22 -26
- flwr/proto/serverappio_pb2.pyi +4 -19
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +20 -9
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +56 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
- flwr/server/superlink/linkstate/linkstate.py +6 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
- flwr/server/superlink/utils.py +23 -10
- flwr/supercore/object_store/in_memory_object_store.py +160 -33
- flwr/supercore/object_store/object_store.py +54 -7
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/exec_user_auth_interceptor.py +11 -11
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +45 -44
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Contextmanager for a gRPC request-response channel to the Flower server."""
|
16
16
|
|
17
|
+
|
17
18
|
from collections.abc import Iterator, Sequence
|
18
19
|
from contextlib import contextmanager
|
19
20
|
from copy import copy
|
@@ -30,7 +31,11 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
30
31
|
from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
|
31
32
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
32
33
|
from flwr.common.heartbeat import HeartbeatSender
|
33
|
-
from flwr.common.inflatable import
|
34
|
+
from flwr.common.inflatable import (
|
35
|
+
get_all_nested_objects,
|
36
|
+
get_object_tree,
|
37
|
+
no_object_id_recompute,
|
38
|
+
)
|
34
39
|
from flwr.common.inflatable_grpc_utils import (
|
35
40
|
make_pull_object_fn_grpc,
|
36
41
|
make_push_object_fn_grpc,
|
@@ -62,7 +67,9 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
62
67
|
SendNodeHeartbeatRequest,
|
63
68
|
SendNodeHeartbeatResponse,
|
64
69
|
)
|
65
|
-
from flwr.proto.message_pb2 import
|
70
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
71
|
+
ConfirmMessageReceivedRequest,
|
72
|
+
)
|
66
73
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
67
74
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
68
75
|
|
@@ -269,14 +276,23 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
269
276
|
|
270
277
|
if message_proto:
|
271
278
|
msg_id = message_proto.metadata.message_id
|
279
|
+
run_id = message_proto.metadata.run_id
|
272
280
|
all_object_contents = pull_objects(
|
273
281
|
list(response.objects_to_pull[msg_id].object_ids) + [msg_id],
|
274
282
|
pull_object_fn=make_pull_object_fn_grpc(
|
275
283
|
pull_object_grpc=stub.PullObject,
|
276
284
|
node=node,
|
277
|
-
run_id=
|
285
|
+
run_id=run_id,
|
278
286
|
),
|
279
287
|
)
|
288
|
+
|
289
|
+
# Confirm that the message has been received
|
290
|
+
stub.ConfirmMessageReceived(
|
291
|
+
ConfirmMessageReceivedRequest(
|
292
|
+
node=node, run_id=run_id, message_object_id=msg_id
|
293
|
+
)
|
294
|
+
)
|
295
|
+
|
280
296
|
in_message = cast(
|
281
297
|
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
282
298
|
)
|
@@ -311,33 +327,36 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
311
327
|
log(ERROR, "Invalid out message")
|
312
328
|
return
|
313
329
|
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
response: PushMessagesResponse = stub.PushMessages(request=request)
|
328
|
-
|
329
|
-
if response.objects_to_push:
|
330
|
-
objs_to_push = set(response.objects_to_push[message.object_id].object_ids)
|
331
|
-
push_objects(
|
332
|
-
all_objects,
|
333
|
-
push_object_fn=make_push_object_fn_grpc(
|
334
|
-
push_object_grpc=stub.PushObject,
|
335
|
-
node=node,
|
336
|
-
run_id=message.metadata.run_id,
|
337
|
-
),
|
338
|
-
object_ids_to_push=objs_to_push,
|
330
|
+
with no_object_id_recompute():
|
331
|
+
# Get all nested objects
|
332
|
+
all_objects = get_all_nested_objects(message)
|
333
|
+
object_tree = get_object_tree(message)
|
334
|
+
|
335
|
+
# Serialize Message
|
336
|
+
message_proto = message_to_proto(
|
337
|
+
message=remove_content_from_message(message)
|
338
|
+
)
|
339
|
+
request = PushMessagesRequest(
|
340
|
+
node=node,
|
341
|
+
messages_list=[message_proto],
|
342
|
+
message_object_trees=[object_tree],
|
339
343
|
)
|
340
|
-
|
344
|
+
response: PushMessagesResponse = stub.PushMessages(request=request)
|
345
|
+
|
346
|
+
if response.objects_to_push:
|
347
|
+
objs_to_push = set(
|
348
|
+
response.objects_to_push[message.object_id].object_ids
|
349
|
+
)
|
350
|
+
push_objects(
|
351
|
+
all_objects,
|
352
|
+
push_object_fn=make_push_object_fn_grpc(
|
353
|
+
push_object_grpc=stub.PushObject,
|
354
|
+
node=node,
|
355
|
+
run_id=message.metadata.run_id,
|
356
|
+
),
|
357
|
+
object_ids_to_push=objs_to_push,
|
358
|
+
)
|
359
|
+
log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
|
341
360
|
|
342
361
|
# Cleanup
|
343
362
|
metadata = None
|
@@ -50,6 +50,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
50
50
|
SendNodeHeartbeatResponse,
|
51
51
|
)
|
52
52
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
53
|
+
ConfirmMessageReceivedRequest,
|
54
|
+
ConfirmMessageReceivedResponse,
|
53
55
|
PullObjectRequest,
|
54
56
|
PullObjectResponse,
|
55
57
|
PushObjectRequest,
|
@@ -169,3 +171,9 @@ class GrpcAdapter:
|
|
169
171
|
) -> PullObjectResponse:
|
170
172
|
"""."""
|
171
173
|
return self._send_and_receive(request, PullObjectResponse, **kwargs)
|
174
|
+
|
175
|
+
def ConfirmMessageReceived( # pylint: disable=C0103
|
176
|
+
self, request: ConfirmMessageReceivedRequest, **kwargs: Any
|
177
|
+
) -> ConfirmMessageReceivedResponse:
|
178
|
+
"""."""
|
179
|
+
return self._send_and_receive(request, ConfirmMessageReceivedResponse, **kwargs)
|
@@ -14,12 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Contextmanager for a REST request-response channel to the Flower server."""
|
16
16
|
|
17
|
-
|
18
17
|
from collections.abc import Iterator
|
19
18
|
from contextlib import contextmanager
|
20
19
|
from copy import copy
|
21
|
-
from logging import ERROR, INFO, WARN
|
22
|
-
from typing import Callable, Optional, TypeVar, Union
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
21
|
+
from typing import Callable, Optional, TypeVar, Union, cast
|
23
22
|
|
24
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
25
24
|
from google.protobuf.message import Message as GrpcMessage
|
@@ -31,10 +30,24 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
31
30
|
from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
|
32
31
|
from flwr.common.exit import ExitCode, flwr_exit
|
33
32
|
from flwr.common.heartbeat import HeartbeatSender
|
33
|
+
from flwr.common.inflatable import (
|
34
|
+
get_all_nested_objects,
|
35
|
+
get_object_tree,
|
36
|
+
no_object_id_recompute,
|
37
|
+
)
|
38
|
+
from flwr.common.inflatable_rest_utils import (
|
39
|
+
make_pull_object_fn_rest,
|
40
|
+
make_push_object_fn_rest,
|
41
|
+
)
|
42
|
+
from flwr.common.inflatable_utils import (
|
43
|
+
inflate_object_from_contents,
|
44
|
+
pull_objects,
|
45
|
+
push_objects,
|
46
|
+
)
|
34
47
|
from flwr.common.logger import log
|
35
|
-
from flwr.common.message import Message
|
48
|
+
from flwr.common.message import Message, remove_content_from_message
|
36
49
|
from flwr.common.retry_invoker import RetryInvoker
|
37
|
-
from flwr.common.serde import
|
50
|
+
from flwr.common.serde import message_to_proto, run_from_proto
|
38
51
|
from flwr.common.typing import Fab, Run
|
39
52
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
40
53
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
@@ -51,6 +64,14 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
51
64
|
SendNodeHeartbeatRequest,
|
52
65
|
SendNodeHeartbeatResponse,
|
53
66
|
)
|
67
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
68
|
+
ConfirmMessageReceivedRequest,
|
69
|
+
ConfirmMessageReceivedResponse,
|
70
|
+
PullObjectRequest,
|
71
|
+
PullObjectResponse,
|
72
|
+
PushObjectRequest,
|
73
|
+
PushObjectResponse,
|
74
|
+
)
|
54
75
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
55
76
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
56
77
|
|
@@ -64,9 +85,12 @@ PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
|
64
85
|
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
65
86
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
66
87
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
88
|
+
PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
|
89
|
+
PATH_PUSH_OBJECT: str = "/api/v0/fleet/push-object"
|
67
90
|
PATH_SEND_NODE_HEARTBEAT: str = "api/v0/fleet/send-node-heartbeat"
|
68
91
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
69
92
|
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
|
93
|
+
PATH_CONFIRM_MESSAGE_RECEIVED: str = "/api/v0/fleet/confirm-message-received"
|
70
94
|
|
71
95
|
T = TypeVar("T", bound=GrpcMessage)
|
72
96
|
|
@@ -296,14 +320,58 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
296
320
|
):
|
297
321
|
message_proto = None
|
298
322
|
|
299
|
-
#
|
300
|
-
|
301
|
-
|
302
|
-
if message_proto
|
303
|
-
message = message_from_proto(message_proto)
|
304
|
-
metadata = copy(message.metadata)
|
323
|
+
# Construct the Message
|
324
|
+
in_message: Optional[Message] = None
|
325
|
+
|
326
|
+
if message_proto:
|
305
327
|
log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
|
306
|
-
|
328
|
+
msg_id = message_proto.metadata.message_id
|
329
|
+
run_id = message_proto.metadata.run_id
|
330
|
+
|
331
|
+
def fn(request: PullObjectRequest) -> PullObjectResponse:
|
332
|
+
res = _request(
|
333
|
+
req=request, res_type=PullObjectResponse, api_path=PATH_PULL_OBJECT
|
334
|
+
)
|
335
|
+
if res is None:
|
336
|
+
raise ValueError("PushObjectResponse is None.")
|
337
|
+
return res
|
338
|
+
|
339
|
+
try:
|
340
|
+
all_object_contents = pull_objects(
|
341
|
+
list(res.objects_to_pull[msg_id].object_ids) + [msg_id],
|
342
|
+
pull_object_fn=make_pull_object_fn_rest(
|
343
|
+
pull_object_rest=fn,
|
344
|
+
node=node,
|
345
|
+
run_id=run_id,
|
346
|
+
),
|
347
|
+
)
|
348
|
+
|
349
|
+
# Confirm that the message has been received
|
350
|
+
_request(
|
351
|
+
req=ConfirmMessageReceivedRequest(
|
352
|
+
node=node, run_id=run_id, message_object_id=msg_id
|
353
|
+
),
|
354
|
+
res_type=ConfirmMessageReceivedResponse,
|
355
|
+
api_path=PATH_CONFIRM_MESSAGE_RECEIVED,
|
356
|
+
)
|
357
|
+
except ValueError as e:
|
358
|
+
log(
|
359
|
+
ERROR,
|
360
|
+
"Pulling objects failed. Potential irrecoverable error: %s",
|
361
|
+
str(e),
|
362
|
+
)
|
363
|
+
in_message = cast(
|
364
|
+
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
365
|
+
)
|
366
|
+
# The deflated message doesn't contain the message_id (its own object_id)
|
367
|
+
# Inject
|
368
|
+
in_message.metadata.__dict__["_message_id"] = msg_id
|
369
|
+
|
370
|
+
# Remember `metadata` of the in message
|
371
|
+
nonlocal metadata
|
372
|
+
metadata = copy(in_message.metadata) if in_message else None
|
373
|
+
|
374
|
+
return in_message
|
307
375
|
|
308
376
|
def send(message: Message) -> None:
|
309
377
|
"""Send Message result back to server."""
|
@@ -318,29 +386,72 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
318
386
|
log(ERROR, "No current message")
|
319
387
|
return
|
320
388
|
|
389
|
+
# Set message_id
|
390
|
+
message.metadata.__dict__["_message_id"] = message.object_id
|
321
391
|
# Validate out message
|
322
392
|
if not validate_out_message(message, metadata):
|
323
393
|
log(ERROR, "Invalid out message")
|
324
394
|
return
|
325
|
-
metadata = None
|
326
395
|
|
327
|
-
|
328
|
-
|
396
|
+
with no_object_id_recompute():
|
397
|
+
# Get all nested objects
|
398
|
+
all_objects = get_all_nested_objects(message)
|
399
|
+
object_tree = get_object_tree(message)
|
329
400
|
|
330
|
-
|
331
|
-
|
401
|
+
# Serialize Message
|
402
|
+
message_proto = message_to_proto(
|
403
|
+
message=remove_content_from_message(message)
|
404
|
+
)
|
405
|
+
req = PushMessagesRequest(
|
406
|
+
node=node,
|
407
|
+
messages_list=[message_proto],
|
408
|
+
message_object_trees=[object_tree],
|
409
|
+
)
|
332
410
|
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
411
|
+
# Send the request
|
412
|
+
res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
|
413
|
+
if res:
|
414
|
+
log(
|
415
|
+
INFO,
|
416
|
+
"[Node] POST /%s: success, created result %s",
|
417
|
+
PATH_PUSH_MESSAGES,
|
418
|
+
res.results, # pylint: disable=no-member
|
419
|
+
)
|
420
|
+
|
421
|
+
if res and res.objects_to_push:
|
422
|
+
objs_to_push = set(res.objects_to_push[message.object_id].object_ids)
|
423
|
+
|
424
|
+
def fn(request: PushObjectRequest) -> PushObjectResponse:
|
425
|
+
res = _request(
|
426
|
+
req=request,
|
427
|
+
res_type=PushObjectResponse,
|
428
|
+
api_path=PATH_PUSH_OBJECT,
|
429
|
+
)
|
430
|
+
if res is None:
|
431
|
+
raise ValueError("PushObjectResponse is None.")
|
432
|
+
return res
|
433
|
+
|
434
|
+
try:
|
435
|
+
push_objects(
|
436
|
+
all_objects,
|
437
|
+
push_object_fn=make_push_object_fn_rest(
|
438
|
+
push_object_rest=fn,
|
439
|
+
node=node,
|
440
|
+
run_id=message_proto.metadata.run_id,
|
441
|
+
),
|
442
|
+
object_ids_to_push=objs_to_push,
|
443
|
+
)
|
444
|
+
log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
|
445
|
+
except ValueError as e:
|
446
|
+
log(
|
447
|
+
ERROR,
|
448
|
+
"Pushing objects failed. Potential irrecoverable error: %s",
|
449
|
+
str(e),
|
450
|
+
)
|
451
|
+
log(ERROR, str(e))
|
337
452
|
|
338
|
-
|
339
|
-
|
340
|
-
"[Node] POST /%s: success, created result %s",
|
341
|
-
PATH_PUSH_MESSAGES,
|
342
|
-
res.results, # pylint: disable=no-member
|
343
|
-
)
|
453
|
+
# Cleanup
|
454
|
+
metadata = None
|
344
455
|
|
345
456
|
def get_run(run_id: int) -> Run:
|
346
457
|
# Construct the request
|
@@ -20,7 +20,7 @@ from collections.abc import Sequence
|
|
20
20
|
from pathlib import Path
|
21
21
|
from typing import Optional, Union
|
22
22
|
|
23
|
-
from flwr.common.typing import
|
23
|
+
from flwr.common.typing import AccountInfo
|
24
24
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
25
25
|
|
26
26
|
from ..typing import UserAuthCredentials, UserAuthLoginDetails
|
@@ -53,7 +53,7 @@ class ExecAuthPlugin(ABC):
|
|
53
53
|
@abstractmethod
|
54
54
|
def validate_tokens_in_metadata(
|
55
55
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
56
|
-
) -> tuple[bool, Optional[
|
56
|
+
) -> tuple[bool, Optional[AccountInfo]]:
|
57
57
|
"""Validate authentication tokens in the provided metadata."""
|
58
58
|
|
59
59
|
@abstractmethod
|
@@ -63,7 +63,9 @@ class ExecAuthPlugin(ABC):
|
|
63
63
|
@abstractmethod
|
64
64
|
def refresh_tokens(
|
65
65
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
66
|
-
) -> tuple[
|
66
|
+
) -> tuple[
|
67
|
+
Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[AccountInfo]
|
68
|
+
]:
|
67
69
|
"""Refresh authentication tokens in the provided metadata."""
|
68
70
|
|
69
71
|
|
@@ -84,7 +86,7 @@ class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
|
|
84
86
|
"""Abstract constructor."""
|
85
87
|
|
86
88
|
@abstractmethod
|
87
|
-
def verify_user_authorization(self,
|
89
|
+
def verify_user_authorization(self, account_info: AccountInfo) -> bool:
|
88
90
|
"""Verify user authorization request."""
|
89
91
|
|
90
92
|
|
@@ -21,7 +21,7 @@ from typing import Optional, Union
|
|
21
21
|
import grpc
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
23
23
|
|
24
|
-
from flwr.common.typing import
|
24
|
+
from flwr.common.typing import AccountInfo, LogEntry
|
25
25
|
|
26
26
|
|
27
27
|
class EventLogWriterPlugin(ABC):
|
@@ -36,7 +36,7 @@ class EventLogWriterPlugin(ABC):
|
|
36
36
|
self,
|
37
37
|
request: GrpcMessage,
|
38
38
|
context: grpc.ServicerContext,
|
39
|
-
|
39
|
+
account_info: Optional[AccountInfo],
|
40
40
|
method_name: str,
|
41
41
|
) -> LogEntry:
|
42
42
|
"""Compose pre-event log entry from the provided request and context."""
|
@@ -46,7 +46,7 @@ class EventLogWriterPlugin(ABC):
|
|
46
46
|
self,
|
47
47
|
request: GrpcMessage,
|
48
48
|
context: grpc.ServicerContext,
|
49
|
-
|
49
|
+
account_info: Optional[AccountInfo],
|
50
50
|
method_name: str,
|
51
51
|
response: Optional[Union[GrpcMessage, BaseException]],
|
52
52
|
) -> LogEntry:
|
flwr/common/inflatable.py
CHANGED
@@ -18,8 +18,13 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import hashlib
|
21
|
+
import threading
|
22
|
+
from collections.abc import Iterator
|
23
|
+
from contextlib import contextmanager
|
21
24
|
from typing import TypeVar, cast
|
22
25
|
|
26
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
27
|
+
|
23
28
|
from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
|
24
29
|
|
25
30
|
|
@@ -33,6 +38,33 @@ class UnexpectedObjectContentError(Exception):
|
|
33
38
|
)
|
34
39
|
|
35
40
|
|
41
|
+
_ctx = threading.local()
|
42
|
+
|
43
|
+
|
44
|
+
def _is_recompute_enabled() -> bool:
|
45
|
+
"""Check if recomputing object IDs is enabled."""
|
46
|
+
return getattr(_ctx, "recompute_object_id_enabled", True)
|
47
|
+
|
48
|
+
|
49
|
+
def _get_computed_object_ids() -> set[str]:
|
50
|
+
"""Get the set of computed object IDs."""
|
51
|
+
return getattr(_ctx, "computed_object_ids", set())
|
52
|
+
|
53
|
+
|
54
|
+
@contextmanager
|
55
|
+
def no_object_id_recompute() -> Iterator[None]:
|
56
|
+
"""Context manager to disable recomputing object IDs."""
|
57
|
+
old_value = _is_recompute_enabled()
|
58
|
+
old_set = _get_computed_object_ids()
|
59
|
+
_ctx.recompute_object_id_enabled = False
|
60
|
+
_ctx.computed_object_ids = set()
|
61
|
+
try:
|
62
|
+
yield
|
63
|
+
finally:
|
64
|
+
_ctx.recompute_object_id_enabled = old_value
|
65
|
+
_ctx.computed_object_ids = old_set
|
66
|
+
|
67
|
+
|
36
68
|
class InflatableObject:
|
37
69
|
"""Base class for inflatable objects."""
|
38
70
|
|
@@ -65,8 +97,23 @@ class InflatableObject:
|
|
65
97
|
@property
|
66
98
|
def object_id(self) -> str:
|
67
99
|
"""Get object_id."""
|
100
|
+
# If recomputing object ID is disabled and the object ID is already computed,
|
101
|
+
# return the cached object ID.
|
102
|
+
if (
|
103
|
+
not _is_recompute_enabled()
|
104
|
+
and (obj_id := self.__dict__.get("_object_id"))
|
105
|
+
in _get_computed_object_ids()
|
106
|
+
):
|
107
|
+
return cast(str, obj_id)
|
108
|
+
|
68
109
|
if self.is_dirty or "_object_id" not in self.__dict__:
|
69
|
-
|
110
|
+
obj_id = get_object_id(self.deflate())
|
111
|
+
self.__dict__["_object_id"] = obj_id
|
112
|
+
|
113
|
+
# If recomputing object ID is disabled, add the object ID to the set of
|
114
|
+
# computed object IDs to avoid recomputing it within the context.
|
115
|
+
if not _is_recompute_enabled():
|
116
|
+
_get_computed_object_ids().add(obj_id)
|
70
117
|
return cast(str, self.__dict__["_object_id"])
|
71
118
|
|
72
119
|
@property
|
@@ -219,3 +266,25 @@ def get_all_nested_objects(obj: InflatableObject) -> dict[str, InflatableObject]
|
|
219
266
|
ret[obj.object_id] = obj
|
220
267
|
|
221
268
|
return ret
|
269
|
+
|
270
|
+
|
271
|
+
def get_object_tree(obj: InflatableObject) -> ObjectTree:
|
272
|
+
"""Get a tree representation of the InflatableObject."""
|
273
|
+
tree_children = []
|
274
|
+
if children := obj.children:
|
275
|
+
for child in children.values():
|
276
|
+
tree_children.append(get_object_tree(child))
|
277
|
+
return ObjectTree(object_id=obj.object_id, children=tree_children)
|
278
|
+
|
279
|
+
|
280
|
+
def iterate_object_tree(
|
281
|
+
tree: ObjectTree,
|
282
|
+
) -> Iterator[ObjectTree]:
|
283
|
+
"""Iterate over the object tree and yield object IDs.
|
284
|
+
|
285
|
+
This function performs a post-order traversal of the tree, yielding the object ID of
|
286
|
+
each node after all its children have been yielded.
|
287
|
+
"""
|
288
|
+
for child in tree.children:
|
289
|
+
yield from iterate_object_tree(child)
|
290
|
+
yield tree
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""InflatableObject utils."""
|
15
|
+
"""InflatableObject gRPC utils."""
|
16
16
|
|
17
17
|
|
18
18
|
from typing import Callable
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""InflatableObject REST utils."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Callable
|
19
|
+
|
20
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
21
|
+
PullObjectRequest,
|
22
|
+
PullObjectResponse,
|
23
|
+
PushObjectRequest,
|
24
|
+
PushObjectResponse,
|
25
|
+
)
|
26
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
27
|
+
|
28
|
+
from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
|
29
|
+
|
30
|
+
|
31
|
+
def make_pull_object_fn_rest(
|
32
|
+
pull_object_rest: Callable[[PullObjectRequest], PullObjectResponse],
|
33
|
+
node: Node,
|
34
|
+
run_id: int,
|
35
|
+
) -> Callable[[str], bytes]:
|
36
|
+
"""Create a pull object function that uses REST to pull objects.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
pull_object_rest : Callable[[PullObjectRequest], PullObjectResponse]
|
41
|
+
A function that makes a POST request against the `/push-object` REST endpoint
|
42
|
+
node : Node
|
43
|
+
The node making the request.
|
44
|
+
run_id : int
|
45
|
+
The run ID for the current operation.
|
46
|
+
|
47
|
+
Returns
|
48
|
+
-------
|
49
|
+
Callable[[str], bytes]
|
50
|
+
A function that takes an object ID and returns the object content as bytes.
|
51
|
+
The function raises `ObjectIdNotPreregisteredError` if the object ID is not
|
52
|
+
pre-registered, or `ObjectUnavailableError` if the object is not yet available.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def pull_object_fn(object_id: str) -> bytes:
|
56
|
+
request = PullObjectRequest(node=node, run_id=run_id, object_id=object_id)
|
57
|
+
response: PullObjectResponse = pull_object_rest(request)
|
58
|
+
if not response.object_found:
|
59
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
60
|
+
if not response.object_available:
|
61
|
+
raise ObjectUnavailableError(object_id)
|
62
|
+
return response.object_content
|
63
|
+
|
64
|
+
return pull_object_fn
|
65
|
+
|
66
|
+
|
67
|
+
def make_push_object_fn_rest(
|
68
|
+
push_object_rest: Callable[[PushObjectRequest], PushObjectResponse],
|
69
|
+
node: Node,
|
70
|
+
run_id: int,
|
71
|
+
) -> Callable[[str, bytes], None]:
|
72
|
+
"""Create a push object function that uses REST to push objects.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
push_object_rest : Callable[[PushObjectRequest], PushObjectResponse]
|
77
|
+
A function that makes a POST request against the `/pull-object` REST endpoint
|
78
|
+
node : Node
|
79
|
+
The node making the request.
|
80
|
+
run_id : int
|
81
|
+
The run ID for the current operation.
|
82
|
+
|
83
|
+
Returns
|
84
|
+
-------
|
85
|
+
Callable[[str, bytes], None]
|
86
|
+
A function that takes an object ID and its content as bytes, and pushes it
|
87
|
+
to the servicer. The function raises `ObjectIdNotPreregisteredError` if
|
88
|
+
the object ID is not pre-registered.
|
89
|
+
"""
|
90
|
+
|
91
|
+
def push_object_fn(object_id: str, object_content: bytes) -> None:
|
92
|
+
request = PushObjectRequest(
|
93
|
+
node=node, run_id=run_id, object_id=object_id, object_content=object_content
|
94
|
+
)
|
95
|
+
response: PushObjectResponse = push_object_rest(request)
|
96
|
+
if not response.stored:
|
97
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
98
|
+
|
99
|
+
return push_object_fn
|
flwr/common/serde.py
CHANGED
@@ -630,6 +630,7 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
|
|
630
630
|
running_at=run.running_at,
|
631
631
|
finished_at=run.finished_at,
|
632
632
|
status=run_status_to_proto(run.status),
|
633
|
+
flwr_aid=run.flwr_aid,
|
633
634
|
)
|
634
635
|
return proto
|
635
636
|
|
@@ -647,6 +648,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
647
648
|
running_at=run_proto.running_at,
|
648
649
|
finished_at=run_proto.finished_at,
|
649
650
|
status=run_status_from_proto(run_proto.status),
|
651
|
+
flwr_aid=run_proto.flwr_aid,
|
650
652
|
)
|
651
653
|
return run
|
652
654
|
|
flwr/common/typing.py
CHANGED
@@ -230,6 +230,7 @@ class Run: # pylint: disable=too-many-instance-attributes
|
|
230
230
|
running_at: str
|
231
231
|
finished_at: str
|
232
232
|
status: RunStatus
|
233
|
+
flwr_aid: str
|
233
234
|
|
234
235
|
@classmethod
|
235
236
|
def create_empty(cls, run_id: int) -> "Run":
|
@@ -245,6 +246,7 @@ class Run: # pylint: disable=too-many-instance-attributes
|
|
245
246
|
running_at="",
|
246
247
|
finished_at="",
|
247
248
|
status=RunStatus(status="", sub_status="", details=""),
|
249
|
+
flwr_aid="",
|
248
250
|
)
|
249
251
|
|
250
252
|
|
@@ -289,11 +291,11 @@ class UserAuthCredentials:
|
|
289
291
|
|
290
292
|
|
291
293
|
@dataclass
|
292
|
-
class
|
294
|
+
class AccountInfo:
|
293
295
|
"""User information for event log."""
|
294
296
|
|
295
|
-
|
296
|
-
|
297
|
+
flwr_aid: Optional[str]
|
298
|
+
account_name: Optional[str]
|
297
299
|
|
298
300
|
|
299
301
|
@dataclass
|