flwr-nightly 1.19.0.dev20250609__py3-none-any.whl → 1.19.0.dev20250611__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/grpc_rere_client/connection.py +4 -1
- flwr/client/rest_client/connection.py +118 -26
- flwr/common/auth_plugin/auth_plugin.py +6 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/inflatable.py +46 -1
- flwr/common/inflatable_grpc_utils.py +3 -266
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +268 -2
- flwr/common/typing.py +3 -3
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +21 -56
- flwr/server/superlink/fleet/message_handler/message_handler.py +57 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +30 -0
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_user_auth_interceptor.py +11 -11
- flwr/supernode/start_client_internal.py +101 -59
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/RECORD +21 -20
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250609.dist-info → flwr_nightly-1.19.0.dev20250611.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Contextmanager for a gRPC request-response channel to the Flower server."""
|
16
16
|
|
17
|
+
|
17
18
|
from collections.abc import Iterator, Sequence
|
18
19
|
from contextlib import contextmanager
|
19
20
|
from copy import copy
|
@@ -32,9 +33,11 @@ from flwr.common.grpc import create_channel, on_channel_state_change
|
|
32
33
|
from flwr.common.heartbeat import HeartbeatSender
|
33
34
|
from flwr.common.inflatable import get_all_nested_objects
|
34
35
|
from flwr.common.inflatable_grpc_utils import (
|
35
|
-
inflate_object_from_contents,
|
36
36
|
make_pull_object_fn_grpc,
|
37
37
|
make_push_object_fn_grpc,
|
38
|
+
)
|
39
|
+
from flwr.common.inflatable_utils import (
|
40
|
+
inflate_object_from_contents,
|
38
41
|
pull_objects,
|
39
42
|
push_objects,
|
40
43
|
)
|
@@ -14,12 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Contextmanager for a REST request-response channel to the Flower server."""
|
16
16
|
|
17
|
-
|
18
17
|
from collections.abc import Iterator
|
19
18
|
from contextlib import contextmanager
|
20
19
|
from copy import copy
|
21
|
-
from logging import ERROR, INFO, WARN
|
22
|
-
from typing import Callable, Optional, TypeVar, Union
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
21
|
+
from typing import Callable, Optional, TypeVar, Union, cast
|
23
22
|
|
24
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
25
24
|
from google.protobuf.message import Message as GrpcMessage
|
@@ -31,10 +30,20 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
31
30
|
from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
|
32
31
|
from flwr.common.exit import ExitCode, flwr_exit
|
33
32
|
from flwr.common.heartbeat import HeartbeatSender
|
33
|
+
from flwr.common.inflatable import get_all_nested_objects
|
34
|
+
from flwr.common.inflatable_rest_utils import (
|
35
|
+
make_pull_object_fn_rest,
|
36
|
+
make_push_object_fn_rest,
|
37
|
+
)
|
38
|
+
from flwr.common.inflatable_utils import (
|
39
|
+
inflate_object_from_contents,
|
40
|
+
pull_objects,
|
41
|
+
push_objects,
|
42
|
+
)
|
34
43
|
from flwr.common.logger import log
|
35
|
-
from flwr.common.message import Message
|
44
|
+
from flwr.common.message import Message, remove_content_from_message
|
36
45
|
from flwr.common.retry_invoker import RetryInvoker
|
37
|
-
from flwr.common.serde import
|
46
|
+
from flwr.common.serde import message_to_proto, run_from_proto
|
38
47
|
from flwr.common.typing import Fab, Run
|
39
48
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
40
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
@@ -51,6 +60,13 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
51
60
|
SendNodeHeartbeatRequest,
|
52
61
|
SendNodeHeartbeatResponse,
|
53
62
|
)
|
63
|
+
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
64
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
65
|
+
PullObjectRequest,
|
66
|
+
PullObjectResponse,
|
67
|
+
PushObjectRequest,
|
68
|
+
PushObjectResponse,
|
69
|
+
)
|
54
70
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
55
71
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
56
72
|
|
@@ -64,6 +80,8 @@ PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
|
64
80
|
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
65
81
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
66
82
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
83
|
+
PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
|
84
|
+
PATH_PUSH_OBJECT: str = "/api/v0/fleet/push-object"
|
67
85
|
PATH_SEND_NODE_HEARTBEAT: str = "api/v0/fleet/send-node-heartbeat"
|
68
86
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
69
87
|
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
|
@@ -296,14 +314,48 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
296
314
|
):
|
297
315
|
message_proto = None
|
298
316
|
|
299
|
-
#
|
300
|
-
|
301
|
-
|
302
|
-
if message_proto
|
303
|
-
message = message_from_proto(message_proto)
|
304
|
-
metadata = copy(message.metadata)
|
317
|
+
# Construct the Message
|
318
|
+
in_message: Optional[Message] = None
|
319
|
+
|
320
|
+
if message_proto:
|
305
321
|
log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
|
306
|
-
|
322
|
+
msg_id = message_proto.metadata.message_id
|
323
|
+
|
324
|
+
def fn(request: PullObjectRequest) -> PullObjectResponse:
|
325
|
+
res = _request(
|
326
|
+
req=request, res_type=PullObjectResponse, api_path=PATH_PULL_OBJECT
|
327
|
+
)
|
328
|
+
if res is None:
|
329
|
+
raise ValueError("PushObjectResponse is None.")
|
330
|
+
return res
|
331
|
+
|
332
|
+
try:
|
333
|
+
all_object_contents = pull_objects(
|
334
|
+
list(res.objects_to_pull[msg_id].object_ids) + [msg_id],
|
335
|
+
pull_object_fn=make_pull_object_fn_rest(
|
336
|
+
pull_object_rest=fn,
|
337
|
+
node=node,
|
338
|
+
run_id=message_proto.metadata.run_id,
|
339
|
+
),
|
340
|
+
)
|
341
|
+
except ValueError as e:
|
342
|
+
log(
|
343
|
+
ERROR,
|
344
|
+
"Pulling objects failed. Potential irrecoverable error: %s",
|
345
|
+
str(e),
|
346
|
+
)
|
347
|
+
in_message = cast(
|
348
|
+
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
349
|
+
)
|
350
|
+
# The deflated message doesn't contain the message_id (its own object_id)
|
351
|
+
# Inject
|
352
|
+
in_message.metadata.__dict__["_message_id"] = msg_id
|
353
|
+
|
354
|
+
# Remember `metadata` of the in message
|
355
|
+
nonlocal metadata
|
356
|
+
metadata = copy(in_message.metadata) if in_message else None
|
357
|
+
|
358
|
+
return in_message
|
307
359
|
|
308
360
|
def send(message: Message) -> None:
|
309
361
|
"""Send Message result back to server."""
|
@@ -318,29 +370,69 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
318
370
|
log(ERROR, "No current message")
|
319
371
|
return
|
320
372
|
|
373
|
+
# Set message_id
|
374
|
+
message.metadata.__dict__["_message_id"] = message.object_id
|
321
375
|
# Validate out message
|
322
376
|
if not validate_out_message(message, metadata):
|
323
377
|
log(ERROR, "Invalid out message")
|
324
378
|
return
|
325
|
-
metadata = None
|
326
379
|
|
327
|
-
#
|
328
|
-
|
329
|
-
|
330
|
-
#
|
331
|
-
|
380
|
+
# Get all nested objects
|
381
|
+
all_objects = get_all_nested_objects(message)
|
382
|
+
all_object_ids = list(all_objects.keys())
|
383
|
+
msg_id = all_object_ids[-1] # Last object is the message itself
|
384
|
+
descendant_ids = all_object_ids[:-1] # All but the last object are descendants
|
385
|
+
|
386
|
+
# Serialize Message
|
387
|
+
message_proto = message_to_proto(message=remove_content_from_message(message))
|
388
|
+
req = PushMessagesRequest(
|
389
|
+
node=node,
|
390
|
+
messages_list=[message_proto],
|
391
|
+
msg_to_descendant_mapping={msg_id: ObjectIDs(object_ids=descendant_ids)},
|
392
|
+
)
|
332
393
|
|
333
394
|
# Send the request
|
334
395
|
res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
|
335
|
-
if res
|
336
|
-
|
396
|
+
if res:
|
397
|
+
log(
|
398
|
+
INFO,
|
399
|
+
"[Node] POST /%s: success, created result %s",
|
400
|
+
PATH_PUSH_MESSAGES,
|
401
|
+
res.results, # pylint: disable=no-member
|
402
|
+
)
|
337
403
|
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
404
|
+
if res and res.objects_to_push:
|
405
|
+
objs_to_push = set(res.objects_to_push[message.object_id].object_ids)
|
406
|
+
|
407
|
+
def fn(request: PushObjectRequest) -> PushObjectResponse:
|
408
|
+
res = _request(
|
409
|
+
req=request, res_type=PushObjectResponse, api_path=PATH_PUSH_OBJECT
|
410
|
+
)
|
411
|
+
if res is None:
|
412
|
+
raise ValueError("PushObjectResponse is None.")
|
413
|
+
return res
|
414
|
+
|
415
|
+
try:
|
416
|
+
push_objects(
|
417
|
+
all_objects,
|
418
|
+
push_object_fn=make_push_object_fn_rest(
|
419
|
+
push_object_rest=fn,
|
420
|
+
node=node,
|
421
|
+
run_id=message_proto.metadata.run_id,
|
422
|
+
),
|
423
|
+
object_ids_to_push=objs_to_push,
|
424
|
+
)
|
425
|
+
log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
|
426
|
+
except ValueError as e:
|
427
|
+
log(
|
428
|
+
ERROR,
|
429
|
+
"Pushing objects failed. Potential irrecoverable error: %s",
|
430
|
+
str(e),
|
431
|
+
)
|
432
|
+
log(ERROR, str(e))
|
433
|
+
|
434
|
+
# Cleanup
|
435
|
+
metadata = None
|
344
436
|
|
345
437
|
def get_run(run_id: int) -> Run:
|
346
438
|
# Construct the request
|
@@ -20,7 +20,7 @@ from collections.abc import Sequence
|
|
20
20
|
from pathlib import Path
|
21
21
|
from typing import Optional, Union
|
22
22
|
|
23
|
-
from flwr.common.typing import
|
23
|
+
from flwr.common.typing import AccountInfo
|
24
24
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
25
25
|
|
26
26
|
from ..typing import UserAuthCredentials, UserAuthLoginDetails
|
@@ -53,7 +53,7 @@ class ExecAuthPlugin(ABC):
|
|
53
53
|
@abstractmethod
|
54
54
|
def validate_tokens_in_metadata(
|
55
55
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
56
|
-
) -> tuple[bool, Optional[
|
56
|
+
) -> tuple[bool, Optional[AccountInfo]]:
|
57
57
|
"""Validate authentication tokens in the provided metadata."""
|
58
58
|
|
59
59
|
@abstractmethod
|
@@ -63,7 +63,9 @@ class ExecAuthPlugin(ABC):
|
|
63
63
|
@abstractmethod
|
64
64
|
def refresh_tokens(
|
65
65
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
66
|
-
) -> tuple[
|
66
|
+
) -> tuple[
|
67
|
+
Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[AccountInfo]
|
68
|
+
]:
|
67
69
|
"""Refresh authentication tokens in the provided metadata."""
|
68
70
|
|
69
71
|
|
@@ -84,7 +86,7 @@ class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
|
|
84
86
|
"""Abstract constructor."""
|
85
87
|
|
86
88
|
@abstractmethod
|
87
|
-
def verify_user_authorization(self,
|
89
|
+
def verify_user_authorization(self, account_info: AccountInfo) -> bool:
|
88
90
|
"""Verify user authorization request."""
|
89
91
|
|
90
92
|
|
@@ -21,7 +21,7 @@ from typing import Optional, Union
|
|
21
21
|
import grpc
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
23
23
|
|
24
|
-
from flwr.common.typing import
|
24
|
+
from flwr.common.typing import AccountInfo, LogEntry
|
25
25
|
|
26
26
|
|
27
27
|
class EventLogWriterPlugin(ABC):
|
@@ -36,7 +36,7 @@ class EventLogWriterPlugin(ABC):
|
|
36
36
|
self,
|
37
37
|
request: GrpcMessage,
|
38
38
|
context: grpc.ServicerContext,
|
39
|
-
|
39
|
+
account_info: Optional[AccountInfo],
|
40
40
|
method_name: str,
|
41
41
|
) -> LogEntry:
|
42
42
|
"""Compose pre-event log entry from the provided request and context."""
|
@@ -46,7 +46,7 @@ class EventLogWriterPlugin(ABC):
|
|
46
46
|
self,
|
47
47
|
request: GrpcMessage,
|
48
48
|
context: grpc.ServicerContext,
|
49
|
-
|
49
|
+
account_info: Optional[AccountInfo],
|
50
50
|
method_name: str,
|
51
51
|
response: Optional[Union[GrpcMessage, BaseException]],
|
52
52
|
) -> LogEntry:
|
flwr/common/inflatable.py
CHANGED
@@ -18,6 +18,9 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import hashlib
|
21
|
+
import threading
|
22
|
+
from collections.abc import Iterator
|
23
|
+
from contextlib import contextmanager
|
21
24
|
from typing import TypeVar, cast
|
22
25
|
|
23
26
|
from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
|
@@ -33,6 +36,33 @@ class UnexpectedObjectContentError(Exception):
|
|
33
36
|
)
|
34
37
|
|
35
38
|
|
39
|
+
_ctx = threading.local()
|
40
|
+
|
41
|
+
|
42
|
+
def _is_recompute_enabled() -> bool:
|
43
|
+
"""Check if recomputing object IDs is enabled."""
|
44
|
+
return getattr(_ctx, "recompute_object_id_enabled", True)
|
45
|
+
|
46
|
+
|
47
|
+
def _get_computed_object_ids() -> set[str]:
|
48
|
+
"""Get the set of computed object IDs."""
|
49
|
+
return getattr(_ctx, "computed_object_ids", set())
|
50
|
+
|
51
|
+
|
52
|
+
@contextmanager
|
53
|
+
def no_object_id_recompute() -> Iterator[None]:
|
54
|
+
"""Context manager to disable recomputing object IDs."""
|
55
|
+
old_value = _is_recompute_enabled()
|
56
|
+
old_set = _get_computed_object_ids()
|
57
|
+
_ctx.recompute_object_id_enabled = False
|
58
|
+
_ctx.computed_object_ids = set()
|
59
|
+
try:
|
60
|
+
yield
|
61
|
+
finally:
|
62
|
+
_ctx.recompute_object_id_enabled = old_value
|
63
|
+
_ctx.computed_object_ids = old_set
|
64
|
+
|
65
|
+
|
36
66
|
class InflatableObject:
|
37
67
|
"""Base class for inflatable objects."""
|
38
68
|
|
@@ -65,8 +95,23 @@ class InflatableObject:
|
|
65
95
|
@property
|
66
96
|
def object_id(self) -> str:
|
67
97
|
"""Get object_id."""
|
98
|
+
# If recomputing object ID is disabled and the object ID is already computed,
|
99
|
+
# return the cached object ID.
|
100
|
+
if (
|
101
|
+
not _is_recompute_enabled()
|
102
|
+
and (obj_id := self.__dict__.get("_object_id"))
|
103
|
+
in _get_computed_object_ids()
|
104
|
+
):
|
105
|
+
return cast(str, obj_id)
|
106
|
+
|
68
107
|
if self.is_dirty or "_object_id" not in self.__dict__:
|
69
|
-
|
108
|
+
obj_id = get_object_id(self.deflate())
|
109
|
+
self.__dict__["_object_id"] = obj_id
|
110
|
+
|
111
|
+
# If recomputing object ID is disabled, add the object ID to the set of
|
112
|
+
# computed object IDs to avoid recomputing it within the context.
|
113
|
+
if not _is_recompute_enabled():
|
114
|
+
_get_computed_object_ids().add(obj_id)
|
70
115
|
return cast(str, self.__dict__["_object_id"])
|
71
116
|
|
72
117
|
@property
|
@@ -12,14 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""InflatableObject utils."""
|
15
|
+
"""InflatableObject gRPC utils."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
import random
|
20
|
-
import threading
|
21
|
-
import time
|
22
|
-
from typing import Callable, Optional
|
18
|
+
from typing import Callable
|
23
19
|
|
24
20
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
25
21
|
PullObjectRequest,
|
@@ -29,42 +25,7 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
29
25
|
)
|
30
26
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
31
27
|
|
32
|
-
from .
|
33
|
-
MAX_CONCURRENT_PULLS,
|
34
|
-
MAX_CONCURRENT_PUSHES,
|
35
|
-
PULL_BACKOFF_CAP,
|
36
|
-
PULL_INITIAL_BACKOFF,
|
37
|
-
PULL_MAX_TIME,
|
38
|
-
PULL_MAX_TRIES_PER_OBJECT,
|
39
|
-
)
|
40
|
-
from .inflatable import InflatableObject, get_object_head_values_from_object_content
|
41
|
-
from .message import Message
|
42
|
-
from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
43
|
-
|
44
|
-
# Helper registry that maps names of classes to their type
|
45
|
-
inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
46
|
-
Array.__qualname__: Array,
|
47
|
-
ArrayRecord.__qualname__: ArrayRecord,
|
48
|
-
ConfigRecord.__qualname__: ConfigRecord,
|
49
|
-
Message.__qualname__: Message,
|
50
|
-
MetricRecord.__qualname__: MetricRecord,
|
51
|
-
RecordDict.__qualname__: RecordDict,
|
52
|
-
}
|
53
|
-
|
54
|
-
|
55
|
-
class ObjectUnavailableError(Exception):
|
56
|
-
"""Exception raised when an object has been pre-registered but is not yet
|
57
|
-
available."""
|
58
|
-
|
59
|
-
def __init__(self, object_id: str):
|
60
|
-
super().__init__(f"Object with ID '{object_id}' is not yet available.")
|
61
|
-
|
62
|
-
|
63
|
-
class ObjectIdNotPreregisteredError(Exception):
|
64
|
-
"""Exception raised when an object ID is not pre-registered."""
|
65
|
-
|
66
|
-
def __init__(self, object_id: str):
|
67
|
-
super().__init__(f"Object with ID '{object_id}' could not be found.")
|
28
|
+
from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
|
68
29
|
|
69
30
|
|
70
31
|
def make_pull_object_fn_grpc(
|
@@ -136,227 +97,3 @@ def make_push_object_fn_grpc(
|
|
136
97
|
raise ObjectIdNotPreregisteredError(object_id)
|
137
98
|
|
138
99
|
return push_object_fn
|
139
|
-
|
140
|
-
|
141
|
-
def push_objects(
|
142
|
-
objects: dict[str, InflatableObject],
|
143
|
-
push_object_fn: Callable[[str, bytes], None],
|
144
|
-
*,
|
145
|
-
object_ids_to_push: Optional[set[str]] = None,
|
146
|
-
keep_objects: bool = False,
|
147
|
-
max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
|
148
|
-
) -> None:
|
149
|
-
"""Push multiple objects to the servicer.
|
150
|
-
|
151
|
-
Parameters
|
152
|
-
----------
|
153
|
-
objects : dict[str, InflatableObject]
|
154
|
-
A dictionary of objects to push, where keys are object IDs and values are
|
155
|
-
`InflatableObject` instances.
|
156
|
-
push_object_fn : Callable[[str, bytes], None]
|
157
|
-
A function that takes an object ID and its content as bytes, and pushes
|
158
|
-
it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
|
159
|
-
if the object ID is not pre-registered.
|
160
|
-
object_ids_to_push : Optional[set[str]] (default: None)
|
161
|
-
A set of object IDs to push. If not provided, all objects will be pushed.
|
162
|
-
keep_objects : bool (default: False)
|
163
|
-
If `True`, the original objects will be kept in the `objects` dictionary
|
164
|
-
after pushing. If `False`, they will be removed from the dictionary to avoid
|
165
|
-
high memory usage.
|
166
|
-
max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
|
167
|
-
The maximum number of concurrent pushes to perform.
|
168
|
-
"""
|
169
|
-
if object_ids_to_push is not None:
|
170
|
-
# Filter objects to push only those with IDs in the set
|
171
|
-
objects = {k: v for k, v in objects.items() if k in object_ids_to_push}
|
172
|
-
|
173
|
-
lock = threading.Lock()
|
174
|
-
|
175
|
-
def push(obj_id: str) -> None:
|
176
|
-
"""Push a single object."""
|
177
|
-
object_content = objects[obj_id].deflate()
|
178
|
-
if not keep_objects:
|
179
|
-
with lock:
|
180
|
-
del objects[obj_id]
|
181
|
-
push_object_fn(obj_id, object_content)
|
182
|
-
|
183
|
-
with concurrent.futures.ThreadPoolExecutor(
|
184
|
-
max_workers=max_concurrent_pushes
|
185
|
-
) as executor:
|
186
|
-
list(executor.map(push, list(objects.keys())))
|
187
|
-
|
188
|
-
|
189
|
-
def pull_objects( # pylint: disable=too-many-arguments
|
190
|
-
object_ids: list[str],
|
191
|
-
pull_object_fn: Callable[[str], bytes],
|
192
|
-
*,
|
193
|
-
max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
|
194
|
-
max_time: Optional[float] = PULL_MAX_TIME,
|
195
|
-
max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
|
196
|
-
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
197
|
-
backoff_cap: float = PULL_BACKOFF_CAP,
|
198
|
-
) -> dict[str, bytes]:
|
199
|
-
"""Pull multiple objects from the servicer.
|
200
|
-
|
201
|
-
Parameters
|
202
|
-
----------
|
203
|
-
object_ids : list[str]
|
204
|
-
A list of object IDs to pull.
|
205
|
-
pull_object_fn : Callable[[str], bytes]
|
206
|
-
A function that takes an object ID and returns the object content as bytes.
|
207
|
-
The function should raise `ObjectUnavailableError` if the object is not yet
|
208
|
-
available, or `ObjectIdNotPreregisteredError` if the object ID is not
|
209
|
-
pre-registered.
|
210
|
-
max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
|
211
|
-
The maximum number of concurrent pulls to perform.
|
212
|
-
max_time : Optional[float] (default: PULL_MAX_TIME)
|
213
|
-
The maximum time to wait for all pulls to complete. If `None`, waits
|
214
|
-
indefinitely.
|
215
|
-
max_tries_per_object : Optional[int] (default: PULL_MAX_TRIES_PER_OBJECT)
|
216
|
-
The maximum number of attempts to pull each object. If `None`, pulls
|
217
|
-
indefinitely until the object is available.
|
218
|
-
initial_backoff : float (default: PULL_INITIAL_BACKOFF)
|
219
|
-
The initial backoff time in seconds for retrying pulls after an
|
220
|
-
`ObjectUnavailableError`.
|
221
|
-
backoff_cap : float (default: PULL_BACKOFF_CAP)
|
222
|
-
The maximum backoff time in seconds. Backoff times will not exceed this value.
|
223
|
-
|
224
|
-
Returns
|
225
|
-
-------
|
226
|
-
dict[str, bytes]
|
227
|
-
A dictionary where keys are object IDs and values are the pulled
|
228
|
-
object contents.
|
229
|
-
"""
|
230
|
-
if max_tries_per_object is None:
|
231
|
-
max_tries_per_object = int(1e9)
|
232
|
-
if max_time is None:
|
233
|
-
max_time = float("inf")
|
234
|
-
|
235
|
-
results: dict[str, bytes] = {}
|
236
|
-
results_lock = threading.Lock()
|
237
|
-
err_to_raise: Optional[Exception] = None
|
238
|
-
early_stop = threading.Event()
|
239
|
-
start = time.monotonic()
|
240
|
-
|
241
|
-
def pull_with_retries(object_id: str) -> None:
|
242
|
-
"""Attempt to pull a single object with retry and backoff."""
|
243
|
-
nonlocal err_to_raise
|
244
|
-
tries = 0
|
245
|
-
delay = initial_backoff
|
246
|
-
|
247
|
-
while not early_stop.is_set():
|
248
|
-
try:
|
249
|
-
object_content = pull_object_fn(object_id)
|
250
|
-
with results_lock:
|
251
|
-
results[object_id] = object_content
|
252
|
-
return
|
253
|
-
|
254
|
-
except ObjectUnavailableError as err:
|
255
|
-
tries += 1
|
256
|
-
if (
|
257
|
-
tries >= max_tries_per_object
|
258
|
-
or time.monotonic() - start >= max_time
|
259
|
-
):
|
260
|
-
# Stop all work if one object exhausts retries
|
261
|
-
early_stop.set()
|
262
|
-
with results_lock:
|
263
|
-
if err_to_raise is None:
|
264
|
-
err_to_raise = err
|
265
|
-
return
|
266
|
-
|
267
|
-
# Apply exponential backoff with ±20% jitter
|
268
|
-
sleep_time = delay * (1 + random.uniform(-0.2, 0.2))
|
269
|
-
early_stop.wait(sleep_time)
|
270
|
-
delay = min(delay * 2, backoff_cap)
|
271
|
-
|
272
|
-
except ObjectIdNotPreregisteredError as err:
|
273
|
-
# Permanent failure: object ID is invalid
|
274
|
-
early_stop.set()
|
275
|
-
with results_lock:
|
276
|
-
if err_to_raise is None:
|
277
|
-
err_to_raise = err
|
278
|
-
return
|
279
|
-
|
280
|
-
# Submit all pull tasks concurrently
|
281
|
-
with concurrent.futures.ThreadPoolExecutor(
|
282
|
-
max_workers=max_concurrent_pulls
|
283
|
-
) as executor:
|
284
|
-
futures = {
|
285
|
-
executor.submit(pull_with_retries, obj_id): obj_id for obj_id in object_ids
|
286
|
-
}
|
287
|
-
|
288
|
-
# Wait for completion
|
289
|
-
concurrent.futures.wait(futures)
|
290
|
-
|
291
|
-
if err_to_raise is not None:
|
292
|
-
raise err_to_raise
|
293
|
-
|
294
|
-
return results
|
295
|
-
|
296
|
-
|
297
|
-
def inflate_object_from_contents(
|
298
|
-
object_id: str,
|
299
|
-
object_contents: dict[str, bytes],
|
300
|
-
*,
|
301
|
-
keep_object_contents: bool = False,
|
302
|
-
objects: Optional[dict[str, InflatableObject]] = None,
|
303
|
-
) -> InflatableObject:
|
304
|
-
"""Inflate an object from object contents.
|
305
|
-
|
306
|
-
Parameters
|
307
|
-
----------
|
308
|
-
object_id : str
|
309
|
-
The ID of the object to inflate.
|
310
|
-
object_contents : dict[str, bytes]
|
311
|
-
A dictionary mapping object IDs to their contents as bytes.
|
312
|
-
All descendant objects must be present in this dictionary.
|
313
|
-
keep_object_contents : bool (default: False)
|
314
|
-
If `True`, the object content will be kept in the `object_contents`
|
315
|
-
dictionary after inflation. If `False`, the object content will be
|
316
|
-
removed from the dictionary to save memory.
|
317
|
-
objects : Optional[dict[str, InflatableObject]] (default: None)
|
318
|
-
No need to provide this parameter. A dictionary to store already
|
319
|
-
inflated objects, mapping object IDs to their corresponding
|
320
|
-
`InflatableObject` instances.
|
321
|
-
|
322
|
-
Returns
|
323
|
-
-------
|
324
|
-
InflatableObject
|
325
|
-
The inflated object.
|
326
|
-
"""
|
327
|
-
if objects is None:
|
328
|
-
# Initialize objects dictionary
|
329
|
-
objects = {}
|
330
|
-
|
331
|
-
if object_id in objects:
|
332
|
-
# If the object is already in the objects dictionary, return it
|
333
|
-
return objects[object_id]
|
334
|
-
|
335
|
-
# Extract object class and object_ids of children
|
336
|
-
object_content = object_contents[object_id]
|
337
|
-
obj_type, children_obj_ids, _ = get_object_head_values_from_object_content(
|
338
|
-
object_content=object_contents[object_id]
|
339
|
-
)
|
340
|
-
|
341
|
-
# Remove the object content from the dictionary to save memory
|
342
|
-
if not keep_object_contents:
|
343
|
-
del object_contents[object_id]
|
344
|
-
|
345
|
-
# Resolve object class
|
346
|
-
cls_type = inflatable_class_registry[obj_type]
|
347
|
-
|
348
|
-
# Inflate all children objects
|
349
|
-
children: dict[str, InflatableObject] = {}
|
350
|
-
for child_obj_id in children_obj_ids:
|
351
|
-
children[child_obj_id] = inflate_object_from_contents(
|
352
|
-
child_obj_id,
|
353
|
-
object_contents,
|
354
|
-
keep_object_contents=keep_object_contents,
|
355
|
-
objects=objects,
|
356
|
-
)
|
357
|
-
|
358
|
-
# Inflate object passing its children
|
359
|
-
obj = cls_type.inflate(object_content, children=children)
|
360
|
-
del object_content # Free memory after inflation
|
361
|
-
objects[object_id] = obj
|
362
|
-
return obj
|