flwr-nightly 1.19.0.dev20250609__py3-none-any.whl → 1.19.0.dev20250611__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/grpc_rere_client/connection.py +4 -1
- flwr/client/rest_client/connection.py +118 -26
- flwr/common/auth_plugin/auth_plugin.py +6 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/inflatable.py +46 -1
- flwr/common/inflatable_grpc_utils.py +3 -266
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +268 -2
- flwr/common/typing.py +3 -3
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +21 -56
- flwr/server/superlink/fleet/message_handler/message_handler.py +57 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +30 -0
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_user_auth_interceptor.py +11 -11
- flwr/supernode/start_client_internal.py +101 -59
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/RECORD +21 -20
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""InflatableObject REST utils."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Callable
|
19
|
+
|
20
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
21
|
+
PullObjectRequest,
|
22
|
+
PullObjectResponse,
|
23
|
+
PushObjectRequest,
|
24
|
+
PushObjectResponse,
|
25
|
+
)
|
26
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
27
|
+
|
28
|
+
from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
|
29
|
+
|
30
|
+
|
31
|
+
def make_pull_object_fn_rest(
|
32
|
+
pull_object_rest: Callable[[PullObjectRequest], PullObjectResponse],
|
33
|
+
node: Node,
|
34
|
+
run_id: int,
|
35
|
+
) -> Callable[[str], bytes]:
|
36
|
+
"""Create a pull object function that uses REST to pull objects.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
pull_object_rest : Callable[[PullObjectRequest], PullObjectResponse]
|
41
|
+
A function that makes a POST request against the `/push-object` REST endpoint
|
42
|
+
node : Node
|
43
|
+
The node making the request.
|
44
|
+
run_id : int
|
45
|
+
The run ID for the current operation.
|
46
|
+
|
47
|
+
Returns
|
48
|
+
-------
|
49
|
+
Callable[[str], bytes]
|
50
|
+
A function that takes an object ID and returns the object content as bytes.
|
51
|
+
The function raises `ObjectIdNotPreregisteredError` if the object ID is not
|
52
|
+
pre-registered, or `ObjectUnavailableError` if the object is not yet available.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def pull_object_fn(object_id: str) -> bytes:
|
56
|
+
request = PullObjectRequest(node=node, run_id=run_id, object_id=object_id)
|
57
|
+
response: PullObjectResponse = pull_object_rest(request)
|
58
|
+
if not response.object_found:
|
59
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
60
|
+
if not response.object_available:
|
61
|
+
raise ObjectUnavailableError(object_id)
|
62
|
+
return response.object_content
|
63
|
+
|
64
|
+
return pull_object_fn
|
65
|
+
|
66
|
+
|
67
|
+
def make_push_object_fn_rest(
|
68
|
+
push_object_rest: Callable[[PushObjectRequest], PushObjectResponse],
|
69
|
+
node: Node,
|
70
|
+
run_id: int,
|
71
|
+
) -> Callable[[str, bytes], None]:
|
72
|
+
"""Create a push object function that uses REST to push objects.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
push_object_rest : Callable[[PushObjectRequest], PushObjectResponse]
|
77
|
+
A function that makes a POST request against the `/pull-object` REST endpoint
|
78
|
+
node : Node
|
79
|
+
The node making the request.
|
80
|
+
run_id : int
|
81
|
+
The run ID for the current operation.
|
82
|
+
|
83
|
+
Returns
|
84
|
+
-------
|
85
|
+
Callable[[str, bytes], None]
|
86
|
+
A function that takes an object ID and its content as bytes, and pushes it
|
87
|
+
to the servicer. The function raises `ObjectIdNotPreregisteredError` if
|
88
|
+
the object ID is not pre-registered.
|
89
|
+
"""
|
90
|
+
|
91
|
+
def push_object_fn(object_id: str, object_content: bytes) -> None:
|
92
|
+
request = PushObjectRequest(
|
93
|
+
node=node, run_id=run_id, object_id=object_id, object_content=object_content
|
94
|
+
)
|
95
|
+
response: PushObjectResponse = push_object_rest(request)
|
96
|
+
if not response.stored:
|
97
|
+
raise ObjectIdNotPreregisteredError(object_id)
|
98
|
+
|
99
|
+
return push_object_fn
|
flwr/common/inflatable_utils.py
CHANGED
@@ -14,15 +14,281 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""InflatableObject utilities."""
|
16
16
|
|
17
|
+
import concurrent.futures
|
18
|
+
import random
|
19
|
+
import threading
|
20
|
+
import time
|
21
|
+
from typing import Callable, Optional
|
17
22
|
|
18
|
-
from .constant import
|
23
|
+
from .constant import (
|
24
|
+
HEAD_BODY_DIVIDER,
|
25
|
+
HEAD_VALUE_DIVIDER,
|
26
|
+
MAX_CONCURRENT_PULLS,
|
27
|
+
MAX_CONCURRENT_PUSHES,
|
28
|
+
PULL_BACKOFF_CAP,
|
29
|
+
PULL_INITIAL_BACKOFF,
|
30
|
+
PULL_MAX_TIME,
|
31
|
+
PULL_MAX_TRIES_PER_OBJECT,
|
32
|
+
)
|
19
33
|
from .inflatable import (
|
34
|
+
InflatableObject,
|
20
35
|
UnexpectedObjectContentError,
|
21
36
|
_get_object_head,
|
37
|
+
get_object_head_values_from_object_content,
|
22
38
|
get_object_id,
|
23
39
|
is_valid_sha256_hash,
|
24
40
|
)
|
25
|
-
from .
|
41
|
+
from .message import Message
|
42
|
+
from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
43
|
+
|
44
|
+
# Helper registry that maps names of classes to their type
|
45
|
+
inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
46
|
+
Array.__qualname__: Array,
|
47
|
+
ArrayRecord.__qualname__: ArrayRecord,
|
48
|
+
ConfigRecord.__qualname__: ConfigRecord,
|
49
|
+
Message.__qualname__: Message,
|
50
|
+
MetricRecord.__qualname__: MetricRecord,
|
51
|
+
RecordDict.__qualname__: RecordDict,
|
52
|
+
}
|
53
|
+
|
54
|
+
|
55
|
+
class ObjectUnavailableError(Exception):
|
56
|
+
"""Exception raised when an object has been pre-registered but is not yet
|
57
|
+
available."""
|
58
|
+
|
59
|
+
def __init__(self, object_id: str):
|
60
|
+
super().__init__(f"Object with ID '{object_id}' is not yet available.")
|
61
|
+
|
62
|
+
|
63
|
+
class ObjectIdNotPreregisteredError(Exception):
|
64
|
+
"""Exception raised when an object ID is not pre-registered."""
|
65
|
+
|
66
|
+
def __init__(self, object_id: str):
|
67
|
+
super().__init__(f"Object with ID '{object_id}' could not be found.")
|
68
|
+
|
69
|
+
|
70
|
+
def push_objects(
|
71
|
+
objects: dict[str, InflatableObject],
|
72
|
+
push_object_fn: Callable[[str, bytes], None],
|
73
|
+
*,
|
74
|
+
object_ids_to_push: Optional[set[str]] = None,
|
75
|
+
keep_objects: bool = False,
|
76
|
+
max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
|
77
|
+
) -> None:
|
78
|
+
"""Push multiple objects to the servicer.
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
objects : dict[str, InflatableObject]
|
83
|
+
A dictionary of objects to push, where keys are object IDs and values are
|
84
|
+
`InflatableObject` instances.
|
85
|
+
push_object_fn : Callable[[str, bytes], None]
|
86
|
+
A function that takes an object ID and its content as bytes, and pushes
|
87
|
+
it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
|
88
|
+
if the object ID is not pre-registered.
|
89
|
+
object_ids_to_push : Optional[set[str]] (default: None)
|
90
|
+
A set of object IDs to push. If not provided, all objects will be pushed.
|
91
|
+
keep_objects : bool (default: False)
|
92
|
+
If `True`, the original objects will be kept in the `objects` dictionary
|
93
|
+
after pushing. If `False`, they will be removed from the dictionary to avoid
|
94
|
+
high memory usage.
|
95
|
+
max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
|
96
|
+
The maximum number of concurrent pushes to perform.
|
97
|
+
"""
|
98
|
+
if object_ids_to_push is not None:
|
99
|
+
# Filter objects to push only those with IDs in the set
|
100
|
+
objects = {k: v for k, v in objects.items() if k in object_ids_to_push}
|
101
|
+
|
102
|
+
lock = threading.Lock()
|
103
|
+
|
104
|
+
def push(obj_id: str) -> None:
|
105
|
+
"""Push a single object."""
|
106
|
+
object_content = objects[obj_id].deflate()
|
107
|
+
if not keep_objects:
|
108
|
+
with lock:
|
109
|
+
del objects[obj_id]
|
110
|
+
push_object_fn(obj_id, object_content)
|
111
|
+
|
112
|
+
with concurrent.futures.ThreadPoolExecutor(
|
113
|
+
max_workers=max_concurrent_pushes
|
114
|
+
) as executor:
|
115
|
+
list(executor.map(push, list(objects.keys())))
|
116
|
+
|
117
|
+
|
118
|
+
def pull_objects( # pylint: disable=too-many-arguments
|
119
|
+
object_ids: list[str],
|
120
|
+
pull_object_fn: Callable[[str], bytes],
|
121
|
+
*,
|
122
|
+
max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
|
123
|
+
max_time: Optional[float] = PULL_MAX_TIME,
|
124
|
+
max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
|
125
|
+
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
126
|
+
backoff_cap: float = PULL_BACKOFF_CAP,
|
127
|
+
) -> dict[str, bytes]:
|
128
|
+
"""Pull multiple objects from the servicer.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
object_ids : list[str]
|
133
|
+
A list of object IDs to pull.
|
134
|
+
pull_object_fn : Callable[[str], bytes]
|
135
|
+
A function that takes an object ID and returns the object content as bytes.
|
136
|
+
The function should raise `ObjectUnavailableError` if the object is not yet
|
137
|
+
available, or `ObjectIdNotPreregisteredError` if the object ID is not
|
138
|
+
pre-registered.
|
139
|
+
max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
|
140
|
+
The maximum number of concurrent pulls to perform.
|
141
|
+
max_time : Optional[float] (default: PULL_MAX_TIME)
|
142
|
+
The maximum time to wait for all pulls to complete. If `None`, waits
|
143
|
+
indefinitely.
|
144
|
+
max_tries_per_object : Optional[int] (default: PULL_MAX_TRIES_PER_OBJECT)
|
145
|
+
The maximum number of attempts to pull each object. If `None`, pulls
|
146
|
+
indefinitely until the object is available.
|
147
|
+
initial_backoff : float (default: PULL_INITIAL_BACKOFF)
|
148
|
+
The initial backoff time in seconds for retrying pulls after an
|
149
|
+
`ObjectUnavailableError`.
|
150
|
+
backoff_cap : float (default: PULL_BACKOFF_CAP)
|
151
|
+
The maximum backoff time in seconds. Backoff times will not exceed this value.
|
152
|
+
|
153
|
+
Returns
|
154
|
+
-------
|
155
|
+
dict[str, bytes]
|
156
|
+
A dictionary where keys are object IDs and values are the pulled
|
157
|
+
object contents.
|
158
|
+
"""
|
159
|
+
if max_tries_per_object is None:
|
160
|
+
max_tries_per_object = int(1e9)
|
161
|
+
if max_time is None:
|
162
|
+
max_time = float("inf")
|
163
|
+
|
164
|
+
results: dict[str, bytes] = {}
|
165
|
+
results_lock = threading.Lock()
|
166
|
+
err_to_raise: Optional[Exception] = None
|
167
|
+
early_stop = threading.Event()
|
168
|
+
start = time.monotonic()
|
169
|
+
|
170
|
+
def pull_with_retries(object_id: str) -> None:
|
171
|
+
"""Attempt to pull a single object with retry and backoff."""
|
172
|
+
nonlocal err_to_raise
|
173
|
+
tries = 0
|
174
|
+
delay = initial_backoff
|
175
|
+
|
176
|
+
while not early_stop.is_set():
|
177
|
+
try:
|
178
|
+
object_content = pull_object_fn(object_id)
|
179
|
+
with results_lock:
|
180
|
+
results[object_id] = object_content
|
181
|
+
return
|
182
|
+
|
183
|
+
except ObjectUnavailableError as err:
|
184
|
+
tries += 1
|
185
|
+
if (
|
186
|
+
tries >= max_tries_per_object
|
187
|
+
or time.monotonic() - start >= max_time
|
188
|
+
):
|
189
|
+
# Stop all work if one object exhausts retries
|
190
|
+
early_stop.set()
|
191
|
+
with results_lock:
|
192
|
+
if err_to_raise is None:
|
193
|
+
err_to_raise = err
|
194
|
+
return
|
195
|
+
|
196
|
+
# Apply exponential backoff with ±20% jitter
|
197
|
+
sleep_time = delay * (1 + random.uniform(-0.2, 0.2))
|
198
|
+
early_stop.wait(sleep_time)
|
199
|
+
delay = min(delay * 2, backoff_cap)
|
200
|
+
|
201
|
+
except ObjectIdNotPreregisteredError as err:
|
202
|
+
# Permanent failure: object ID is invalid
|
203
|
+
early_stop.set()
|
204
|
+
with results_lock:
|
205
|
+
if err_to_raise is None:
|
206
|
+
err_to_raise = err
|
207
|
+
return
|
208
|
+
|
209
|
+
# Submit all pull tasks concurrently
|
210
|
+
with concurrent.futures.ThreadPoolExecutor(
|
211
|
+
max_workers=max_concurrent_pulls
|
212
|
+
) as executor:
|
213
|
+
futures = {
|
214
|
+
executor.submit(pull_with_retries, obj_id): obj_id for obj_id in object_ids
|
215
|
+
}
|
216
|
+
|
217
|
+
# Wait for completion
|
218
|
+
concurrent.futures.wait(futures)
|
219
|
+
|
220
|
+
if err_to_raise is not None:
|
221
|
+
raise err_to_raise
|
222
|
+
|
223
|
+
return results
|
224
|
+
|
225
|
+
|
226
|
+
def inflate_object_from_contents(
|
227
|
+
object_id: str,
|
228
|
+
object_contents: dict[str, bytes],
|
229
|
+
*,
|
230
|
+
keep_object_contents: bool = False,
|
231
|
+
objects: Optional[dict[str, InflatableObject]] = None,
|
232
|
+
) -> InflatableObject:
|
233
|
+
"""Inflate an object from object contents.
|
234
|
+
|
235
|
+
Parameters
|
236
|
+
----------
|
237
|
+
object_id : str
|
238
|
+
The ID of the object to inflate.
|
239
|
+
object_contents : dict[str, bytes]
|
240
|
+
A dictionary mapping object IDs to their contents as bytes.
|
241
|
+
All descendant objects must be present in this dictionary.
|
242
|
+
keep_object_contents : bool (default: False)
|
243
|
+
If `True`, the object content will be kept in the `object_contents`
|
244
|
+
dictionary after inflation. If `False`, the object content will be
|
245
|
+
removed from the dictionary to save memory.
|
246
|
+
objects : Optional[dict[str, InflatableObject]] (default: None)
|
247
|
+
No need to provide this parameter. A dictionary to store already
|
248
|
+
inflated objects, mapping object IDs to their corresponding
|
249
|
+
`InflatableObject` instances.
|
250
|
+
|
251
|
+
Returns
|
252
|
+
-------
|
253
|
+
InflatableObject
|
254
|
+
The inflated object.
|
255
|
+
"""
|
256
|
+
if objects is None:
|
257
|
+
# Initialize objects dictionary
|
258
|
+
objects = {}
|
259
|
+
|
260
|
+
if object_id in objects:
|
261
|
+
# If the object is already in the objects dictionary, return it
|
262
|
+
return objects[object_id]
|
263
|
+
|
264
|
+
# Extract object class and object_ids of children
|
265
|
+
object_content = object_contents[object_id]
|
266
|
+
obj_type, children_obj_ids, _ = get_object_head_values_from_object_content(
|
267
|
+
object_content=object_contents[object_id]
|
268
|
+
)
|
269
|
+
|
270
|
+
# Remove the object content from the dictionary to save memory
|
271
|
+
if not keep_object_contents:
|
272
|
+
del object_contents[object_id]
|
273
|
+
|
274
|
+
# Resolve object class
|
275
|
+
cls_type = inflatable_class_registry[obj_type]
|
276
|
+
|
277
|
+
# Inflate all children objects
|
278
|
+
children: dict[str, InflatableObject] = {}
|
279
|
+
for child_obj_id in children_obj_ids:
|
280
|
+
children[child_obj_id] = inflate_object_from_contents(
|
281
|
+
child_obj_id,
|
282
|
+
object_contents,
|
283
|
+
keep_object_contents=keep_object_contents,
|
284
|
+
objects=objects,
|
285
|
+
)
|
286
|
+
|
287
|
+
# Inflate object passing its children
|
288
|
+
obj = cls_type.inflate(object_content, children=children)
|
289
|
+
del object_content # Free memory after inflation
|
290
|
+
objects[object_id] = obj
|
291
|
+
return obj
|
26
292
|
|
27
293
|
|
28
294
|
def validate_object_content(content: bytes) -> None:
|
flwr/common/typing.py
CHANGED
@@ -289,11 +289,11 @@ class UserAuthCredentials:
|
|
289
289
|
|
290
290
|
|
291
291
|
@dataclass
|
292
|
-
class
|
292
|
+
class AccountInfo:
|
293
293
|
"""User information for event log."""
|
294
294
|
|
295
|
-
|
296
|
-
|
295
|
+
flwr_aid: Optional[str]
|
296
|
+
account_name: Optional[str]
|
297
297
|
|
298
298
|
|
299
299
|
@dataclass
|
@@ -59,7 +59,7 @@ class FleetEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
59
59
|
log_entry = self.log_plugin.compose_log_before_event(
|
60
60
|
request=request,
|
61
61
|
context=context,
|
62
|
-
|
62
|
+
account_info=None,
|
63
63
|
method_name=method_name,
|
64
64
|
)
|
65
65
|
self.log_plugin.write_log(log_entry)
|
@@ -75,7 +75,7 @@ class FleetEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
75
75
|
log_entry = self.log_plugin.compose_log_after_event(
|
76
76
|
request=request,
|
77
77
|
context=context,
|
78
|
-
|
78
|
+
account_info=None,
|
79
79
|
method_name=method_name,
|
80
80
|
response=unary_response or error,
|
81
81
|
)
|
flwr/server/grid/grpc_grid.py
CHANGED
@@ -30,9 +30,11 @@ from flwr.common.constant import (
|
|
30
30
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
31
31
|
from flwr.common.inflatable import get_all_nested_objects
|
32
32
|
from flwr.common.inflatable_grpc_utils import (
|
33
|
-
inflate_object_from_contents,
|
34
33
|
make_pull_object_fn_grpc,
|
35
34
|
make_push_object_fn_grpc,
|
35
|
+
)
|
36
|
+
from flwr.common.inflatable_utils import (
|
37
|
+
inflate_object_from_contents,
|
36
38
|
pull_objects,
|
37
39
|
push_objects,
|
38
40
|
)
|
@@ -15,12 +15,11 @@
|
|
15
15
|
"""Fleet API gRPC request-response servicer."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import DEBUG,
|
18
|
+
from logging import DEBUG, INFO
|
19
19
|
|
20
20
|
import grpc
|
21
21
|
from google.protobuf.json_format import MessageToDict
|
22
22
|
|
23
|
-
from flwr.common.constant import Status
|
24
23
|
from flwr.common.inflatable import UnexpectedObjectContentError
|
25
24
|
from flwr.common.logger import log
|
26
25
|
from flwr.common.typing import InvalidRunStatusException
|
@@ -50,9 +49,8 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
|
|
50
49
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
51
50
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
52
51
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
53
|
-
from flwr.server.superlink.utils import abort_grpc_context
|
52
|
+
from flwr.server.superlink.utils import abort_grpc_context
|
54
53
|
from flwr.supercore.object_store import ObjectStoreFactory
|
55
|
-
from flwr.supercore.object_store.object_store import NoObjectInStoreError
|
56
54
|
|
57
55
|
|
58
56
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
@@ -185,36 +183,20 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
185
183
|
request.object_id,
|
186
184
|
)
|
187
185
|
|
188
|
-
state = self.state_factory.state()
|
189
|
-
|
190
|
-
# Abort if the run is not running
|
191
|
-
abort_msg = check_abort(
|
192
|
-
request.run_id,
|
193
|
-
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
194
|
-
state,
|
195
|
-
)
|
196
|
-
if abort_msg:
|
197
|
-
abort_grpc_context(abort_msg, context)
|
198
|
-
|
199
|
-
if request.node.node_id not in state.get_nodes(run_id=request.run_id):
|
200
|
-
# Cancel insertion in ObjectStore
|
201
|
-
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
202
|
-
|
203
|
-
# Init store
|
204
|
-
store = self.objectstore_factory.store()
|
205
|
-
|
206
|
-
# Insert in store
|
207
|
-
stored = False
|
208
186
|
try:
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
187
|
+
# Insert in Store
|
188
|
+
res = message_handler.push_object(
|
189
|
+
request=request,
|
190
|
+
state=self.state_factory.state(),
|
191
|
+
store=self.objectstore_factory.store(),
|
192
|
+
)
|
193
|
+
except InvalidRunStatusException as e:
|
194
|
+
abort_grpc_context(e.message, context)
|
213
195
|
except UnexpectedObjectContentError as e:
|
214
196
|
# Object content is not valid
|
215
197
|
context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
|
216
198
|
|
217
|
-
return
|
199
|
+
return res
|
218
200
|
|
219
201
|
def PullObject(
|
220
202
|
self, request: PullObjectRequest, context: grpc.ServicerContext
|
@@ -226,31 +208,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
226
208
|
request.object_id,
|
227
209
|
)
|
228
210
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
state,
|
236
|
-
)
|
237
|
-
if abort_msg:
|
238
|
-
abort_grpc_context(abort_msg, context)
|
239
|
-
|
240
|
-
if request.node.node_id not in state.get_nodes(run_id=request.run_id):
|
241
|
-
# Cancel insertion in ObjectStore
|
242
|
-
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
243
|
-
|
244
|
-
# Init store
|
245
|
-
store = self.objectstore_factory.store()
|
246
|
-
|
247
|
-
# Fetch from store
|
248
|
-
content = store.get(request.object_id)
|
249
|
-
if content is not None:
|
250
|
-
object_available = content != b""
|
251
|
-
return PullObjectResponse(
|
252
|
-
object_found=True,
|
253
|
-
object_available=object_available,
|
254
|
-
object_content=content,
|
211
|
+
try:
|
212
|
+
# Fetch from store
|
213
|
+
res = message_handler.pull_object(
|
214
|
+
request=request,
|
215
|
+
state=self.state_factory.state(),
|
216
|
+
store=self.objectstore_factory.store(),
|
255
217
|
)
|
256
|
-
|
218
|
+
except InvalidRunStatusException as e:
|
219
|
+
abort_grpc_context(e.message, context)
|
220
|
+
|
221
|
+
return res
|
@@ -19,6 +19,7 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
from flwr.common import Message, log
|
21
21
|
from flwr.common.constant import Status
|
22
|
+
from flwr.common.inflatable import UnexpectedObjectContentError
|
22
23
|
from flwr.common.serde import (
|
23
24
|
fab_to_proto,
|
24
25
|
message_from_proto,
|
@@ -42,7 +43,13 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
42
43
|
SendNodeHeartbeatRequest,
|
43
44
|
SendNodeHeartbeatResponse,
|
44
45
|
)
|
45
|
-
from flwr.proto.message_pb2 import
|
46
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
47
|
+
ObjectIDs,
|
48
|
+
PullObjectRequest,
|
49
|
+
PullObjectResponse,
|
50
|
+
PushObjectRequest,
|
51
|
+
PushObjectResponse,
|
52
|
+
)
|
46
53
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
47
54
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
48
55
|
GetRunRequest,
|
@@ -203,3 +210,52 @@ def get_fab(
|
|
203
210
|
return GetFabResponse(fab=fab_to_proto(fab))
|
204
211
|
|
205
212
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
213
|
+
|
214
|
+
|
215
|
+
def push_object(
|
216
|
+
request: PushObjectRequest, state: LinkState, store: ObjectStore
|
217
|
+
) -> PushObjectResponse:
|
218
|
+
"""Push Object."""
|
219
|
+
abort_msg = check_abort(
|
220
|
+
request.run_id,
|
221
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
222
|
+
state,
|
223
|
+
)
|
224
|
+
if abort_msg:
|
225
|
+
raise InvalidRunStatusException(abort_msg)
|
226
|
+
|
227
|
+
stored = False
|
228
|
+
try:
|
229
|
+
store.put(request.object_id, request.object_content)
|
230
|
+
stored = True
|
231
|
+
except (NoObjectInStoreError, ValueError) as e:
|
232
|
+
log(ERROR, str(e))
|
233
|
+
except UnexpectedObjectContentError as e:
|
234
|
+
# Object content is not valid
|
235
|
+
log(ERROR, str(e))
|
236
|
+
raise
|
237
|
+
return PushObjectResponse(stored=stored)
|
238
|
+
|
239
|
+
|
240
|
+
def pull_object(
|
241
|
+
request: PullObjectRequest, state: LinkState, store: ObjectStore
|
242
|
+
) -> PullObjectResponse:
|
243
|
+
"""Pull Object."""
|
244
|
+
abort_msg = check_abort(
|
245
|
+
request.run_id,
|
246
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
247
|
+
state,
|
248
|
+
)
|
249
|
+
if abort_msg:
|
250
|
+
raise InvalidRunStatusException(abort_msg)
|
251
|
+
|
252
|
+
# Fetch from store
|
253
|
+
content = store.get(request.object_id)
|
254
|
+
if content is not None:
|
255
|
+
object_available = content != b""
|
256
|
+
return PullObjectResponse(
|
257
|
+
object_found=True,
|
258
|
+
object_available=object_available,
|
259
|
+
object_content=content,
|
260
|
+
)
|
261
|
+
return PullObjectResponse(object_found=False, object_available=False)
|
@@ -38,6 +38,12 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
38
38
|
SendNodeHeartbeatRequest,
|
39
39
|
SendNodeHeartbeatResponse,
|
40
40
|
)
|
41
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
42
|
+
PullObjectRequest,
|
43
|
+
PullObjectResponse,
|
44
|
+
PushObjectRequest,
|
45
|
+
PushObjectResponse,
|
46
|
+
)
|
41
47
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
42
48
|
from flwr.server.superlink.ffs.ffs import Ffs
|
43
49
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
@@ -131,6 +137,28 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
131
137
|
return message_handler.push_messages(request=request, state=state, store=store)
|
132
138
|
|
133
139
|
|
140
|
+
@rest_request_response(PullObjectRequest)
|
141
|
+
async def pull_object(request: PullObjectRequest) -> PullObjectResponse:
|
142
|
+
"""Pull PullObject."""
|
143
|
+
# Get state from app
|
144
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
145
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
146
|
+
|
147
|
+
# Handle message
|
148
|
+
return message_handler.pull_object(request=request, state=state, store=store)
|
149
|
+
|
150
|
+
|
151
|
+
@rest_request_response(PushObjectRequest)
|
152
|
+
async def push_object(request: PushObjectRequest) -> PushObjectResponse:
|
153
|
+
"""Pull PushObject."""
|
154
|
+
# Get state from app
|
155
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
156
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
157
|
+
|
158
|
+
# Handle message
|
159
|
+
return message_handler.push_object(request=request, state=state, store=store)
|
160
|
+
|
161
|
+
|
134
162
|
@rest_request_response(SendNodeHeartbeatRequest)
|
135
163
|
async def send_node_heartbeat(
|
136
164
|
request: SendNodeHeartbeatRequest,
|
@@ -171,6 +199,8 @@ routes = [
|
|
171
199
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
172
200
|
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
173
201
|
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
202
|
+
Route("/api/v0/fleet/pull-object", pull_object, methods=["POST"]),
|
203
|
+
Route("/api/v0/fleet/push-object", push_object, methods=["POST"]),
|
174
204
|
Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
|
175
205
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
176
206
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|