flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +82 -57
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +10 -18
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +31 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +4 -4
- flwr/client/grpc_rere_client/connection.py +130 -60
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +173 -67
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +36 -7
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -183
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +19 -159
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +68 -186
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +18 -0
- flwr/server/superlink/ffs/__init__.py +2 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
- flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
- flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
- flwr/server/superlink/utils.py +44 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +7 -3
- flwr/superexec/exec_servicer.py +125 -23
- flwr/superexec/exec_user_auth_interceptor.py +37 -8
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
flwr/common/heartbeat.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Heartbeat sender."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import random
|
|
19
|
+
import threading
|
|
20
|
+
from typing import Callable, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
|
|
24
|
+
# pylint: disable=E0611
|
|
25
|
+
from flwr.proto.heartbeat_pb2 import SendAppHeartbeatRequest
|
|
26
|
+
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
|
|
27
|
+
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub
|
|
28
|
+
|
|
29
|
+
# pylint: enable=E0611
|
|
30
|
+
from .constant import (
|
|
31
|
+
HEARTBEAT_BASE_MULTIPLIER,
|
|
32
|
+
HEARTBEAT_CALL_TIMEOUT,
|
|
33
|
+
HEARTBEAT_DEFAULT_INTERVAL,
|
|
34
|
+
HEARTBEAT_RANDOM_RANGE,
|
|
35
|
+
)
|
|
36
|
+
from .retry_invoker import RetryInvoker, exponential
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class HeartbeatFailure(Exception):
|
|
40
|
+
"""Exception raised when a heartbeat fails."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class HeartbeatSender:
|
|
44
|
+
"""Periodically send heartbeat signals to a server in a background thread.
|
|
45
|
+
|
|
46
|
+
This class uses the provided `heartbeat_fn` to send heartbeats. If a heartbeat
|
|
47
|
+
attempt fails, it will be retried using an exponential backoff strategy.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
heartbeat_fn : Callable[[], bool]
|
|
52
|
+
Function used to send a heartbeat signal. It should return True if the heartbeat
|
|
53
|
+
succeeds, or False if it fails. Any internal exceptions (e.g., gRPC errors)
|
|
54
|
+
should be handled within this function to ensure boolean return values.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
heartbeat_fn: Callable[[], bool],
|
|
60
|
+
) -> None:
|
|
61
|
+
self.heartbeat_fn = heartbeat_fn
|
|
62
|
+
self._stop_event = threading.Event()
|
|
63
|
+
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
64
|
+
self._retry_invoker = RetryInvoker(
|
|
65
|
+
lambda: exponential(max_delay=20),
|
|
66
|
+
HeartbeatFailure, # The only exception we want to retry on
|
|
67
|
+
max_tries=None,
|
|
68
|
+
max_time=None,
|
|
69
|
+
# Allow the stop event to interrupt the wait
|
|
70
|
+
wait_function=self._stop_event.wait, # type: ignore
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def start(self) -> None:
|
|
74
|
+
"""Start the heartbeat sender."""
|
|
75
|
+
if self._thread.is_alive():
|
|
76
|
+
raise RuntimeError("Heartbeat sender is already running.")
|
|
77
|
+
if self._stop_event.is_set():
|
|
78
|
+
raise RuntimeError("Cannot start a stopped heartbeat sender.")
|
|
79
|
+
self._thread.start()
|
|
80
|
+
|
|
81
|
+
def stop(self) -> None:
|
|
82
|
+
"""Stop the heartbeat sender."""
|
|
83
|
+
if not self._thread.is_alive():
|
|
84
|
+
raise RuntimeError("Heartbeat sender is not running.")
|
|
85
|
+
self._stop_event.set()
|
|
86
|
+
self._thread.join()
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def is_running(self) -> bool:
|
|
90
|
+
"""Return True if the heartbeat sender is running, False otherwise."""
|
|
91
|
+
return self._thread.is_alive() and not self._stop_event.is_set()
|
|
92
|
+
|
|
93
|
+
def _run(self) -> None:
|
|
94
|
+
"""Periodically send heartbeats until stopped."""
|
|
95
|
+
while not self._stop_event.is_set():
|
|
96
|
+
# Attempt to send a heartbeat with retry on failure
|
|
97
|
+
self._retry_invoker.invoke(self._heartbeat)
|
|
98
|
+
|
|
99
|
+
# Calculate the interval for the next heartbeat
|
|
100
|
+
# Formula: next_interval = (interval - timeout) * random.uniform(0.7, 0.9)
|
|
101
|
+
rd = random.uniform(*HEARTBEAT_RANDOM_RANGE)
|
|
102
|
+
next_interval: float = HEARTBEAT_DEFAULT_INTERVAL - HEARTBEAT_CALL_TIMEOUT
|
|
103
|
+
next_interval *= HEARTBEAT_BASE_MULTIPLIER + rd
|
|
104
|
+
|
|
105
|
+
# Wait for the calculated interval or exit early if stopped
|
|
106
|
+
self._stop_event.wait(next_interval)
|
|
107
|
+
|
|
108
|
+
def _heartbeat(self) -> None:
|
|
109
|
+
"""Send a single heartbeat and raise an exception if it fails.
|
|
110
|
+
|
|
111
|
+
Call the provided `heartbeat_fn`. If the function returns False,
|
|
112
|
+
a `HeartbeatFailure` exception is raised to trigger the retry mechanism.
|
|
113
|
+
"""
|
|
114
|
+
if not self._stop_event.is_set():
|
|
115
|
+
if not self.heartbeat_fn():
|
|
116
|
+
raise HeartbeatFailure
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_grpc_app_heartbeat_fn(
|
|
120
|
+
stub: Union[ServerAppIoStub, SimulationIoStub],
|
|
121
|
+
run_id: int,
|
|
122
|
+
*,
|
|
123
|
+
failure_message: str,
|
|
124
|
+
) -> Callable[[], bool]:
|
|
125
|
+
"""Get the function to send a heartbeat to gRPC endpoint.
|
|
126
|
+
|
|
127
|
+
This function is for app heartbeats only. It is not used for node heartbeats.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
stub : Union[ServerAppIoStub, SimulationIoStub]
|
|
132
|
+
gRPC stub to send the heartbeat.
|
|
133
|
+
run_id : int
|
|
134
|
+
The run ID to use in the heartbeat request.
|
|
135
|
+
failure_message : str
|
|
136
|
+
Error message to raise if the heartbeat fails.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
Callable[[], bool]
|
|
141
|
+
Function that sends a heartbeat to the gRPC endpoint.
|
|
142
|
+
"""
|
|
143
|
+
# Construct the heartbeat request
|
|
144
|
+
req = SendAppHeartbeatRequest(
|
|
145
|
+
run_id=run_id, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def fn() -> bool:
|
|
149
|
+
# Call ServerAppIo API
|
|
150
|
+
try:
|
|
151
|
+
res = stub.SendAppHeartbeat(req)
|
|
152
|
+
except grpc.RpcError as e:
|
|
153
|
+
status_code = e.code()
|
|
154
|
+
if status_code == grpc.StatusCode.UNAVAILABLE:
|
|
155
|
+
return False
|
|
156
|
+
if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
157
|
+
return False
|
|
158
|
+
raise
|
|
159
|
+
|
|
160
|
+
# Check if not successful
|
|
161
|
+
if not res.success:
|
|
162
|
+
raise RuntimeError(failure_message)
|
|
163
|
+
return True
|
|
164
|
+
|
|
165
|
+
return fn
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""InflatableObject base class."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import hashlib
|
|
21
|
+
import threading
|
|
22
|
+
from collections.abc import Iterator
|
|
23
|
+
from contextlib import contextmanager
|
|
24
|
+
from typing import TypeVar, cast
|
|
25
|
+
|
|
26
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
27
|
+
|
|
28
|
+
from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class UnexpectedObjectContentError(Exception):
|
|
32
|
+
"""Exception raised when the content of an object does not conform to the expected
|
|
33
|
+
structure for an InflatableObject (i.e., head, body, and values within the head)."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, object_id: str, reason: str):
|
|
36
|
+
super().__init__(
|
|
37
|
+
f"Object with ID '{object_id}' has an unexpected structure. {reason}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
_ctx = threading.local()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _is_recompute_enabled() -> bool:
|
|
45
|
+
"""Check if recomputing object IDs is enabled."""
|
|
46
|
+
return getattr(_ctx, "recompute_object_id_enabled", True)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_computed_object_ids() -> set[str]:
|
|
50
|
+
"""Get the set of computed object IDs."""
|
|
51
|
+
return getattr(_ctx, "computed_object_ids", set())
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@contextmanager
|
|
55
|
+
def no_object_id_recompute() -> Iterator[None]:
|
|
56
|
+
"""Context manager to disable recomputing object IDs."""
|
|
57
|
+
old_value = _is_recompute_enabled()
|
|
58
|
+
old_set = _get_computed_object_ids()
|
|
59
|
+
_ctx.recompute_object_id_enabled = False
|
|
60
|
+
_ctx.computed_object_ids = set()
|
|
61
|
+
try:
|
|
62
|
+
yield
|
|
63
|
+
finally:
|
|
64
|
+
_ctx.recompute_object_id_enabled = old_value
|
|
65
|
+
_ctx.computed_object_ids = old_set
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class InflatableObject:
|
|
69
|
+
"""Base class for inflatable objects."""
|
|
70
|
+
|
|
71
|
+
def deflate(self) -> bytes:
|
|
72
|
+
"""Deflate object."""
|
|
73
|
+
raise NotImplementedError()
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def inflate(
|
|
77
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
78
|
+
) -> InflatableObject:
|
|
79
|
+
"""Inflate the object from bytes.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
object_content : bytes
|
|
84
|
+
The deflated object content.
|
|
85
|
+
|
|
86
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
87
|
+
Dictionary of children InflatableObjects mapped to their object IDs. These
|
|
88
|
+
childrens enable the full inflation of the parent InflatableObject.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
InflatableObject
|
|
93
|
+
The inflated object.
|
|
94
|
+
"""
|
|
95
|
+
raise NotImplementedError()
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def object_id(self) -> str:
|
|
99
|
+
"""Get object_id."""
|
|
100
|
+
# If recomputing object ID is disabled and the object ID is already computed,
|
|
101
|
+
# return the cached object ID.
|
|
102
|
+
if (
|
|
103
|
+
not _is_recompute_enabled()
|
|
104
|
+
and (obj_id := self.__dict__.get("_object_id"))
|
|
105
|
+
in _get_computed_object_ids()
|
|
106
|
+
):
|
|
107
|
+
return cast(str, obj_id)
|
|
108
|
+
|
|
109
|
+
if self.is_dirty or "_object_id" not in self.__dict__:
|
|
110
|
+
obj_id = get_object_id(self.deflate())
|
|
111
|
+
self.__dict__["_object_id"] = obj_id
|
|
112
|
+
|
|
113
|
+
# If recomputing object ID is disabled, add the object ID to the set of
|
|
114
|
+
# computed object IDs to avoid recomputing it within the context.
|
|
115
|
+
if not _is_recompute_enabled():
|
|
116
|
+
_get_computed_object_ids().add(obj_id)
|
|
117
|
+
return cast(str, self.__dict__["_object_id"])
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def children(self) -> dict[str, InflatableObject] | None:
|
|
121
|
+
"""Get all child objects as a dictionary or None if there are no children."""
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def is_dirty(self) -> bool:
|
|
126
|
+
"""Check if the object is dirty after the last deflation.
|
|
127
|
+
|
|
128
|
+
An object is considered dirty if its content has changed since the last its
|
|
129
|
+
object ID was computed.
|
|
130
|
+
"""
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
T = TypeVar("T", bound=InflatableObject)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_object_id(object_content: bytes) -> str:
|
|
138
|
+
"""Return a SHA-256 hash of the (deflated) object content."""
|
|
139
|
+
return hashlib.sha256(object_content).hexdigest()
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_object_body(object_content: bytes, cls: type[T]) -> bytes:
|
|
143
|
+
"""Return object body but raise an error if object type doesn't match class name."""
|
|
144
|
+
class_name = cls.__qualname__
|
|
145
|
+
object_type = get_object_type_from_object_content(object_content)
|
|
146
|
+
if not object_type == class_name:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Class name ({class_name}) and object type "
|
|
149
|
+
f"({object_type}) do not match."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Return object body
|
|
153
|
+
return _get_object_body(object_content)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def add_header_to_object_body(object_body: bytes, obj: InflatableObject) -> bytes:
|
|
157
|
+
"""Add header to object content."""
|
|
158
|
+
# Construct header
|
|
159
|
+
header = f"%s{HEAD_VALUE_DIVIDER}%s{HEAD_VALUE_DIVIDER}%d" % (
|
|
160
|
+
obj.__class__.__qualname__, # Type of object
|
|
161
|
+
",".join((obj.children or {}).keys()), # IDs of child objects
|
|
162
|
+
len(object_body), # Length of object body
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Concatenate header and object body
|
|
166
|
+
ret = bytearray()
|
|
167
|
+
ret.extend(header.encode(encoding="utf-8"))
|
|
168
|
+
ret.extend(HEAD_BODY_DIVIDER)
|
|
169
|
+
ret.extend(object_body)
|
|
170
|
+
return bytes(ret)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _get_object_head(object_content: bytes) -> bytes:
|
|
174
|
+
"""Return object head from object content."""
|
|
175
|
+
index = object_content.find(HEAD_BODY_DIVIDER)
|
|
176
|
+
return object_content[:index]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _get_object_body(object_content: bytes) -> bytes:
|
|
180
|
+
"""Return object body from object content."""
|
|
181
|
+
index = object_content.find(HEAD_BODY_DIVIDER)
|
|
182
|
+
return object_content[index + len(HEAD_BODY_DIVIDER) :]
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def is_valid_sha256_hash(object_id: str) -> bool:
|
|
186
|
+
"""Check if the given string is a valid SHA-256 hash.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
object_id : str
|
|
191
|
+
The string to check.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
bool
|
|
196
|
+
``True`` if the string is a valid SHA-256 hash, ``False`` otherwise.
|
|
197
|
+
"""
|
|
198
|
+
if len(object_id) != 64:
|
|
199
|
+
return False
|
|
200
|
+
try:
|
|
201
|
+
# If base 16 int conversion succeeds, it's a valid hexadecimal str
|
|
202
|
+
int(object_id, 16)
|
|
203
|
+
return True
|
|
204
|
+
except ValueError:
|
|
205
|
+
return False
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_object_type_from_object_content(object_content: bytes) -> str:
|
|
209
|
+
"""Return object type from bytes."""
|
|
210
|
+
return get_object_head_values_from_object_content(object_content)[0]
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_object_children_ids_from_object_content(object_content: bytes) -> list[str]:
|
|
214
|
+
"""Return object children IDs from bytes."""
|
|
215
|
+
return get_object_head_values_from_object_content(object_content)[1]
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def get_object_body_len_from_object_content(object_content: bytes) -> int:
|
|
219
|
+
"""Return length of the object body."""
|
|
220
|
+
return get_object_head_values_from_object_content(object_content)[2]
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def get_object_head_values_from_object_content(
|
|
224
|
+
object_content: bytes,
|
|
225
|
+
) -> tuple[str, list[str], int]:
|
|
226
|
+
"""Return object type and body length from object content.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
object_content : bytes
|
|
231
|
+
The deflated object content.
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
tuple[str, list[str], int]
|
|
236
|
+
A tuple containing:
|
|
237
|
+
- The object type as a string.
|
|
238
|
+
- A list of child object IDs as strings.
|
|
239
|
+
- The length of the object body as an integer.
|
|
240
|
+
"""
|
|
241
|
+
head = _get_object_head(object_content).decode(encoding="utf-8")
|
|
242
|
+
obj_type, children_str, body_len = head.split(HEAD_VALUE_DIVIDER)
|
|
243
|
+
children_ids = children_str.split(",") if children_str else []
|
|
244
|
+
return obj_type, children_ids, int(body_len)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def get_descendant_object_ids(obj: InflatableObject) -> set[str]:
|
|
248
|
+
"""Get a set of object IDs of all descendants."""
|
|
249
|
+
descendants = set(get_all_nested_objects(obj).keys())
|
|
250
|
+
# Exclude Object ID of parent object
|
|
251
|
+
descendants.discard(obj.object_id)
|
|
252
|
+
return descendants
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def get_all_nested_objects(obj: InflatableObject) -> dict[str, InflatableObject]:
|
|
256
|
+
"""Get a dictionary of all nested objects, including the object itself.
|
|
257
|
+
|
|
258
|
+
Each key in the dictionary is an object ID, and the entries are ordered by post-
|
|
259
|
+
order traversal, i.e., child objects appear before their respective parents.
|
|
260
|
+
"""
|
|
261
|
+
ret: dict[str, InflatableObject] = {}
|
|
262
|
+
if children := obj.children:
|
|
263
|
+
for child in children.values():
|
|
264
|
+
ret.update(get_all_nested_objects(child))
|
|
265
|
+
|
|
266
|
+
ret[obj.object_id] = obj
|
|
267
|
+
|
|
268
|
+
return ret
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def get_object_tree(obj: InflatableObject) -> ObjectTree:
|
|
272
|
+
"""Get a tree representation of the InflatableObject."""
|
|
273
|
+
tree_children = []
|
|
274
|
+
if children := obj.children:
|
|
275
|
+
for child in children.values():
|
|
276
|
+
tree_children.append(get_object_tree(child))
|
|
277
|
+
return ObjectTree(object_id=obj.object_id, children=tree_children)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def iterate_object_tree(
|
|
281
|
+
tree: ObjectTree,
|
|
282
|
+
) -> Iterator[ObjectTree]:
|
|
283
|
+
"""Iterate over the object tree and yield object IDs.
|
|
284
|
+
|
|
285
|
+
This function performs a post-order traversal of the tree, yielding the object ID of
|
|
286
|
+
each node after all its children have been yielded.
|
|
287
|
+
"""
|
|
288
|
+
for child in tree.children:
|
|
289
|
+
yield from iterate_object_tree(child)
|
|
290
|
+
yield tree
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""InflatableObject gRPC utils."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from typing import Callable
|
|
19
|
+
|
|
20
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
21
|
+
PullObjectRequest,
|
|
22
|
+
PullObjectResponse,
|
|
23
|
+
PushObjectRequest,
|
|
24
|
+
PushObjectResponse,
|
|
25
|
+
)
|
|
26
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
27
|
+
|
|
28
|
+
from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def make_pull_object_fn_grpc(
|
|
32
|
+
pull_object_grpc: Callable[[PullObjectRequest], PullObjectResponse],
|
|
33
|
+
node: Node,
|
|
34
|
+
run_id: int,
|
|
35
|
+
) -> Callable[[str], bytes]:
|
|
36
|
+
"""Create a pull object function that uses gRPC to pull objects.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
pull_object_grpc : Callable[[PullObjectRequest], PullObjectResponse]
|
|
41
|
+
The gRPC function to pull objects, e.g., `FleetStub.PullObject`.
|
|
42
|
+
node : Node
|
|
43
|
+
The node making the request.
|
|
44
|
+
run_id : int
|
|
45
|
+
The run ID for the current operation.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
Callable[[str], bytes]
|
|
50
|
+
A function that takes an object ID and returns the object content as bytes.
|
|
51
|
+
The function raises `ObjectIdNotPreregisteredError` if the object ID is not
|
|
52
|
+
pre-registered, or `ObjectUnavailableError` if the object is not yet available.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def pull_object_fn(object_id: str) -> bytes:
|
|
56
|
+
request = PullObjectRequest(node=node, run_id=run_id, object_id=object_id)
|
|
57
|
+
response: PullObjectResponse = pull_object_grpc(request)
|
|
58
|
+
if not response.object_found:
|
|
59
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
|
60
|
+
if not response.object_available:
|
|
61
|
+
raise ObjectUnavailableError(object_id)
|
|
62
|
+
return response.object_content
|
|
63
|
+
|
|
64
|
+
return pull_object_fn
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def make_push_object_fn_grpc(
|
|
68
|
+
push_object_grpc: Callable[[PushObjectRequest], PushObjectResponse],
|
|
69
|
+
node: Node,
|
|
70
|
+
run_id: int,
|
|
71
|
+
) -> Callable[[str, bytes], None]:
|
|
72
|
+
"""Create a push object function that uses gRPC to push objects.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
push_object_grpc : Callable[[PushObjectRequest], PushObjectResponse]
|
|
77
|
+
The gRPC function to push objects, e.g., `FleetStub.PushObject`.
|
|
78
|
+
node : Node
|
|
79
|
+
The node making the request.
|
|
80
|
+
run_id : int
|
|
81
|
+
The run ID for the current operation.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
Callable[[str, bytes], None]
|
|
86
|
+
A function that takes an object ID and its content as bytes, and pushes it
|
|
87
|
+
to the servicer. The function raises `ObjectIdNotPreregisteredError` if
|
|
88
|
+
the object ID is not pre-registered.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def push_object_fn(object_id: str, object_content: bytes) -> None:
|
|
92
|
+
request = PushObjectRequest(
|
|
93
|
+
node=node, run_id=run_id, object_id=object_id, object_content=object_content
|
|
94
|
+
)
|
|
95
|
+
response: PushObjectResponse = push_object_grpc(request)
|
|
96
|
+
if not response.stored:
|
|
97
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
|
98
|
+
|
|
99
|
+
return push_object_fn
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""InflatableObject REST utils."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from typing import Callable
|
|
19
|
+
|
|
20
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
21
|
+
PullObjectRequest,
|
|
22
|
+
PullObjectResponse,
|
|
23
|
+
PushObjectRequest,
|
|
24
|
+
PushObjectResponse,
|
|
25
|
+
)
|
|
26
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
27
|
+
|
|
28
|
+
from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def make_pull_object_fn_rest(
|
|
32
|
+
pull_object_rest: Callable[[PullObjectRequest], PullObjectResponse],
|
|
33
|
+
node: Node,
|
|
34
|
+
run_id: int,
|
|
35
|
+
) -> Callable[[str], bytes]:
|
|
36
|
+
"""Create a pull object function that uses REST to pull objects.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
pull_object_rest : Callable[[PullObjectRequest], PullObjectResponse]
|
|
41
|
+
A function that makes a POST request against the `/push-object` REST endpoint
|
|
42
|
+
node : Node
|
|
43
|
+
The node making the request.
|
|
44
|
+
run_id : int
|
|
45
|
+
The run ID for the current operation.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
Callable[[str], bytes]
|
|
50
|
+
A function that takes an object ID and returns the object content as bytes.
|
|
51
|
+
The function raises `ObjectIdNotPreregisteredError` if the object ID is not
|
|
52
|
+
pre-registered, or `ObjectUnavailableError` if the object is not yet available.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def pull_object_fn(object_id: str) -> bytes:
|
|
56
|
+
request = PullObjectRequest(node=node, run_id=run_id, object_id=object_id)
|
|
57
|
+
response: PullObjectResponse = pull_object_rest(request)
|
|
58
|
+
if not response.object_found:
|
|
59
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
|
60
|
+
if not response.object_available:
|
|
61
|
+
raise ObjectUnavailableError(object_id)
|
|
62
|
+
return response.object_content
|
|
63
|
+
|
|
64
|
+
return pull_object_fn
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def make_push_object_fn_rest(
|
|
68
|
+
push_object_rest: Callable[[PushObjectRequest], PushObjectResponse],
|
|
69
|
+
node: Node,
|
|
70
|
+
run_id: int,
|
|
71
|
+
) -> Callable[[str, bytes], None]:
|
|
72
|
+
"""Create a push object function that uses REST to push objects.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
push_object_rest : Callable[[PushObjectRequest], PushObjectResponse]
|
|
77
|
+
A function that makes a POST request against the `/pull-object` REST endpoint
|
|
78
|
+
node : Node
|
|
79
|
+
The node making the request.
|
|
80
|
+
run_id : int
|
|
81
|
+
The run ID for the current operation.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
Callable[[str, bytes], None]
|
|
86
|
+
A function that takes an object ID and its content as bytes, and pushes it
|
|
87
|
+
to the servicer. The function raises `ObjectIdNotPreregisteredError` if
|
|
88
|
+
the object ID is not pre-registered.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def push_object_fn(object_id: str, object_content: bytes) -> None:
|
|
92
|
+
request = PushObjectRequest(
|
|
93
|
+
node=node, run_id=run_id, object_id=object_id, object_content=object_content
|
|
94
|
+
)
|
|
95
|
+
response: PushObjectResponse = push_object_rest(request)
|
|
96
|
+
if not response.stored:
|
|
97
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
|
98
|
+
|
|
99
|
+
return push_object_fn
|