flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/build.py +15 -5
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +23 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +53 -50
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +29 -11
- flwr/client/grpc_adapter_client/connection.py +11 -4
- flwr/client/grpc_rere_client/connection.py +93 -129
- flwr/client/rest_client/connection.py +134 -164
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +26 -0
- flwr/clientapp/mod/centraldp_mods.py +132 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -5
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +42 -8
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +1 -1
- flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
- flwr/common/inflatable_utils.py +191 -24
- flwr/common/logger.py +1 -1
- flwr/common/record/array.py +101 -22
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/serde.py +0 -28
- flwr/common/telemetry.py +4 -0
- flwr/compat/client/app.py +14 -31
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +51 -0
- flwr/proto/appio_pb2.pyi +195 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +4 -19
- flwr/proto/clientappio_pb2.pyi +0 -125
- flwr/proto/clientappio_pb2_grpc.py +269 -29
- flwr/proto/clientappio_pb2_grpc.pyi +114 -21
- flwr/proto/control_pb2.py +62 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
- flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
- flwr/proto/fleet_pb2.py +12 -20
- flwr/proto/fleet_pb2.pyi +6 -36
- flwr/proto/serverappio_pb2.py +8 -31
- flwr/proto/serverappio_pb2.pyi +0 -152
- flwr/proto/serverappio_pb2_grpc.py +107 -38
- flwr/proto/serverappio_pb2_grpc.pyi +47 -20
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +130 -153
- flwr/server/fleet_event_log_interceptor.py +4 -0
- flwr/server/grid/grpc_grid.py +94 -54
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +165 -144
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
- flwr/server/superlink/utils.py +0 -35
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/dp_fixed_clipping.py +352 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +38 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
- flwr/serverapp/strategy/fedadagrad.py +162 -0
- flwr/serverapp/strategy/fedadam.py +181 -0
- flwr/serverapp/strategy/fedavg.py +295 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedyogi.py +173 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +251 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
- flwr/simulation/app.py +159 -154
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/app_utils.py +58 -0
- flwr/supercore/cli/__init__.py +22 -0
- flwr/supercore/cli/flower_superexec.py +141 -0
- flwr/supercore/corestate/__init__.py +22 -0
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +25 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/in_memory_object_store.py +31 -31
- flwr/supercore/object_store/object_store.py +20 -42
- flwr/supercore/object_store/utils.py +43 -0
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +185 -0
- flwr/supercore/utils.py +32 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
- flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
- flwr/supernode/cli/flower_supernode.py +3 -7
- flwr/supernode/cli/flwr_clientapp.py +20 -16
- flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
- flwr/supernode/nodestate/nodestate.py +3 -44
- flwr/supernode/runtime/run_clientapp.py +129 -115
- flwr/supernode/servicer/clientappio/__init__.py +1 -3
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
- flwr/supernode/start_client_internal.py +205 -148
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
- flwr/common/inflatable_rest_utils.py +0 -99
- flwr/proto/exec_pb2.py +0 -62
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -192
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -130
- /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
- /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
flwr/common/inflatable_utils.py
CHANGED
|
@@ -15,10 +15,14 @@
|
|
|
15
15
|
"""InflatableObject utilities."""
|
|
16
16
|
|
|
17
17
|
import concurrent.futures
|
|
18
|
+
import os
|
|
18
19
|
import random
|
|
19
20
|
import threading
|
|
20
21
|
import time
|
|
21
|
-
from
|
|
22
|
+
from collections.abc import Iterable, Iterator
|
|
23
|
+
from typing import Callable, Optional, TypeVar
|
|
24
|
+
|
|
25
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
22
26
|
|
|
23
27
|
from .constant import (
|
|
24
28
|
HEAD_BODY_DIVIDER,
|
|
@@ -30,6 +34,7 @@ from .constant import (
|
|
|
30
34
|
PULL_MAX_TIME,
|
|
31
35
|
PULL_MAX_TRIES_PER_OBJECT,
|
|
32
36
|
)
|
|
37
|
+
from .exit import add_exit_handler
|
|
33
38
|
from .inflatable import (
|
|
34
39
|
InflatableObject,
|
|
35
40
|
UnexpectedObjectContentError,
|
|
@@ -37,12 +42,15 @@ from .inflatable import (
|
|
|
37
42
|
get_object_head_values_from_object_content,
|
|
38
43
|
get_object_id,
|
|
39
44
|
is_valid_sha256_hash,
|
|
45
|
+
iterate_object_tree,
|
|
40
46
|
)
|
|
41
47
|
from .message import Message
|
|
42
48
|
from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
|
49
|
+
from .record.arraychunk import ArrayChunk
|
|
43
50
|
|
|
44
51
|
# Helper registry that maps names of classes to their type
|
|
45
52
|
inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
|
53
|
+
ArrayChunk.__qualname__: ArrayChunk,
|
|
46
54
|
Array.__qualname__: Array,
|
|
47
55
|
ArrayRecord.__qualname__: ArrayRecord,
|
|
48
56
|
ConfigRecord.__qualname__: ConfigRecord,
|
|
@@ -51,6 +59,36 @@ inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
|
|
51
59
|
RecordDict.__qualname__: RecordDict,
|
|
52
60
|
}
|
|
53
61
|
|
|
62
|
+
T = TypeVar("T", bound=InflatableObject)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Allow thread pool executors to be shut down gracefully
|
|
66
|
+
_thread_pool_executors: set[concurrent.futures.ThreadPoolExecutor] = set()
|
|
67
|
+
_lock = threading.Lock()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _shutdown_thread_pool_executors() -> None:
|
|
71
|
+
"""Shutdown all thread pool executors gracefully."""
|
|
72
|
+
with _lock:
|
|
73
|
+
for executor in _thread_pool_executors:
|
|
74
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
|
75
|
+
_thread_pool_executors.clear()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _track_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
|
|
79
|
+
"""Track a thread pool executor for graceful shutdown."""
|
|
80
|
+
with _lock:
|
|
81
|
+
_thread_pool_executors.add(executor)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _untrack_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
|
|
85
|
+
"""Untrack a thread pool executor."""
|
|
86
|
+
with _lock:
|
|
87
|
+
_thread_pool_executors.discard(executor)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
add_exit_handler(_shutdown_thread_pool_executors)
|
|
91
|
+
|
|
54
92
|
|
|
55
93
|
class ObjectUnavailableError(Exception):
|
|
56
94
|
"""Exception raised when an object has been pre-registered but is not yet
|
|
@@ -67,6 +105,13 @@ class ObjectIdNotPreregisteredError(Exception):
|
|
|
67
105
|
super().__init__(f"Object with ID '{object_id}' could not be found.")
|
|
68
106
|
|
|
69
107
|
|
|
108
|
+
def get_num_workers(max_concurrent: int) -> int:
|
|
109
|
+
"""Get number of workers based on the number of CPU cores and the maximum
|
|
110
|
+
allowed."""
|
|
111
|
+
num_cores = os.cpu_count() or 1
|
|
112
|
+
return min(max_concurrent, num_cores)
|
|
113
|
+
|
|
114
|
+
|
|
70
115
|
def push_objects(
|
|
71
116
|
objects: dict[str, InflatableObject],
|
|
72
117
|
push_object_fn: Callable[[str, bytes], None],
|
|
@@ -95,27 +140,73 @@ def push_objects(
|
|
|
95
140
|
max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
|
|
96
141
|
The maximum number of concurrent pushes to perform.
|
|
97
142
|
"""
|
|
98
|
-
if object_ids_to_push is not None:
|
|
99
|
-
# Filter objects to push only those with IDs in the set
|
|
100
|
-
objects = {k: v for k, v in objects.items() if k in object_ids_to_push}
|
|
101
|
-
|
|
102
143
|
lock = threading.Lock()
|
|
103
144
|
|
|
104
|
-
def
|
|
145
|
+
def iter_dict_items() -> Iterator[tuple[str, bytes]]:
|
|
146
|
+
"""Iterate over the dictionary items."""
|
|
147
|
+
for obj_id in list(objects.keys()):
|
|
148
|
+
# Skip the object if no need to push it
|
|
149
|
+
if object_ids_to_push is not None and obj_id not in object_ids_to_push:
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
# Deflate the object content
|
|
153
|
+
object_content = objects[obj_id].deflate()
|
|
154
|
+
if not keep_objects:
|
|
155
|
+
with lock:
|
|
156
|
+
del objects[obj_id]
|
|
157
|
+
|
|
158
|
+
yield obj_id, object_content
|
|
159
|
+
|
|
160
|
+
push_object_contents_from_iterable(
|
|
161
|
+
iter_dict_items(),
|
|
162
|
+
push_object_fn,
|
|
163
|
+
max_concurrent_pushes=max_concurrent_pushes,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def push_object_contents_from_iterable(
|
|
168
|
+
object_contents: Iterable[tuple[str, bytes]],
|
|
169
|
+
push_object_fn: Callable[[str, bytes], None],
|
|
170
|
+
*,
|
|
171
|
+
max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
|
|
172
|
+
) -> None:
|
|
173
|
+
"""Push multiple object contents to the servicer.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
object_contents : Iterable[tuple[str, bytes]]
|
|
178
|
+
An iterable of `(object_id, object_content)` pairs.
|
|
179
|
+
`object_id` is the object ID, and `object_content` is the object content.
|
|
180
|
+
push_object_fn : Callable[[str, bytes], None]
|
|
181
|
+
A function that takes an object ID and its content as bytes, and pushes
|
|
182
|
+
it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
|
|
183
|
+
if the object ID is not pre-registered.
|
|
184
|
+
max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
|
|
185
|
+
The maximum number of concurrent pushes to perform.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def push(args: tuple[str, bytes]) -> None:
|
|
105
189
|
"""Push a single object."""
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
190
|
+
obj_id, obj_content = args
|
|
191
|
+
# Push the object using the provided function
|
|
192
|
+
push_object_fn(obj_id, obj_content)
|
|
193
|
+
|
|
194
|
+
# Push all object contents concurrently
|
|
195
|
+
num_workers = get_num_workers(max_concurrent_pushes)
|
|
196
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
197
|
+
# Ensure that the thread pool executors are tracked for graceful shutdown
|
|
198
|
+
_track_executor(executor)
|
|
111
199
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
) as executor:
|
|
115
|
-
list(executor.map(push, list(objects.keys())))
|
|
200
|
+
# Submit push tasks for each object content
|
|
201
|
+
executor.map(push, object_contents) # Non-blocking map
|
|
116
202
|
|
|
203
|
+
# The context manager will block until all submitted tasks have completed
|
|
117
204
|
|
|
118
|
-
|
|
205
|
+
# Remove the executor from the list of tracked executors
|
|
206
|
+
_untrack_executor(executor)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
119
210
|
object_ids: list[str],
|
|
120
211
|
pull_object_fn: Callable[[str], bytes],
|
|
121
212
|
*,
|
|
@@ -207,16 +298,20 @@ def pull_objects( # pylint: disable=too-many-arguments
|
|
|
207
298
|
return
|
|
208
299
|
|
|
209
300
|
# Submit all pull tasks concurrently
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
301
|
+
num_workers = get_num_workers(max_concurrent_pulls)
|
|
302
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
303
|
+
# Ensure that the thread pool executors are tracked for graceful shutdown
|
|
304
|
+
_track_executor(executor)
|
|
305
|
+
|
|
306
|
+
# Submit pull tasks for each object ID
|
|
307
|
+
executor.map(pull_with_retries, object_ids) # Non-blocking map
|
|
308
|
+
|
|
309
|
+
# The context manager will block until all submitted tasks have completed
|
|
216
310
|
|
|
217
|
-
|
|
218
|
-
|
|
311
|
+
# Remove the executor from the list of tracked executors
|
|
312
|
+
_untrack_executor(executor)
|
|
219
313
|
|
|
314
|
+
# If an error occurred during pulling, raise it
|
|
220
315
|
if err_to_raise is not None:
|
|
221
316
|
raise err_to_raise
|
|
222
317
|
|
|
@@ -339,3 +434,75 @@ def validate_object_content(content: bytes) -> None:
|
|
|
339
434
|
raise UnexpectedObjectContentError(
|
|
340
435
|
object_id=get_object_id(content), reason=str(err)
|
|
341
436
|
) from err
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def pull_and_inflate_object_from_tree( # pylint: disable=R0913
|
|
440
|
+
object_tree: ObjectTree,
|
|
441
|
+
pull_object_fn: Callable[[str], bytes],
|
|
442
|
+
confirm_object_received_fn: Callable[[str], None],
|
|
443
|
+
*,
|
|
444
|
+
return_type: type[T] = InflatableObject, # type: ignore
|
|
445
|
+
max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
|
|
446
|
+
max_time: Optional[float] = PULL_MAX_TIME,
|
|
447
|
+
max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
|
|
448
|
+
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
|
449
|
+
backoff_cap: float = PULL_BACKOFF_CAP,
|
|
450
|
+
) -> T:
|
|
451
|
+
"""Pull and inflate the head object from the provided object tree.
|
|
452
|
+
|
|
453
|
+
Parameters
|
|
454
|
+
----------
|
|
455
|
+
object_tree : ObjectTree
|
|
456
|
+
The object tree containing the object ID and its descendants.
|
|
457
|
+
pull_object_fn : Callable[[str], bytes]
|
|
458
|
+
A function that takes an object ID and returns the object content as bytes.
|
|
459
|
+
confirm_object_received_fn : Callable[[str], None]
|
|
460
|
+
A function to confirm that the object has been received.
|
|
461
|
+
return_type : type[T] (default: InflatableObject)
|
|
462
|
+
The type of the object to return. Must be a subclass of `InflatableObject`.
|
|
463
|
+
max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
|
|
464
|
+
The maximum number of concurrent pulls to perform.
|
|
465
|
+
max_time : Optional[float] (default: PULL_MAX_TIME)
|
|
466
|
+
The maximum time to wait for all pulls to complete. If `None`, waits
|
|
467
|
+
indefinitely.
|
|
468
|
+
max_tries_per_object : Optional[int] (default: PULL_MAX_TRIES_PER_OBJECT)
|
|
469
|
+
The maximum number of attempts to pull each object. If `None`, pulls
|
|
470
|
+
indefinitely until the object is available.
|
|
471
|
+
initial_backoff : float (default: PULL_INITIAL_BACKOFF)
|
|
472
|
+
The initial backoff time in seconds for retrying pulls after an
|
|
473
|
+
`ObjectUnavailableError`.
|
|
474
|
+
backoff_cap : float (default: PULL_BACKOFF_CAP)
|
|
475
|
+
The maximum backoff time in seconds. Backoff times will not exceed this value.
|
|
476
|
+
|
|
477
|
+
Returns
|
|
478
|
+
-------
|
|
479
|
+
T
|
|
480
|
+
An instance of the specified return type containing the inflated object.
|
|
481
|
+
"""
|
|
482
|
+
# Pull the main object and all its descendants
|
|
483
|
+
pulled_object_contents = pull_objects(
|
|
484
|
+
[tree.object_id for tree in iterate_object_tree(object_tree)],
|
|
485
|
+
pull_object_fn,
|
|
486
|
+
max_concurrent_pulls=max_concurrent_pulls,
|
|
487
|
+
max_time=max_time,
|
|
488
|
+
max_tries_per_object=max_tries_per_object,
|
|
489
|
+
initial_backoff=initial_backoff,
|
|
490
|
+
backoff_cap=backoff_cap,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Confirm that all objects were pulled
|
|
494
|
+
confirm_object_received_fn(object_tree.object_id)
|
|
495
|
+
|
|
496
|
+
# Inflate the main object
|
|
497
|
+
inflated_object = inflate_object_from_contents(
|
|
498
|
+
object_tree.object_id, pulled_object_contents, keep_object_contents=False
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# Check if the inflated object is of the expected type
|
|
502
|
+
if not isinstance(inflated_object, return_type):
|
|
503
|
+
raise TypeError(
|
|
504
|
+
f"Expected object of type {return_type.__name__}, "
|
|
505
|
+
f"but got {type(inflated_object).__name__}."
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
return inflated_object
|
flwr/common/logger.py
CHANGED
|
@@ -132,13 +132,13 @@ if log_level := os.getenv("FLWR_LOG_LEVEL"):
|
|
|
132
132
|
log_level = log_level.upper()
|
|
133
133
|
try:
|
|
134
134
|
is_debug = log_level == "DEBUG"
|
|
135
|
+
update_console_handler(level=log_level, timestamps=is_debug, colored=True)
|
|
135
136
|
if is_debug:
|
|
136
137
|
log(
|
|
137
138
|
WARN,
|
|
138
139
|
"DEBUG logs enabled. Do not use this in production, as it may expose "
|
|
139
140
|
"sensitive details.",
|
|
140
141
|
)
|
|
141
|
-
update_console_handler(level=log_level, timestamps=is_debug, colored=True)
|
|
142
142
|
except Exception: # pylint: disable=broad-exception-caught
|
|
143
143
|
# Alert user but don't raise exception
|
|
144
144
|
log(
|
flwr/common/record/array.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
+
import json
|
|
20
21
|
import sys
|
|
21
22
|
from dataclasses import dataclass
|
|
22
23
|
from io import BytesIO
|
|
@@ -24,11 +25,15 @@ from typing import TYPE_CHECKING, Any, cast, overload
|
|
|
24
25
|
|
|
25
26
|
import numpy as np
|
|
26
27
|
|
|
27
|
-
from
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
from ..constant import MAX_ARRAY_CHUNK_SIZE, SType
|
|
29
|
+
from ..inflatable import (
|
|
30
|
+
InflatableObject,
|
|
31
|
+
add_header_to_object_body,
|
|
32
|
+
get_object_body,
|
|
33
|
+
get_object_children_ids_from_object_content,
|
|
34
|
+
)
|
|
31
35
|
from ..typing import NDArray
|
|
36
|
+
from .arraychunk import ArrayChunk
|
|
32
37
|
|
|
33
38
|
if TYPE_CHECKING:
|
|
34
39
|
import torch
|
|
@@ -252,16 +257,56 @@ class Array(InflatableObject):
|
|
|
252
257
|
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
|
253
258
|
return cast(NDArray, ndarray_deserialized)
|
|
254
259
|
|
|
260
|
+
@property
|
|
261
|
+
def children(self) -> dict[str, InflatableObject]:
|
|
262
|
+
"""Return a dictionary of ArrayChunks with their Object IDs as keys."""
|
|
263
|
+
return dict(self.slice_array())
|
|
264
|
+
|
|
265
|
+
def slice_array(self) -> list[tuple[str, InflatableObject]]:
|
|
266
|
+
"""Slice Array data and construct a list of ArrayChunks."""
|
|
267
|
+
# Return cached chunks if they exist
|
|
268
|
+
if "_chunks" in self.__dict__:
|
|
269
|
+
return cast(list[tuple[str, InflatableObject]], self.__dict__["_chunks"])
|
|
270
|
+
|
|
271
|
+
# Chunks are not children as some of them may be identical
|
|
272
|
+
chunks: list[tuple[str, InflatableObject]] = []
|
|
273
|
+
# memoryview allows for zero-copy slicing
|
|
274
|
+
data_view = memoryview(self.data)
|
|
275
|
+
for start in range(0, len(data_view), MAX_ARRAY_CHUNK_SIZE):
|
|
276
|
+
end = min(start + MAX_ARRAY_CHUNK_SIZE, len(data_view))
|
|
277
|
+
ac = ArrayChunk(data_view[start:end])
|
|
278
|
+
chunks.append((ac.object_id, ac))
|
|
279
|
+
|
|
280
|
+
# Cache the chunks for future use
|
|
281
|
+
self.__dict__["_chunks"] = chunks
|
|
282
|
+
return chunks
|
|
283
|
+
|
|
255
284
|
def deflate(self) -> bytes:
|
|
256
285
|
"""Deflate the Array."""
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
|
|
286
|
+
array_metadata: dict[str, str | tuple[int, ...] | list[int]] = {}
|
|
287
|
+
|
|
288
|
+
# We want to record all object_id even if repeated
|
|
289
|
+
# it can happend that chunks carry the exact same data
|
|
290
|
+
# for example when the array has only zeros
|
|
291
|
+
children_list = self.slice_array()
|
|
292
|
+
# Let's not save the entire object_id but a mapping to those
|
|
293
|
+
# that will be carried in the object head
|
|
294
|
+
# (replace a long object_id with a single scalar)
|
|
295
|
+
unique_children = list(self.children.keys())
|
|
296
|
+
arraychunk_ids = [unique_children.index(ch_id) for ch_id, _ in children_list]
|
|
297
|
+
|
|
298
|
+
# The deflated Array carries everything but the data
|
|
299
|
+
# The `arraychunk_ids` will be used during Array inflation
|
|
300
|
+
# to rematerialize the data from ArrayChunk objects.
|
|
301
|
+
array_metadata = {
|
|
302
|
+
"dtype": self.dtype,
|
|
303
|
+
"shape": self.shape,
|
|
304
|
+
"stype": self.stype,
|
|
305
|
+
"arraychunk_ids": arraychunk_ids,
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
# Serialize metadata dict
|
|
309
|
+
obj_body = json.dumps(array_metadata).encode("utf-8")
|
|
265
310
|
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
266
311
|
|
|
267
312
|
@classmethod
|
|
@@ -276,26 +321,55 @@ class Array(InflatableObject):
|
|
|
276
321
|
The deflated object content of the Array.
|
|
277
322
|
|
|
278
323
|
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
279
|
-
Must be ``None``. ``Array``
|
|
280
|
-
Providing
|
|
324
|
+
Must be ``None``. ``Array`` must have child objects.
|
|
325
|
+
Providing no children will raise a ``ValueError``.
|
|
281
326
|
|
|
282
327
|
Returns
|
|
283
328
|
-------
|
|
284
329
|
Array
|
|
285
330
|
The inflated Array.
|
|
286
331
|
"""
|
|
287
|
-
if children:
|
|
288
|
-
|
|
332
|
+
if children is None:
|
|
333
|
+
children = {}
|
|
289
334
|
|
|
290
335
|
obj_body = get_object_body(object_content, cls)
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
336
|
+
|
|
337
|
+
# Extract children IDs from head
|
|
338
|
+
children_ids = get_object_children_ids_from_object_content(object_content)
|
|
339
|
+
# Decode the Array body
|
|
340
|
+
array_metadata: dict[str, str | tuple[int, ...] | list[int]] = json.loads(
|
|
341
|
+
obj_body.decode(encoding="utf-8")
|
|
297
342
|
)
|
|
298
343
|
|
|
344
|
+
# Verify children ids in body match those passed for inflation
|
|
345
|
+
chunk_ids_indices = cast(list[int], array_metadata["arraychunk_ids"])
|
|
346
|
+
# Convert indices back to IDs
|
|
347
|
+
chunk_ids = [children_ids[i] for i in chunk_ids_indices]
|
|
348
|
+
# Check consistency
|
|
349
|
+
unique_arrayschunks = set(chunk_ids)
|
|
350
|
+
children_obj_ids = set(children.keys())
|
|
351
|
+
if unique_arrayschunks != children_obj_ids:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
"Unexpected set of `children`. "
|
|
354
|
+
f"Expected {unique_arrayschunks} but got {children_obj_ids}."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Materialize Array with empty data
|
|
358
|
+
array = cls(
|
|
359
|
+
dtype=cast(str, array_metadata["dtype"]),
|
|
360
|
+
shape=cast(tuple[int], tuple(array_metadata["shape"])),
|
|
361
|
+
stype=cast(str, array_metadata["stype"]),
|
|
362
|
+
data=b"",
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Now inject data from chunks
|
|
366
|
+
buff = bytearray()
|
|
367
|
+
for ch_id in chunk_ids:
|
|
368
|
+
buff += cast(ArrayChunk, children[ch_id]).data
|
|
369
|
+
|
|
370
|
+
array.data = bytes(buff)
|
|
371
|
+
return array
|
|
372
|
+
|
|
299
373
|
@property
|
|
300
374
|
def object_id(self) -> str:
|
|
301
375
|
"""Get object ID."""
|
|
@@ -320,4 +394,9 @@ class Array(InflatableObject):
|
|
|
320
394
|
if name in ("dtype", "shape", "stype", "data"):
|
|
321
395
|
# Mark as dirty if any of the main attributes are set
|
|
322
396
|
self.is_dirty = True
|
|
397
|
+
# Clear cached object ID
|
|
398
|
+
self.__dict__.pop("_object_id", None)
|
|
399
|
+
# Clear cached chunks if data is set
|
|
400
|
+
if name == "data":
|
|
401
|
+
self.__dict__.pop("_chunks", None)
|
|
323
402
|
super().__setattr__(name, value)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ArrayChunk."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ArrayChunk(InflatableObject):
|
|
27
|
+
"""ArrayChunk type."""
|
|
28
|
+
|
|
29
|
+
data: memoryview
|
|
30
|
+
|
|
31
|
+
def deflate(self) -> bytes:
|
|
32
|
+
"""Deflate the ArrayChunk."""
|
|
33
|
+
return add_header_to_object_body(object_body=self.data, obj=self)
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def inflate(
|
|
37
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
38
|
+
) -> ArrayChunk:
|
|
39
|
+
"""Inflate an ArrayChunk from bytes.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
object_content : bytes
|
|
44
|
+
The deflated object content of the ArrayChunk.
|
|
45
|
+
|
|
46
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
47
|
+
Must be ``None``. ``ArrayChunk`` does not support child objects.
|
|
48
|
+
Providing any children will raise a ``ValueError``.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
ArrayChunk
|
|
53
|
+
The inflated ArrayChunk.
|
|
54
|
+
"""
|
|
55
|
+
if children:
|
|
56
|
+
raise ValueError("`ArrayChunk` objects do not have children.")
|
|
57
|
+
|
|
58
|
+
obj_body = get_object_body(object_content, cls)
|
|
59
|
+
return cls(data=memoryview(obj_body))
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import itertools
|
|
19
19
|
import random
|
|
20
|
+
import threading
|
|
20
21
|
import time
|
|
21
22
|
from collections.abc import Generator, Iterable
|
|
22
23
|
from dataclasses import dataclass
|
|
@@ -319,8 +320,12 @@ class RetryInvoker:
|
|
|
319
320
|
|
|
320
321
|
def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
321
322
|
"""Create a simple gRPC retry invoker."""
|
|
323
|
+
lock = threading.Lock()
|
|
324
|
+
system_healthy = threading.Event()
|
|
325
|
+
system_healthy.set() # Initially, the connection is healthy
|
|
322
326
|
|
|
323
|
-
def
|
|
327
|
+
def _on_success(retry_state: RetryState) -> None:
|
|
328
|
+
system_healthy.set()
|
|
324
329
|
if retry_state.tries > 1:
|
|
325
330
|
log(
|
|
326
331
|
INFO,
|
|
@@ -329,17 +334,11 @@ def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
|
329
334
|
retry_state.tries,
|
|
330
335
|
)
|
|
331
336
|
|
|
332
|
-
def _on_backoff(
|
|
333
|
-
|
|
334
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
335
|
-
else:
|
|
336
|
-
log(
|
|
337
|
-
WARN,
|
|
338
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
339
|
-
retry_state.actual_wait,
|
|
340
|
-
)
|
|
337
|
+
def _on_backoff(_: RetryState) -> None:
|
|
338
|
+
system_healthy.clear()
|
|
341
339
|
|
|
342
340
|
def _on_giveup(retry_state: RetryState) -> None:
|
|
341
|
+
system_healthy.clear()
|
|
343
342
|
if retry_state.tries > 1:
|
|
344
343
|
log(
|
|
345
344
|
WARN,
|
|
@@ -355,15 +354,35 @@ def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
|
355
354
|
return False
|
|
356
355
|
return True
|
|
357
356
|
|
|
357
|
+
def _wait(wait_time: float) -> None:
|
|
358
|
+
# Use a lock to prevent multiple gRPC calls from retrying concurrently,
|
|
359
|
+
# which is unnecessary since they are all likely to fail.
|
|
360
|
+
with lock:
|
|
361
|
+
# Log the wait time
|
|
362
|
+
log(
|
|
363
|
+
WARN,
|
|
364
|
+
"Connection attempt failed, retrying in %.2f seconds",
|
|
365
|
+
wait_time,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
start = time.monotonic()
|
|
369
|
+
# Avoid sequential waits if the system is healthy
|
|
370
|
+
system_healthy.wait(wait_time)
|
|
371
|
+
|
|
372
|
+
remaining_time = wait_time - (time.monotonic() - start)
|
|
373
|
+
if remaining_time > 0:
|
|
374
|
+
time.sleep(remaining_time)
|
|
375
|
+
|
|
358
376
|
return RetryInvoker(
|
|
359
377
|
wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
|
|
360
378
|
recoverable_exceptions=grpc.RpcError,
|
|
361
379
|
max_tries=None,
|
|
362
380
|
max_time=None,
|
|
363
|
-
on_success=
|
|
381
|
+
on_success=_on_success,
|
|
364
382
|
on_backoff=_on_backoff,
|
|
365
383
|
on_giveup=_on_giveup,
|
|
366
384
|
should_giveup=_should_giveup_fn,
|
|
385
|
+
wait_function=_wait,
|
|
367
386
|
)
|
|
368
387
|
|
|
369
388
|
|
flwr/common/serde.py
CHANGED
|
@@ -19,7 +19,6 @@ from collections import OrderedDict
|
|
|
19
19
|
from typing import Any, cast
|
|
20
20
|
|
|
21
21
|
# pylint: disable=E0611
|
|
22
|
-
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
23
22
|
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
24
23
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
25
24
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
@@ -653,33 +652,6 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
|
653
652
|
return run
|
|
654
653
|
|
|
655
654
|
|
|
656
|
-
# === ClientApp status messages ===
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
def clientappstatus_to_proto(
|
|
660
|
-
status: typing.ClientAppOutputStatus,
|
|
661
|
-
) -> ClientAppOutputStatus:
|
|
662
|
-
"""Serialize `ClientAppOutputStatus` to ProtoBuf."""
|
|
663
|
-
code = ClientAppOutputCode.SUCCESS
|
|
664
|
-
if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
665
|
-
code = ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
666
|
-
if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
|
|
667
|
-
code = ClientAppOutputCode.UNKNOWN_ERROR
|
|
668
|
-
return ClientAppOutputStatus(code=code, message=status.message)
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
def clientappstatus_from_proto(
|
|
672
|
-
msg: ClientAppOutputStatus,
|
|
673
|
-
) -> typing.ClientAppOutputStatus:
|
|
674
|
-
"""Deserialize `ClientAppOutputStatus` from ProtoBuf."""
|
|
675
|
-
code = typing.ClientAppOutputCode.SUCCESS
|
|
676
|
-
if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
677
|
-
code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
678
|
-
if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
|
|
679
|
-
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
|
|
680
|
-
return typing.ClientAppOutputStatus(code=code, message=msg.message)
|
|
681
|
-
|
|
682
|
-
|
|
683
655
|
# === Run status ===
|
|
684
656
|
|
|
685
657
|
|
flwr/common/telemetry.py
CHANGED
|
@@ -181,6 +181,10 @@ class EventType(str, Enum):
|
|
|
181
181
|
RUN_SUPERNODE_ENTER = auto()
|
|
182
182
|
RUN_SUPERNODE_LEAVE = auto()
|
|
183
183
|
|
|
184
|
+
# CLI: `flower-superexec`
|
|
185
|
+
RUN_SUPEREXEC_ENTER = auto()
|
|
186
|
+
RUN_SUPEREXEC_LEAVE = auto()
|
|
187
|
+
|
|
184
188
|
|
|
185
189
|
# Use the ThreadPoolExecutor with max_workers=1 to have a queue
|
|
186
190
|
# and also ensure that telemetry calls are not blocking.
|