flwr 1.19.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +15 -5
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +45 -38
- flwr/cli/utils.py +12 -5
- flwr/client/grpc_adapter_client/connection.py +11 -4
- flwr/client/grpc_rere_client/connection.py +92 -117
- flwr/client/rest_client/connection.py +131 -164
- flwr/common/constant.py +3 -1
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/grpc.py +12 -1
- flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
- flwr/common/inflatable_utils.py +191 -24
- flwr/common/record/array.py +101 -22
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/serde.py +0 -28
- flwr/compat/client/app.py +14 -31
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +12 -20
- flwr/proto/fleet_pb2.pyi +6 -36
- flwr/proto/serverappio_pb2.py +8 -31
- flwr/proto/serverappio_pb2.pyi +0 -152
- flwr/proto/serverappio_pb2_grpc.py +39 -38
- flwr/proto/serverappio_pb2_grpc.pyi +21 -20
- flwr/server/app.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +4 -0
- flwr/server/grid/grpc_grid.py +91 -54
- flwr/server/serverapp/app.py +27 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +35 -43
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/superlink/utils.py +0 -35
- flwr/simulation/app.py +8 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/in_memory_object_store.py +31 -31
- flwr/supercore/object_store/object_store.py +20 -42
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/supercore/utils.py +32 -0
- flwr/superexec/deployment.py +1 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -0
- flwr/superexec/exec_grpc.py +18 -2
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +10 -1
- flwr/superexec/exec_user_auth_interceptor.py +10 -2
- flwr/superexec/executor.py +1 -1
- flwr/superexec/simulation.py +1 -2
- flwr/supernode/cli/flower_supernode.py +0 -7
- flwr/supernode/cli/flwr_clientapp.py +10 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +11 -2
- flwr/supernode/nodestate/nodestate.py +15 -0
- flwr/supernode/runtime/run_clientapp.py +110 -33
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/clientappio/__init__.py +1 -3
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +223 -164
- flwr/supernode/start_client_internal.py +202 -104
- {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/METADATA +2 -1
- {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/RECORD +93 -78
- flwr/common/inflatable_rest_utils.py +0 -99
- /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +0 -0
- {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +0 -0
flwr/common/inflatable_utils.py
CHANGED
|
@@ -15,10 +15,14 @@
|
|
|
15
15
|
"""InflatableObject utilities."""
|
|
16
16
|
|
|
17
17
|
import concurrent.futures
|
|
18
|
+
import os
|
|
18
19
|
import random
|
|
19
20
|
import threading
|
|
20
21
|
import time
|
|
21
|
-
from
|
|
22
|
+
from collections.abc import Iterable, Iterator
|
|
23
|
+
from typing import Callable, Optional, TypeVar
|
|
24
|
+
|
|
25
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
22
26
|
|
|
23
27
|
from .constant import (
|
|
24
28
|
HEAD_BODY_DIVIDER,
|
|
@@ -30,6 +34,7 @@ from .constant import (
|
|
|
30
34
|
PULL_MAX_TIME,
|
|
31
35
|
PULL_MAX_TRIES_PER_OBJECT,
|
|
32
36
|
)
|
|
37
|
+
from .exit_handlers import add_exit_handler
|
|
33
38
|
from .inflatable import (
|
|
34
39
|
InflatableObject,
|
|
35
40
|
UnexpectedObjectContentError,
|
|
@@ -37,12 +42,15 @@ from .inflatable import (
|
|
|
37
42
|
get_object_head_values_from_object_content,
|
|
38
43
|
get_object_id,
|
|
39
44
|
is_valid_sha256_hash,
|
|
45
|
+
iterate_object_tree,
|
|
40
46
|
)
|
|
41
47
|
from .message import Message
|
|
42
48
|
from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
|
49
|
+
from .record.arraychunk import ArrayChunk
|
|
43
50
|
|
|
44
51
|
# Helper registry that maps names of classes to their type
|
|
45
52
|
inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
|
53
|
+
ArrayChunk.__qualname__: ArrayChunk,
|
|
46
54
|
Array.__qualname__: Array,
|
|
47
55
|
ArrayRecord.__qualname__: ArrayRecord,
|
|
48
56
|
ConfigRecord.__qualname__: ConfigRecord,
|
|
@@ -51,6 +59,36 @@ inflatable_class_registry: dict[str, type[InflatableObject]] = {
|
|
|
51
59
|
RecordDict.__qualname__: RecordDict,
|
|
52
60
|
}
|
|
53
61
|
|
|
62
|
+
T = TypeVar("T", bound=InflatableObject)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Allow thread pool executors to be shut down gracefully
|
|
66
|
+
_thread_pool_executors: set[concurrent.futures.ThreadPoolExecutor] = set()
|
|
67
|
+
_lock = threading.Lock()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _shutdown_thread_pool_executors() -> None:
|
|
71
|
+
"""Shutdown all thread pool executors gracefully."""
|
|
72
|
+
with _lock:
|
|
73
|
+
for executor in _thread_pool_executors:
|
|
74
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
|
75
|
+
_thread_pool_executors.clear()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _track_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
|
|
79
|
+
"""Track a thread pool executor for graceful shutdown."""
|
|
80
|
+
with _lock:
|
|
81
|
+
_thread_pool_executors.add(executor)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _untrack_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
|
|
85
|
+
"""Untrack a thread pool executor."""
|
|
86
|
+
with _lock:
|
|
87
|
+
_thread_pool_executors.discard(executor)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
add_exit_handler(_shutdown_thread_pool_executors)
|
|
91
|
+
|
|
54
92
|
|
|
55
93
|
class ObjectUnavailableError(Exception):
|
|
56
94
|
"""Exception raised when an object has been pre-registered but is not yet
|
|
@@ -67,6 +105,13 @@ class ObjectIdNotPreregisteredError(Exception):
|
|
|
67
105
|
super().__init__(f"Object with ID '{object_id}' could not be found.")
|
|
68
106
|
|
|
69
107
|
|
|
108
|
+
def get_num_workers(max_concurrent: int) -> int:
|
|
109
|
+
"""Get number of workers based on the number of CPU cores and the maximum
|
|
110
|
+
allowed."""
|
|
111
|
+
num_cores = os.cpu_count() or 1
|
|
112
|
+
return min(max_concurrent, num_cores)
|
|
113
|
+
|
|
114
|
+
|
|
70
115
|
def push_objects(
|
|
71
116
|
objects: dict[str, InflatableObject],
|
|
72
117
|
push_object_fn: Callable[[str, bytes], None],
|
|
@@ -95,27 +140,73 @@ def push_objects(
|
|
|
95
140
|
max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
|
|
96
141
|
The maximum number of concurrent pushes to perform.
|
|
97
142
|
"""
|
|
98
|
-
if object_ids_to_push is not None:
|
|
99
|
-
# Filter objects to push only those with IDs in the set
|
|
100
|
-
objects = {k: v for k, v in objects.items() if k in object_ids_to_push}
|
|
101
|
-
|
|
102
143
|
lock = threading.Lock()
|
|
103
144
|
|
|
104
|
-
def
|
|
145
|
+
def iter_dict_items() -> Iterator[tuple[str, bytes]]:
|
|
146
|
+
"""Iterate over the dictionary items."""
|
|
147
|
+
for obj_id in list(objects.keys()):
|
|
148
|
+
# Skip the object if no need to push it
|
|
149
|
+
if object_ids_to_push is not None and obj_id not in object_ids_to_push:
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
# Deflate the object content
|
|
153
|
+
object_content = objects[obj_id].deflate()
|
|
154
|
+
if not keep_objects:
|
|
155
|
+
with lock:
|
|
156
|
+
del objects[obj_id]
|
|
157
|
+
|
|
158
|
+
yield obj_id, object_content
|
|
159
|
+
|
|
160
|
+
push_object_contents_from_iterable(
|
|
161
|
+
iter_dict_items(),
|
|
162
|
+
push_object_fn,
|
|
163
|
+
max_concurrent_pushes=max_concurrent_pushes,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def push_object_contents_from_iterable(
|
|
168
|
+
object_contents: Iterable[tuple[str, bytes]],
|
|
169
|
+
push_object_fn: Callable[[str, bytes], None],
|
|
170
|
+
*,
|
|
171
|
+
max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
|
|
172
|
+
) -> None:
|
|
173
|
+
"""Push multiple object contents to the servicer.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
object_contents : Iterable[tuple[str, bytes]]
|
|
178
|
+
An iterable of `(object_id, object_content)` pairs.
|
|
179
|
+
`object_id` is the object ID, and `object_content` is the object content.
|
|
180
|
+
push_object_fn : Callable[[str, bytes], None]
|
|
181
|
+
A function that takes an object ID and its content as bytes, and pushes
|
|
182
|
+
it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
|
|
183
|
+
if the object ID is not pre-registered.
|
|
184
|
+
max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
|
|
185
|
+
The maximum number of concurrent pushes to perform.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def push(args: tuple[str, bytes]) -> None:
|
|
105
189
|
"""Push a single object."""
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
190
|
+
obj_id, obj_content = args
|
|
191
|
+
# Push the object using the provided function
|
|
192
|
+
push_object_fn(obj_id, obj_content)
|
|
193
|
+
|
|
194
|
+
# Push all object contents concurrently
|
|
195
|
+
num_workers = get_num_workers(max_concurrent_pushes)
|
|
196
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
197
|
+
# Ensure that the thread pool executors are tracked for graceful shutdown
|
|
198
|
+
_track_executor(executor)
|
|
111
199
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
) as executor:
|
|
115
|
-
list(executor.map(push, list(objects.keys())))
|
|
200
|
+
# Submit push tasks for each object content
|
|
201
|
+
executor.map(push, object_contents) # Non-blocking map
|
|
116
202
|
|
|
203
|
+
# The context manager will block until all submitted tasks have completed
|
|
117
204
|
|
|
118
|
-
|
|
205
|
+
# Remove the executor from the list of tracked executors
|
|
206
|
+
_untrack_executor(executor)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
119
210
|
object_ids: list[str],
|
|
120
211
|
pull_object_fn: Callable[[str], bytes],
|
|
121
212
|
*,
|
|
@@ -207,16 +298,20 @@ def pull_objects( # pylint: disable=too-many-arguments
|
|
|
207
298
|
return
|
|
208
299
|
|
|
209
300
|
# Submit all pull tasks concurrently
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
301
|
+
num_workers = get_num_workers(max_concurrent_pulls)
|
|
302
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
303
|
+
# Ensure that the thread pool executors are tracked for graceful shutdown
|
|
304
|
+
_track_executor(executor)
|
|
305
|
+
|
|
306
|
+
# Submit pull tasks for each object ID
|
|
307
|
+
executor.map(pull_with_retries, object_ids) # Non-blocking map
|
|
308
|
+
|
|
309
|
+
# The context manager will block until all submitted tasks have completed
|
|
216
310
|
|
|
217
|
-
|
|
218
|
-
|
|
311
|
+
# Remove the executor from the list of tracked executors
|
|
312
|
+
_untrack_executor(executor)
|
|
219
313
|
|
|
314
|
+
# If an error occurred during pulling, raise it
|
|
220
315
|
if err_to_raise is not None:
|
|
221
316
|
raise err_to_raise
|
|
222
317
|
|
|
@@ -339,3 +434,75 @@ def validate_object_content(content: bytes) -> None:
|
|
|
339
434
|
raise UnexpectedObjectContentError(
|
|
340
435
|
object_id=get_object_id(content), reason=str(err)
|
|
341
436
|
) from err
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def pull_and_inflate_object_from_tree( # pylint: disable=R0913
|
|
440
|
+
object_tree: ObjectTree,
|
|
441
|
+
pull_object_fn: Callable[[str], bytes],
|
|
442
|
+
confirm_object_received_fn: Callable[[str], None],
|
|
443
|
+
*,
|
|
444
|
+
return_type: type[T] = InflatableObject, # type: ignore
|
|
445
|
+
max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
|
|
446
|
+
max_time: Optional[float] = PULL_MAX_TIME,
|
|
447
|
+
max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
|
|
448
|
+
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
|
449
|
+
backoff_cap: float = PULL_BACKOFF_CAP,
|
|
450
|
+
) -> T:
|
|
451
|
+
"""Pull and inflate the head object from the provided object tree.
|
|
452
|
+
|
|
453
|
+
Parameters
|
|
454
|
+
----------
|
|
455
|
+
object_tree : ObjectTree
|
|
456
|
+
The object tree containing the object ID and its descendants.
|
|
457
|
+
pull_object_fn : Callable[[str], bytes]
|
|
458
|
+
A function that takes an object ID and returns the object content as bytes.
|
|
459
|
+
confirm_object_received_fn : Callable[[str], None]
|
|
460
|
+
A function to confirm that the object has been received.
|
|
461
|
+
return_type : type[T] (default: InflatableObject)
|
|
462
|
+
The type of the object to return. Must be a subclass of `InflatableObject`.
|
|
463
|
+
max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
|
|
464
|
+
The maximum number of concurrent pulls to perform.
|
|
465
|
+
max_time : Optional[float] (default: PULL_MAX_TIME)
|
|
466
|
+
The maximum time to wait for all pulls to complete. If `None`, waits
|
|
467
|
+
indefinitely.
|
|
468
|
+
max_tries_per_object : Optional[int] (default: PULL_MAX_TRIES_PER_OBJECT)
|
|
469
|
+
The maximum number of attempts to pull each object. If `None`, pulls
|
|
470
|
+
indefinitely until the object is available.
|
|
471
|
+
initial_backoff : float (default: PULL_INITIAL_BACKOFF)
|
|
472
|
+
The initial backoff time in seconds for retrying pulls after an
|
|
473
|
+
`ObjectUnavailableError`.
|
|
474
|
+
backoff_cap : float (default: PULL_BACKOFF_CAP)
|
|
475
|
+
The maximum backoff time in seconds. Backoff times will not exceed this value.
|
|
476
|
+
|
|
477
|
+
Returns
|
|
478
|
+
-------
|
|
479
|
+
T
|
|
480
|
+
An instance of the specified return type containing the inflated object.
|
|
481
|
+
"""
|
|
482
|
+
# Pull the main object and all its descendants
|
|
483
|
+
pulled_object_contents = pull_objects(
|
|
484
|
+
[tree.object_id for tree in iterate_object_tree(object_tree)],
|
|
485
|
+
pull_object_fn,
|
|
486
|
+
max_concurrent_pulls=max_concurrent_pulls,
|
|
487
|
+
max_time=max_time,
|
|
488
|
+
max_tries_per_object=max_tries_per_object,
|
|
489
|
+
initial_backoff=initial_backoff,
|
|
490
|
+
backoff_cap=backoff_cap,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Confirm that all objects were pulled
|
|
494
|
+
confirm_object_received_fn(object_tree.object_id)
|
|
495
|
+
|
|
496
|
+
# Inflate the main object
|
|
497
|
+
inflated_object = inflate_object_from_contents(
|
|
498
|
+
object_tree.object_id, pulled_object_contents, keep_object_contents=False
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# Check if the inflated object is of the expected type
|
|
502
|
+
if not isinstance(inflated_object, return_type):
|
|
503
|
+
raise TypeError(
|
|
504
|
+
f"Expected object of type {return_type.__name__}, "
|
|
505
|
+
f"but got {type(inflated_object).__name__}."
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
return inflated_object
|
flwr/common/record/array.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
+
import json
|
|
20
21
|
import sys
|
|
21
22
|
from dataclasses import dataclass
|
|
22
23
|
from io import BytesIO
|
|
@@ -24,11 +25,15 @@ from typing import TYPE_CHECKING, Any, cast, overload
|
|
|
24
25
|
|
|
25
26
|
import numpy as np
|
|
26
27
|
|
|
27
|
-
from
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
from ..constant import MAX_ARRAY_CHUNK_SIZE, SType
|
|
29
|
+
from ..inflatable import (
|
|
30
|
+
InflatableObject,
|
|
31
|
+
add_header_to_object_body,
|
|
32
|
+
get_object_body,
|
|
33
|
+
get_object_children_ids_from_object_content,
|
|
34
|
+
)
|
|
31
35
|
from ..typing import NDArray
|
|
36
|
+
from .arraychunk import ArrayChunk
|
|
32
37
|
|
|
33
38
|
if TYPE_CHECKING:
|
|
34
39
|
import torch
|
|
@@ -252,16 +257,56 @@ class Array(InflatableObject):
|
|
|
252
257
|
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
|
253
258
|
return cast(NDArray, ndarray_deserialized)
|
|
254
259
|
|
|
260
|
+
@property
|
|
261
|
+
def children(self) -> dict[str, InflatableObject]:
|
|
262
|
+
"""Return a dictionary of ArrayChunks with their Object IDs as keys."""
|
|
263
|
+
return dict(self.slice_array())
|
|
264
|
+
|
|
265
|
+
def slice_array(self) -> list[tuple[str, InflatableObject]]:
|
|
266
|
+
"""Slice Array data and construct a list of ArrayChunks."""
|
|
267
|
+
# Return cached chunks if they exist
|
|
268
|
+
if "_chunks" in self.__dict__:
|
|
269
|
+
return cast(list[tuple[str, InflatableObject]], self.__dict__["_chunks"])
|
|
270
|
+
|
|
271
|
+
# Chunks are not children as some of them may be identical
|
|
272
|
+
chunks: list[tuple[str, InflatableObject]] = []
|
|
273
|
+
# memoryview allows for zero-copy slicing
|
|
274
|
+
data_view = memoryview(self.data)
|
|
275
|
+
for start in range(0, len(data_view), MAX_ARRAY_CHUNK_SIZE):
|
|
276
|
+
end = min(start + MAX_ARRAY_CHUNK_SIZE, len(data_view))
|
|
277
|
+
ac = ArrayChunk(data_view[start:end])
|
|
278
|
+
chunks.append((ac.object_id, ac))
|
|
279
|
+
|
|
280
|
+
# Cache the chunks for future use
|
|
281
|
+
self.__dict__["_chunks"] = chunks
|
|
282
|
+
return chunks
|
|
283
|
+
|
|
255
284
|
def deflate(self) -> bytes:
|
|
256
285
|
"""Deflate the Array."""
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
|
|
286
|
+
array_metadata: dict[str, str | tuple[int, ...] | list[int]] = {}
|
|
287
|
+
|
|
288
|
+
# We want to record all object_id even if repeated
|
|
289
|
+
# it can happend that chunks carry the exact same data
|
|
290
|
+
# for example when the array has only zeros
|
|
291
|
+
children_list = self.slice_array()
|
|
292
|
+
# Let's not save the entire object_id but a mapping to those
|
|
293
|
+
# that will be carried in the object head
|
|
294
|
+
# (replace a long object_id with a single scalar)
|
|
295
|
+
unique_children = list(self.children.keys())
|
|
296
|
+
arraychunk_ids = [unique_children.index(ch_id) for ch_id, _ in children_list]
|
|
297
|
+
|
|
298
|
+
# The deflated Array carries everything but the data
|
|
299
|
+
# The `arraychunk_ids` will be used during Array inflation
|
|
300
|
+
# to rematerialize the data from ArrayChunk objects.
|
|
301
|
+
array_metadata = {
|
|
302
|
+
"dtype": self.dtype,
|
|
303
|
+
"shape": self.shape,
|
|
304
|
+
"stype": self.stype,
|
|
305
|
+
"arraychunk_ids": arraychunk_ids,
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
# Serialize metadata dict
|
|
309
|
+
obj_body = json.dumps(array_metadata).encode("utf-8")
|
|
265
310
|
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
266
311
|
|
|
267
312
|
@classmethod
|
|
@@ -276,26 +321,55 @@ class Array(InflatableObject):
|
|
|
276
321
|
The deflated object content of the Array.
|
|
277
322
|
|
|
278
323
|
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
279
|
-
Must be ``None``. ``Array``
|
|
280
|
-
Providing
|
|
324
|
+
Must be ``None``. ``Array`` must have child objects.
|
|
325
|
+
Providing no children will raise a ``ValueError``.
|
|
281
326
|
|
|
282
327
|
Returns
|
|
283
328
|
-------
|
|
284
329
|
Array
|
|
285
330
|
The inflated Array.
|
|
286
331
|
"""
|
|
287
|
-
if children:
|
|
288
|
-
|
|
332
|
+
if children is None:
|
|
333
|
+
children = {}
|
|
289
334
|
|
|
290
335
|
obj_body = get_object_body(object_content, cls)
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
336
|
+
|
|
337
|
+
# Extract children IDs from head
|
|
338
|
+
children_ids = get_object_children_ids_from_object_content(object_content)
|
|
339
|
+
# Decode the Array body
|
|
340
|
+
array_metadata: dict[str, str | tuple[int, ...] | list[int]] = json.loads(
|
|
341
|
+
obj_body.decode(encoding="utf-8")
|
|
297
342
|
)
|
|
298
343
|
|
|
344
|
+
# Verify children ids in body match those passed for inflation
|
|
345
|
+
chunk_ids_indices = cast(list[int], array_metadata["arraychunk_ids"])
|
|
346
|
+
# Convert indices back to IDs
|
|
347
|
+
chunk_ids = [children_ids[i] for i in chunk_ids_indices]
|
|
348
|
+
# Check consistency
|
|
349
|
+
unique_arrayschunks = set(chunk_ids)
|
|
350
|
+
children_obj_ids = set(children.keys())
|
|
351
|
+
if unique_arrayschunks != children_obj_ids:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
"Unexpected set of `children`. "
|
|
354
|
+
f"Expected {unique_arrayschunks} but got {children_obj_ids}."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Materialize Array with empty data
|
|
358
|
+
array = cls(
|
|
359
|
+
dtype=cast(str, array_metadata["dtype"]),
|
|
360
|
+
shape=cast(tuple[int], tuple(array_metadata["shape"])),
|
|
361
|
+
stype=cast(str, array_metadata["stype"]),
|
|
362
|
+
data=b"",
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Now inject data from chunks
|
|
366
|
+
buff = bytearray()
|
|
367
|
+
for ch_id in chunk_ids:
|
|
368
|
+
buff += cast(ArrayChunk, children[ch_id]).data
|
|
369
|
+
|
|
370
|
+
array.data = bytes(buff)
|
|
371
|
+
return array
|
|
372
|
+
|
|
299
373
|
@property
|
|
300
374
|
def object_id(self) -> str:
|
|
301
375
|
"""Get object ID."""
|
|
@@ -320,4 +394,9 @@ class Array(InflatableObject):
|
|
|
320
394
|
if name in ("dtype", "shape", "stype", "data"):
|
|
321
395
|
# Mark as dirty if any of the main attributes are set
|
|
322
396
|
self.is_dirty = True
|
|
397
|
+
# Clear cached object ID
|
|
398
|
+
self.__dict__.pop("_object_id", None)
|
|
399
|
+
# Clear cached chunks if data is set
|
|
400
|
+
if name == "data":
|
|
401
|
+
self.__dict__.pop("_chunks", None)
|
|
323
402
|
super().__setattr__(name, value)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ArrayChunk."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ArrayChunk(InflatableObject):
|
|
27
|
+
"""ArrayChunk type."""
|
|
28
|
+
|
|
29
|
+
data: memoryview
|
|
30
|
+
|
|
31
|
+
def deflate(self) -> bytes:
|
|
32
|
+
"""Deflate the ArrayChunk."""
|
|
33
|
+
return add_header_to_object_body(object_body=self.data, obj=self)
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def inflate(
|
|
37
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
38
|
+
) -> ArrayChunk:
|
|
39
|
+
"""Inflate an ArrayChunk from bytes.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
object_content : bytes
|
|
44
|
+
The deflated object content of the ArrayChunk.
|
|
45
|
+
|
|
46
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
47
|
+
Must be ``None``. ``ArrayChunk`` does not support child objects.
|
|
48
|
+
Providing any children will raise a ``ValueError``.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
ArrayChunk
|
|
53
|
+
The inflated ArrayChunk.
|
|
54
|
+
"""
|
|
55
|
+
if children:
|
|
56
|
+
raise ValueError("`ArrayChunk` objects do not have children.")
|
|
57
|
+
|
|
58
|
+
obj_body = get_object_body(object_content, cls)
|
|
59
|
+
return cls(data=memoryview(obj_body))
|
flwr/common/serde.py
CHANGED
|
@@ -19,7 +19,6 @@ from collections import OrderedDict
|
|
|
19
19
|
from typing import Any, cast
|
|
20
20
|
|
|
21
21
|
# pylint: disable=E0611
|
|
22
|
-
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
23
22
|
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
24
23
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
25
24
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
@@ -653,33 +652,6 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
|
653
652
|
return run
|
|
654
653
|
|
|
655
654
|
|
|
656
|
-
# === ClientApp status messages ===
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
def clientappstatus_to_proto(
|
|
660
|
-
status: typing.ClientAppOutputStatus,
|
|
661
|
-
) -> ClientAppOutputStatus:
|
|
662
|
-
"""Serialize `ClientAppOutputStatus` to ProtoBuf."""
|
|
663
|
-
code = ClientAppOutputCode.SUCCESS
|
|
664
|
-
if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
665
|
-
code = ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
666
|
-
if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
|
|
667
|
-
code = ClientAppOutputCode.UNKNOWN_ERROR
|
|
668
|
-
return ClientAppOutputStatus(code=code, message=status.message)
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
def clientappstatus_from_proto(
|
|
672
|
-
msg: ClientAppOutputStatus,
|
|
673
|
-
) -> typing.ClientAppOutputStatus:
|
|
674
|
-
"""Deserialize `ClientAppOutputStatus` from ProtoBuf."""
|
|
675
|
-
code = typing.ClientAppOutputCode.SUCCESS
|
|
676
|
-
if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
677
|
-
code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
678
|
-
if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
|
|
679
|
-
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
|
|
680
|
-
return typing.ClientAppOutputStatus(code=code, message=msg.message)
|
|
681
|
-
|
|
682
|
-
|
|
683
655
|
# === Run status ===
|
|
684
656
|
|
|
685
657
|
|
flwr/compat/client/app.py
CHANGED
|
@@ -29,8 +29,6 @@ from flwr.cli.config_utils import get_fab_metadata
|
|
|
29
29
|
from flwr.cli.install import install_from_fab
|
|
30
30
|
from flwr.client.client import Client
|
|
31
31
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
32
|
-
from flwr.client.grpc_adapter_client.connection import grpc_adapter
|
|
33
|
-
from flwr.client.grpc_rere_client.connection import grpc_request_response
|
|
34
32
|
from flwr.client.message_handler.message_handler import handle_control_message
|
|
35
33
|
from flwr.client.numpy_client import NumPyClient
|
|
36
34
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
@@ -39,10 +37,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, ev
|
|
|
39
37
|
from flwr.common.address import parse_address
|
|
40
38
|
from flwr.common.constant import (
|
|
41
39
|
MAX_RETRY_DELAY,
|
|
42
|
-
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
43
40
|
TRANSPORT_TYPE_GRPC_BIDI,
|
|
44
|
-
TRANSPORT_TYPE_GRPC_RERE,
|
|
45
|
-
TRANSPORT_TYPE_REST,
|
|
46
41
|
TRANSPORT_TYPES,
|
|
47
42
|
ErrorCode,
|
|
48
43
|
)
|
|
@@ -121,10 +116,8 @@ def start_client(
|
|
|
121
116
|
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
|
122
117
|
when False, using system certificates if `root_certificates` is None.
|
|
123
118
|
transport : Optional[str] (default: None)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
- 'grpc-rere': gRPC, request-response (experimental)
|
|
127
|
-
- 'rest': HTTP (experimental)
|
|
119
|
+
**[Deprecated]** This argument is no longer supported and will be
|
|
120
|
+
removed in a future release.
|
|
128
121
|
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
129
122
|
Tuple containing the elliptic curve private key and public key for
|
|
130
123
|
authentication from the cryptography library.
|
|
@@ -180,6 +173,12 @@ def start_client(
|
|
|
180
173
|
)
|
|
181
174
|
warn_deprecated_feature(name=msg)
|
|
182
175
|
|
|
176
|
+
if transport is not None and transport != "grpc-bidi":
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Transport type {transport} is not supported. "
|
|
179
|
+
"Use 'grpc-bidi' or None (default) instead."
|
|
180
|
+
)
|
|
181
|
+
|
|
183
182
|
event(EventType.START_CLIENT_ENTER)
|
|
184
183
|
start_client_internal(
|
|
185
184
|
server_address=server_address,
|
|
@@ -429,7 +428,7 @@ def start_client_internal(
|
|
|
429
428
|
|
|
430
429
|
run: Run = runs[run_id]
|
|
431
430
|
if get_fab is not None and run.fab_hash:
|
|
432
|
-
fab = get_fab(run.fab_hash, run_id)
|
|
431
|
+
fab = get_fab(run.fab_hash, run_id) # pylint: disable=E1102
|
|
433
432
|
# If `ClientApp` runs in the same process, install the FAB
|
|
434
433
|
install_from_fab(fab.content, flwr_path, True)
|
|
435
434
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
@@ -573,10 +572,8 @@ def start_numpy_client(
|
|
|
573
572
|
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
|
574
573
|
when False, using system certificates if `root_certificates` is None.
|
|
575
574
|
transport : Optional[str] (default: None)
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
- 'grpc-rere': gRPC, request-response (experimental)
|
|
579
|
-
- 'rest': HTTP (experimental)
|
|
575
|
+
**[Deprecated]** This argument is no longer supported and will be
|
|
576
|
+
removed in a future release.
|
|
580
577
|
|
|
581
578
|
Examples
|
|
582
579
|
--------
|
|
@@ -672,23 +669,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
|
|
|
672
669
|
if transport is None:
|
|
673
670
|
transport = TRANSPORT_TYPE_GRPC_BIDI
|
|
674
671
|
|
|
675
|
-
# Use
|
|
676
|
-
if transport ==
|
|
677
|
-
|
|
678
|
-
from requests.exceptions import ConnectionError as RequestsConnectionError
|
|
679
|
-
|
|
680
|
-
from flwr.client.rest_client.connection import http_request_response
|
|
681
|
-
except ModuleNotFoundError:
|
|
682
|
-
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
683
|
-
if server_address[:4] != "http":
|
|
684
|
-
flwr_exit(ExitCode.SUPERNODE_REST_ADDRESS_INVALID)
|
|
685
|
-
connection, error_type = http_request_response, RequestsConnectionError
|
|
686
|
-
elif transport == TRANSPORT_TYPE_GRPC_RERE:
|
|
687
|
-
connection, error_type = grpc_request_response, RpcError
|
|
688
|
-
elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
689
|
-
connection, error_type = grpc_adapter, RpcError
|
|
690
|
-
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
|
|
691
|
-
connection, error_type = grpc_connection, RpcError # type: ignore[assignment]
|
|
672
|
+
# Use gRPC bidirectional streaming
|
|
673
|
+
if transport == TRANSPORT_TYPE_GRPC_BIDI:
|
|
674
|
+
connection, error_type = grpc_connection, RpcError
|
|
692
675
|
else:
|
|
693
676
|
raise ValueError(
|
|
694
677
|
f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
|