flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
flwr/common/inflatable_utils.py
CHANGED
|
@@ -14,21 +14,23 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""InflatableObject utilities."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import concurrent.futures
|
|
18
19
|
import os
|
|
19
20
|
import random
|
|
20
21
|
import threading
|
|
21
22
|
import time
|
|
22
|
-
from collections.abc import Iterable, Iterator
|
|
23
|
-
from
|
|
23
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
24
|
+
from queue import Queue
|
|
25
|
+
from typing import TypeVar
|
|
24
26
|
|
|
25
27
|
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
26
28
|
|
|
27
29
|
from .constant import (
|
|
30
|
+
FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
|
|
31
|
+
FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
|
|
28
32
|
HEAD_BODY_DIVIDER,
|
|
29
33
|
HEAD_VALUE_DIVIDER,
|
|
30
|
-
MAX_CONCURRENT_PULLS,
|
|
31
|
-
MAX_CONCURRENT_PUSHES,
|
|
32
34
|
PULL_BACKOFF_CAP,
|
|
33
35
|
PULL_INITIAL_BACKOFF,
|
|
34
36
|
PULL_MAX_TIME,
|
|
@@ -116,9 +118,9 @@ def push_objects(
|
|
|
116
118
|
objects: dict[str, InflatableObject],
|
|
117
119
|
push_object_fn: Callable[[str, bytes], None],
|
|
118
120
|
*,
|
|
119
|
-
object_ids_to_push:
|
|
121
|
+
object_ids_to_push: set[str] | None = None,
|
|
120
122
|
keep_objects: bool = False,
|
|
121
|
-
max_concurrent_pushes: int =
|
|
123
|
+
max_concurrent_pushes: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
|
|
122
124
|
) -> None:
|
|
123
125
|
"""Push multiple objects to the servicer.
|
|
124
126
|
|
|
@@ -137,7 +139,7 @@ def push_objects(
|
|
|
137
139
|
If `True`, the original objects will be kept in the `objects` dictionary
|
|
138
140
|
after pushing. If `False`, they will be removed from the dictionary to avoid
|
|
139
141
|
high memory usage.
|
|
140
|
-
max_concurrent_pushes : int (default:
|
|
142
|
+
max_concurrent_pushes : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES)
|
|
141
143
|
The maximum number of concurrent pushes to perform.
|
|
142
144
|
"""
|
|
143
145
|
lock = threading.Lock()
|
|
@@ -168,7 +170,7 @@ def push_object_contents_from_iterable(
|
|
|
168
170
|
object_contents: Iterable[tuple[str, bytes]],
|
|
169
171
|
push_object_fn: Callable[[str, bytes], None],
|
|
170
172
|
*,
|
|
171
|
-
max_concurrent_pushes: int =
|
|
173
|
+
max_concurrent_pushes: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES,
|
|
172
174
|
) -> None:
|
|
173
175
|
"""Push multiple object contents to the servicer.
|
|
174
176
|
|
|
@@ -181,15 +183,24 @@ def push_object_contents_from_iterable(
|
|
|
181
183
|
A function that takes an object ID and its content as bytes, and pushes
|
|
182
184
|
it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
|
|
183
185
|
if the object ID is not pre-registered.
|
|
184
|
-
max_concurrent_pushes : int (default:
|
|
186
|
+
max_concurrent_pushes : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PUSHES)
|
|
185
187
|
The maximum number of concurrent pushes to perform.
|
|
186
188
|
"""
|
|
189
|
+
error_event = threading.Event()
|
|
190
|
+
err_queue: Queue[Exception] = Queue()
|
|
187
191
|
|
|
188
192
|
def push(args: tuple[str, bytes]) -> None:
|
|
189
193
|
"""Push a single object."""
|
|
194
|
+
if error_event.is_set():
|
|
195
|
+
return
|
|
190
196
|
obj_id, obj_content = args
|
|
191
197
|
# Push the object using the provided function
|
|
192
|
-
|
|
198
|
+
try:
|
|
199
|
+
push_object_fn(obj_id, obj_content)
|
|
200
|
+
except Exception as err: # pylint: disable=broad-except
|
|
201
|
+
# Unexpected error during pushing
|
|
202
|
+
error_event.set()
|
|
203
|
+
err_queue.put(err)
|
|
193
204
|
|
|
194
205
|
# Push all object contents concurrently
|
|
195
206
|
num_workers = get_num_workers(max_concurrent_pushes)
|
|
@@ -205,14 +216,18 @@ def push_object_contents_from_iterable(
|
|
|
205
216
|
# Remove the executor from the list of tracked executors
|
|
206
217
|
_untrack_executor(executor)
|
|
207
218
|
|
|
219
|
+
# If an error occurred during pushing, raise it
|
|
220
|
+
if not err_queue.empty():
|
|
221
|
+
raise err_queue.get()
|
|
222
|
+
|
|
208
223
|
|
|
209
224
|
def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
210
225
|
object_ids: list[str],
|
|
211
226
|
pull_object_fn: Callable[[str], bytes],
|
|
212
227
|
*,
|
|
213
|
-
max_concurrent_pulls: int =
|
|
214
|
-
max_time:
|
|
215
|
-
max_tries_per_object:
|
|
228
|
+
max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
|
|
229
|
+
max_time: float | None = PULL_MAX_TIME,
|
|
230
|
+
max_tries_per_object: int | None = PULL_MAX_TRIES_PER_OBJECT,
|
|
216
231
|
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
|
217
232
|
backoff_cap: float = PULL_BACKOFF_CAP,
|
|
218
233
|
) -> dict[str, bytes]:
|
|
@@ -227,7 +242,7 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
227
242
|
The function should raise `ObjectUnavailableError` if the object is not yet
|
|
228
243
|
available, or `ObjectIdNotPreregisteredError` if the object ID is not
|
|
229
244
|
pre-registered.
|
|
230
|
-
max_concurrent_pulls : int (default:
|
|
245
|
+
max_concurrent_pulls : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS)
|
|
231
246
|
The maximum number of concurrent pulls to perform.
|
|
232
247
|
max_time : Optional[float] (default: PULL_MAX_TIME)
|
|
233
248
|
The maximum time to wait for all pulls to complete. If `None`, waits
|
|
@@ -254,13 +269,16 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
254
269
|
|
|
255
270
|
results: dict[str, bytes] = {}
|
|
256
271
|
results_lock = threading.Lock()
|
|
257
|
-
|
|
272
|
+
err_queue: Queue[Exception] = Queue()
|
|
258
273
|
early_stop = threading.Event()
|
|
259
274
|
start = time.monotonic()
|
|
260
275
|
|
|
276
|
+
def stop_on_error(err: Exception) -> None:
|
|
277
|
+
early_stop.set()
|
|
278
|
+
err_queue.put(err)
|
|
279
|
+
|
|
261
280
|
def pull_with_retries(object_id: str) -> None:
|
|
262
281
|
"""Attempt to pull a single object with retry and backoff."""
|
|
263
|
-
nonlocal err_to_raise
|
|
264
282
|
tries = 0
|
|
265
283
|
delay = initial_backoff
|
|
266
284
|
|
|
@@ -278,10 +296,7 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
278
296
|
or time.monotonic() - start >= max_time
|
|
279
297
|
):
|
|
280
298
|
# Stop all work if one object exhausts retries
|
|
281
|
-
|
|
282
|
-
with results_lock:
|
|
283
|
-
if err_to_raise is None:
|
|
284
|
-
err_to_raise = err
|
|
299
|
+
stop_on_error(err)
|
|
285
300
|
return
|
|
286
301
|
|
|
287
302
|
# Apply exponential backoff with ±20% jitter
|
|
@@ -291,10 +306,12 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
291
306
|
|
|
292
307
|
except ObjectIdNotPreregisteredError as err:
|
|
293
308
|
# Permanent failure: object ID is invalid
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
309
|
+
stop_on_error(err)
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
except Exception as err: # pylint: disable=broad-except
|
|
313
|
+
# Permanent failure: unexpected error
|
|
314
|
+
stop_on_error(err)
|
|
298
315
|
return
|
|
299
316
|
|
|
300
317
|
# Submit all pull tasks concurrently
|
|
@@ -312,8 +329,8 @@ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
312
329
|
_untrack_executor(executor)
|
|
313
330
|
|
|
314
331
|
# If an error occurred during pulling, raise it
|
|
315
|
-
if
|
|
316
|
-
raise
|
|
332
|
+
if not err_queue.empty():
|
|
333
|
+
raise err_queue.get()
|
|
317
334
|
|
|
318
335
|
return results
|
|
319
336
|
|
|
@@ -323,7 +340,7 @@ def inflate_object_from_contents(
|
|
|
323
340
|
object_contents: dict[str, bytes],
|
|
324
341
|
*,
|
|
325
342
|
keep_object_contents: bool = False,
|
|
326
|
-
objects:
|
|
343
|
+
objects: dict[str, InflatableObject] | None = None,
|
|
327
344
|
) -> InflatableObject:
|
|
328
345
|
"""Inflate an object from object contents.
|
|
329
346
|
|
|
@@ -442,9 +459,9 @@ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
|
|
|
442
459
|
confirm_object_received_fn: Callable[[str], None],
|
|
443
460
|
*,
|
|
444
461
|
return_type: type[T] = InflatableObject, # type: ignore
|
|
445
|
-
max_concurrent_pulls: int =
|
|
446
|
-
max_time:
|
|
447
|
-
max_tries_per_object:
|
|
462
|
+
max_concurrent_pulls: int = FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS,
|
|
463
|
+
max_time: float | None = PULL_MAX_TIME,
|
|
464
|
+
max_tries_per_object: int | None = PULL_MAX_TRIES_PER_OBJECT,
|
|
448
465
|
initial_backoff: float = PULL_INITIAL_BACKOFF,
|
|
449
466
|
backoff_cap: float = PULL_BACKOFF_CAP,
|
|
450
467
|
) -> T:
|
|
@@ -460,7 +477,7 @@ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
|
|
|
460
477
|
A function to confirm that the object has been received.
|
|
461
478
|
return_type : type[T] (default: InflatableObject)
|
|
462
479
|
The type of the object to return. Must be a subclass of `InflatableObject`.
|
|
463
|
-
max_concurrent_pulls : int (default:
|
|
480
|
+
max_concurrent_pulls : int (default: FLWR_PRIVATE_MAX_CONCURRENT_OBJ_PULLS)
|
|
464
481
|
The maximum number of concurrent pulls to perform.
|
|
465
482
|
max_time : Optional[float] (default: PULL_MAX_TIME)
|
|
466
483
|
The maximum time to wait for all pulls to complete. If `None`, waits
|
flwr/common/logger.py
CHANGED
|
@@ -26,7 +26,7 @@ from io import StringIO
|
|
|
26
26
|
from logging import ERROR, WARN, LogRecord
|
|
27
27
|
from logging.handlers import HTTPHandler
|
|
28
28
|
from queue import Empty, Queue
|
|
29
|
-
from typing import TYPE_CHECKING, Any,
|
|
29
|
+
from typing import TYPE_CHECKING, Any, TextIO
|
|
30
30
|
|
|
31
31
|
import grpc
|
|
32
32
|
import typer
|
|
@@ -68,7 +68,7 @@ class ConsoleHandler(StreamHandler):
|
|
|
68
68
|
timestamps: bool = False,
|
|
69
69
|
json: bool = False,
|
|
70
70
|
colored: bool = True,
|
|
71
|
-
stream:
|
|
71
|
+
stream: TextIO | None = None,
|
|
72
72
|
) -> None:
|
|
73
73
|
super().__init__(stream)
|
|
74
74
|
self.timestamps = timestamps
|
|
@@ -103,9 +103,9 @@ class ConsoleHandler(StreamHandler):
|
|
|
103
103
|
|
|
104
104
|
|
|
105
105
|
def update_console_handler(
|
|
106
|
-
level:
|
|
107
|
-
timestamps:
|
|
108
|
-
colored:
|
|
106
|
+
level: int | str | None = None,
|
|
107
|
+
timestamps: bool | None = None,
|
|
108
|
+
colored: bool | None = None,
|
|
109
109
|
) -> None:
|
|
110
110
|
"""Update the logging handler."""
|
|
111
111
|
for handler in logging.getLogger(LOGGER_NAME).handlers:
|
|
@@ -160,7 +160,7 @@ class CustomHTTPHandler(HTTPHandler):
|
|
|
160
160
|
url: str,
|
|
161
161
|
method: str = "GET",
|
|
162
162
|
secure: bool = False,
|
|
163
|
-
credentials:
|
|
163
|
+
credentials: tuple[str, str] | None = None,
|
|
164
164
|
) -> None:
|
|
165
165
|
super().__init__(host, url, method, secure, credentials)
|
|
166
166
|
self.identifier = identifier
|
|
@@ -180,7 +180,7 @@ class CustomHTTPHandler(HTTPHandler):
|
|
|
180
180
|
|
|
181
181
|
|
|
182
182
|
def configure(
|
|
183
|
-
identifier: str, filename:
|
|
183
|
+
identifier: str, filename: str | None = None, host: str | None = None
|
|
184
184
|
) -> None:
|
|
185
185
|
"""Configure logging to file and/or remote log server."""
|
|
186
186
|
# Create formatter
|
|
@@ -298,7 +298,7 @@ def set_logger_propagation(
|
|
|
298
298
|
return child_logger
|
|
299
299
|
|
|
300
300
|
|
|
301
|
-
def mirror_output_to_queue(log_queue: Queue[
|
|
301
|
+
def mirror_output_to_queue(log_queue: Queue[str | None]) -> None:
|
|
302
302
|
"""Mirror stdout and stderr output to the provided queue."""
|
|
303
303
|
|
|
304
304
|
def get_write_fn(stream: TextIO) -> Any:
|
|
@@ -335,7 +335,7 @@ def redirect_output(output_buffer: StringIO) -> None:
|
|
|
335
335
|
|
|
336
336
|
|
|
337
337
|
def _log_uploader(
|
|
338
|
-
log_queue: Queue[
|
|
338
|
+
log_queue: Queue[str | None], node_id: int, run_id: int, stub: ServerAppIoStub
|
|
339
339
|
) -> None:
|
|
340
340
|
"""Upload logs to the SuperLink."""
|
|
341
341
|
exit_flag = False
|
|
@@ -378,10 +378,10 @@ def _log_uploader(
|
|
|
378
378
|
|
|
379
379
|
|
|
380
380
|
def start_log_uploader(
|
|
381
|
-
log_queue: Queue[
|
|
381
|
+
log_queue: Queue[str | None],
|
|
382
382
|
node_id: int,
|
|
383
383
|
run_id: int,
|
|
384
|
-
stub:
|
|
384
|
+
stub: ServerAppIoStub | SimulationIoStub,
|
|
385
385
|
) -> threading.Thread:
|
|
386
386
|
"""Start the log uploader thread and return it."""
|
|
387
387
|
thread = threading.Thread(
|
|
@@ -392,7 +392,7 @@ def start_log_uploader(
|
|
|
392
392
|
|
|
393
393
|
|
|
394
394
|
def stop_log_uploader(
|
|
395
|
-
log_queue: Queue[
|
|
395
|
+
log_queue: Queue[str | None], log_uploader: threading.Thread
|
|
396
396
|
) -> None:
|
|
397
397
|
"""Stop the log uploader thread."""
|
|
398
398
|
log_queue.put(None)
|
|
@@ -403,19 +403,19 @@ def _remove_emojis(text: str) -> str:
|
|
|
403
403
|
"""Remove emojis from the provided text."""
|
|
404
404
|
emoji_pattern = re.compile(
|
|
405
405
|
"["
|
|
406
|
-
"\
|
|
407
|
-
"\
|
|
408
|
-
"\
|
|
409
|
-
"\
|
|
410
|
-
"\U00002702-\
|
|
411
|
-
"\
|
|
406
|
+
"\U0001f600-\U0001f64f" # Emoticons
|
|
407
|
+
"\U0001f300-\U0001f5ff" # Symbols & Pictographs
|
|
408
|
+
"\U0001f680-\U0001f6ff" # Transport & Map Symbols
|
|
409
|
+
"\U0001f1e0-\U0001f1ff" # Flags
|
|
410
|
+
"\U00002702-\U000027b0" # Dingbats
|
|
411
|
+
"\U000024c2-\U0001f251"
|
|
412
412
|
"]+",
|
|
413
413
|
flags=re.UNICODE,
|
|
414
414
|
)
|
|
415
415
|
return emoji_pattern.sub(r"", text)
|
|
416
416
|
|
|
417
417
|
|
|
418
|
-
def print_json_error(msg: str, e:
|
|
418
|
+
def print_json_error(msg: str, e: typer.Exit | Exception) -> None:
|
|
419
419
|
"""Print error message as JSON."""
|
|
420
420
|
Console().print_json(
|
|
421
421
|
_json.dumps(
|
flwr/common/message.py
CHANGED
|
@@ -105,7 +105,7 @@ class Message(InflatableObject):
|
|
|
105
105
|
"""
|
|
106
106
|
|
|
107
107
|
@overload
|
|
108
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
108
|
+
def __init__( # pylint: disable=too-many-arguments
|
|
109
109
|
self,
|
|
110
110
|
content: RecordDict,
|
|
111
111
|
dst_node_id: int,
|
|
@@ -116,12 +116,12 @@ class Message(InflatableObject):
|
|
|
116
116
|
) -> None: ...
|
|
117
117
|
|
|
118
118
|
@overload
|
|
119
|
-
def __init__(
|
|
119
|
+
def __init__(
|
|
120
120
|
self, content: RecordDict, *, reply_to: Message, ttl: float | None = None
|
|
121
121
|
) -> None: ...
|
|
122
122
|
|
|
123
123
|
@overload
|
|
124
|
-
def __init__(
|
|
124
|
+
def __init__(
|
|
125
125
|
self, error: Error, *, reply_to: Message, ttl: float | None = None
|
|
126
126
|
) -> None: ...
|
|
127
127
|
|
|
@@ -511,7 +511,7 @@ def _check_arg_types( # pylint: disable=too-many-arguments, R0917
|
|
|
511
511
|
and (message_type is None or isinstance(message_type, str))
|
|
512
512
|
and (content is None or isinstance(content, RecordDict))
|
|
513
513
|
and (error is None or isinstance(error, Error))
|
|
514
|
-
and (ttl is None or isinstance(ttl, (int
|
|
514
|
+
and (ttl is None or isinstance(ttl, (int | float)))
|
|
515
515
|
and (group_id is None or isinstance(group_id, str))
|
|
516
516
|
and (reply_to is None or isinstance(reply_to, Message))
|
|
517
517
|
and (metadata is None or isinstance(metadata, Metadata))
|
flwr/common/object_ref.py
CHANGED
|
@@ -21,7 +21,7 @@ import sys
|
|
|
21
21
|
from importlib.util import find_spec
|
|
22
22
|
from pathlib import Path
|
|
23
23
|
from threading import Lock
|
|
24
|
-
from typing import Any
|
|
24
|
+
from typing import Any
|
|
25
25
|
|
|
26
26
|
OBJECT_REF_HELP_STR = """
|
|
27
27
|
\n\nThe object reference string should have the form <module>:<attribute>. Valid
|
|
@@ -31,15 +31,15 @@ attribute.
|
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
_current_sys_path:
|
|
34
|
+
_current_sys_path: str | None = None
|
|
35
35
|
_import_lock = Lock()
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def validate(
|
|
39
39
|
module_attribute_str: str,
|
|
40
40
|
check_module: bool = True,
|
|
41
|
-
project_dir:
|
|
42
|
-
) -> tuple[bool,
|
|
41
|
+
project_dir: str | Path | None = None,
|
|
42
|
+
) -> tuple[bool, str | None]:
|
|
43
43
|
"""Validate object reference.
|
|
44
44
|
|
|
45
45
|
Parameters
|
|
@@ -114,7 +114,7 @@ def validate(
|
|
|
114
114
|
def load_app( # pylint: disable= too-many-branches
|
|
115
115
|
module_attribute_str: str,
|
|
116
116
|
error_type: type[Exception],
|
|
117
|
-
project_dir:
|
|
117
|
+
project_dir: str | Path | None = None,
|
|
118
118
|
) -> Any:
|
|
119
119
|
"""Return the object specified in a module attribute string.
|
|
120
120
|
|
|
@@ -194,12 +194,12 @@ def _unload_modules(project_dir: Path) -> None:
|
|
|
194
194
|
"""Unload modules from the project directory."""
|
|
195
195
|
dir_str = str(project_dir.absolute())
|
|
196
196
|
for name, m in list(sys.modules.items()):
|
|
197
|
-
path:
|
|
197
|
+
path: str | None = getattr(m, "__file__", None)
|
|
198
198
|
if path is not None and path.startswith(dir_str):
|
|
199
199
|
del sys.modules[name]
|
|
200
200
|
|
|
201
201
|
|
|
202
|
-
def _set_sys_path(directory:
|
|
202
|
+
def _set_sys_path(directory: str | Path | None) -> None:
|
|
203
203
|
"""Set the system path."""
|
|
204
204
|
if directory is None:
|
|
205
205
|
directory = Path.cwd()
|
flwr/common/record/array.py
CHANGED
|
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, cast, overload
|
|
|
25
25
|
|
|
26
26
|
import numpy as np
|
|
27
27
|
|
|
28
|
-
from ..constant import
|
|
28
|
+
from ..constant import FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE, SType
|
|
29
29
|
from ..inflatable import (
|
|
30
30
|
InflatableObject,
|
|
31
31
|
add_header_to_object_body,
|
|
@@ -117,15 +117,15 @@ class Array(InflatableObject):
|
|
|
117
117
|
data: bytes
|
|
118
118
|
|
|
119
119
|
@overload
|
|
120
|
-
def __init__(
|
|
120
|
+
def __init__(
|
|
121
121
|
self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
|
|
122
122
|
) -> None: ...
|
|
123
123
|
|
|
124
124
|
@overload
|
|
125
|
-
def __init__(self, ndarray: NDArray) -> None: ...
|
|
125
|
+
def __init__(self, ndarray: NDArray) -> None: ...
|
|
126
126
|
|
|
127
127
|
@overload
|
|
128
|
-
def __init__(self, torch_tensor: torch.Tensor) -> None: ...
|
|
128
|
+
def __init__(self, torch_tensor: torch.Tensor) -> None: ...
|
|
129
129
|
|
|
130
130
|
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
|
131
131
|
self,
|
|
@@ -272,8 +272,8 @@ class Array(InflatableObject):
|
|
|
272
272
|
chunks: list[tuple[str, InflatableObject]] = []
|
|
273
273
|
# memoryview allows for zero-copy slicing
|
|
274
274
|
data_view = memoryview(self.data)
|
|
275
|
-
for start in range(0, len(data_view),
|
|
276
|
-
end = min(start +
|
|
275
|
+
for start in range(0, len(data_view), FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE):
|
|
276
|
+
end = min(start + FLWR_PRIVATE_MAX_ARRAY_CHUNK_SIZE, len(data_view))
|
|
277
277
|
ac = ArrayChunk(data_view[start:end])
|
|
278
278
|
chunks.append((ac.object_id, ac))
|
|
279
279
|
|
|
@@ -63,7 +63,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
63
63
|
|
|
64
64
|
A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
|
|
65
65
|
including model parameters, gradients, embeddings or non-parameter arrays.
|
|
66
|
-
Internally, this behaves similarly to an ``
|
|
66
|
+
Internally, this behaves similarly to an ``dict[str, Array]``.
|
|
67
67
|
An ``ArrayRecord`` can be viewed as an equivalent to PyTorch's ``state_dict``,
|
|
68
68
|
but it holds arrays in a serialized form.
|
|
69
69
|
|
|
@@ -80,13 +80,13 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
80
80
|
|
|
81
81
|
Parameters
|
|
82
82
|
----------
|
|
83
|
-
array_dict : Optional[
|
|
83
|
+
array_dict : Optional[dict[str, Array]] (default: None)
|
|
84
84
|
An existing dictionary containing named :class:`Array` instances. If
|
|
85
85
|
provided, these entries will be used directly to populate the record.
|
|
86
86
|
numpy_ndarrays : Optional[list[NDArray]] (default: None)
|
|
87
87
|
A list of NumPy arrays. Each array will be automatically converted
|
|
88
88
|
into an :class:`Array` and stored in this record with generated keys.
|
|
89
|
-
torch_state_dict : Optional[
|
|
89
|
+
torch_state_dict : Optional[dict[str, torch.Tensor]] (default: None)
|
|
90
90
|
A PyTorch ``state_dict`` (``str`` keys to ``torch.Tensor`` values). Each
|
|
91
91
|
tensor will be converted into an :class:`Array` and stored in this record.
|
|
92
92
|
keep_input : bool (default: True)
|
|
@@ -127,22 +127,23 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
127
127
|
"""
|
|
128
128
|
|
|
129
129
|
@overload
|
|
130
|
-
def __init__(self) -> None: ...
|
|
130
|
+
def __init__(self) -> None: ...
|
|
131
131
|
|
|
132
132
|
@overload
|
|
133
|
-
def __init__(
|
|
134
|
-
self, array_dict:
|
|
133
|
+
def __init__(
|
|
134
|
+
self, array_dict: dict[str, Array], *, keep_input: bool = True
|
|
135
135
|
) -> None: ...
|
|
136
136
|
|
|
137
137
|
@overload
|
|
138
|
-
def __init__(
|
|
138
|
+
def __init__(
|
|
139
139
|
self, numpy_ndarrays: list[NDArray], *, keep_input: bool = True
|
|
140
140
|
) -> None: ...
|
|
141
141
|
|
|
142
142
|
@overload
|
|
143
|
-
def __init__(
|
|
143
|
+
def __init__(
|
|
144
144
|
self,
|
|
145
|
-
|
|
145
|
+
# `Any` is required for PyTorch state dict because they are not strongly typed
|
|
146
|
+
torch_state_dict: dict[str, torch.Tensor] | dict[str, Any],
|
|
146
147
|
*,
|
|
147
148
|
keep_input: bool = True,
|
|
148
149
|
) -> None: ...
|
|
@@ -151,15 +152,15 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
151
152
|
self,
|
|
152
153
|
*args: Any,
|
|
153
154
|
numpy_ndarrays: list[NDArray] | None = None,
|
|
154
|
-
torch_state_dict:
|
|
155
|
-
array_dict:
|
|
155
|
+
torch_state_dict: dict[str, torch.Tensor] | dict[str, Any] | None = None,
|
|
156
|
+
array_dict: dict[str, Array] | None = None,
|
|
156
157
|
keep_input: bool = True,
|
|
157
158
|
) -> None:
|
|
158
159
|
super().__init__(_check_key, _check_value)
|
|
159
160
|
|
|
160
161
|
# Determine the initialization method and validates input arguments.
|
|
161
162
|
# Support the following initialization formats:
|
|
162
|
-
# 1. cls(array_dict:
|
|
163
|
+
# 1. cls(array_dict: dict[str, Array], keep_input: bool)
|
|
163
164
|
# 2. cls(numpy_ndarrays: list[NDArray], keep_input: bool)
|
|
164
165
|
# 3. cls(torch_state_dict: dict[str, torch.Tensor], keep_input: bool)
|
|
165
166
|
|
|
@@ -204,7 +205,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
204
205
|
and all(isinstance(k, str) for k in arg.keys())
|
|
205
206
|
and all(isinstance(v, Array) for v in arg.values())
|
|
206
207
|
):
|
|
207
|
-
array_dict = cast(
|
|
208
|
+
array_dict = cast(dict[str, Array], arg)
|
|
208
209
|
converted = self.from_array_dict(array_dict, keep_input=keep_input)
|
|
209
210
|
self.__dict__.update(converted.__dict__)
|
|
210
211
|
return
|
|
@@ -230,9 +231,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
230
231
|
and all(isinstance(k, str) for k in arg.keys())
|
|
231
232
|
and all(isinstance(v, torch.Tensor) for v in arg.values())
|
|
232
233
|
):
|
|
233
|
-
torch_state_dict = cast(
|
|
234
|
-
OrderedDict[str, torch.Tensor], arg # type: ignore
|
|
235
|
-
)
|
|
234
|
+
torch_state_dict = cast(dict[str, torch.Tensor], arg) # type: ignore
|
|
236
235
|
converted = self.from_torch_state_dict(
|
|
237
236
|
torch_state_dict, keep_input=keep_input
|
|
238
237
|
)
|
|
@@ -244,7 +243,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
244
243
|
@classmethod
|
|
245
244
|
def from_array_dict(
|
|
246
245
|
cls,
|
|
247
|
-
array_dict:
|
|
246
|
+
array_dict: dict[str, Array],
|
|
248
247
|
*,
|
|
249
248
|
keep_input: bool = True,
|
|
250
249
|
) -> ArrayRecord:
|
|
@@ -291,7 +290,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
291
290
|
@classmethod
|
|
292
291
|
def from_torch_state_dict(
|
|
293
292
|
cls,
|
|
294
|
-
state_dict:
|
|
293
|
+
state_dict: dict[str, torch.Tensor],
|
|
295
294
|
*,
|
|
296
295
|
keep_input: bool = True,
|
|
297
296
|
) -> ArrayRecord:
|
|
@@ -424,9 +423,7 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
|
424
423
|
|
|
425
424
|
# Instantiate new ArrayRecord
|
|
426
425
|
return ArrayRecord(
|
|
427
|
-
|
|
428
|
-
{name: children[object_id] for name, object_id in array_refs.items()}
|
|
429
|
-
)
|
|
426
|
+
{name: children[object_id] for name, object_id in array_refs.items()}
|
|
430
427
|
)
|
|
431
428
|
|
|
432
429
|
@property
|
|
@@ -142,11 +142,11 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
|
|
|
142
142
|
var_bytes = 0
|
|
143
143
|
if isinstance(value, bool):
|
|
144
144
|
var_bytes = 1
|
|
145
|
-
elif isinstance(value, (int
|
|
145
|
+
elif isinstance(value, (int | float)):
|
|
146
146
|
var_bytes = (
|
|
147
147
|
8 # the profobufing represents int/floats in ConfigRecords as 64bit
|
|
148
148
|
)
|
|
149
|
-
if isinstance(value, (str
|
|
149
|
+
if isinstance(value, (str | bytes)):
|
|
150
150
|
var_bytes = len(value)
|
|
151
151
|
if var_bytes == 0:
|
|
152
152
|
raise ValueError(
|
|
@@ -159,7 +159,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
|
|
|
159
159
|
|
|
160
160
|
for k, v in self.items():
|
|
161
161
|
if isinstance(v, list):
|
|
162
|
-
if isinstance(v[0], (bytes
|
|
162
|
+
if isinstance(v[0], (bytes | str)):
|
|
163
163
|
# not all str are of equal length necessarily
|
|
164
164
|
# for both the footprint of each element is 1 Byte
|
|
165
165
|
num_bytes += int(sum(len(s) for s in v)) # type: ignore
|
flwr/common/record/recorddict.py
CHANGED
|
@@ -20,7 +20,7 @@ from __future__ import annotations
|
|
|
20
20
|
import json
|
|
21
21
|
from logging import WARN
|
|
22
22
|
from textwrap import indent
|
|
23
|
-
from typing import TypeVar,
|
|
23
|
+
from typing import TypeVar, cast
|
|
24
24
|
|
|
25
25
|
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
26
26
|
from ..logger import log
|
|
@@ -29,7 +29,7 @@ from .configrecord import ConfigRecord
|
|
|
29
29
|
from .metricrecord import MetricRecord
|
|
30
30
|
from .typeddict import TypedDict
|
|
31
31
|
|
|
32
|
-
RecordType =
|
|
32
|
+
RecordType = ArrayRecord | MetricRecord | ConfigRecord
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class _WarningTracker:
|
|
@@ -59,7 +59,7 @@ def _check_key(key: str) -> None:
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def _check_value(value: RecordType) -> None:
|
|
62
|
-
if not isinstance(value, (ArrayRecord
|
|
62
|
+
if not isinstance(value, (ArrayRecord | MetricRecord | ConfigRecord)):
|
|
63
63
|
raise TypeError(
|
|
64
64
|
f"Expected `{ArrayRecord.__name__}`, `{MetricRecord.__name__}`, "
|
|
65
65
|
f"or `{ConfigRecord.__name__}` but received "
|
|
@@ -76,7 +76,7 @@ class _SyncedDict(TypedDict[str, T]):
|
|
|
76
76
|
"""
|
|
77
77
|
|
|
78
78
|
def __init__(self, ref_recorddict: RecordDict, allowed_type: type[T]) -> None:
|
|
79
|
-
if not issubclass(allowed_type, (ArrayRecord
|
|
79
|
+
if not issubclass(allowed_type, (ArrayRecord | MetricRecord | ConfigRecord)):
|
|
80
80
|
raise TypeError(f"{allowed_type} is not a valid type.")
|
|
81
81
|
super().__init__(_check_key, self.check_value)
|
|
82
82
|
self.recorddict = ref_recorddict
|
|
@@ -341,7 +341,7 @@ class RecordDict(TypedDict[str, RecordType], InflatableObject):
|
|
|
341
341
|
|
|
342
342
|
# Ensure children are one of the *Record objects exepecte in a RecordDict
|
|
343
343
|
if not all(
|
|
344
|
-
isinstance(ch, (ArrayRecord
|
|
344
|
+
isinstance(ch, (ArrayRecord | ConfigRecord | MetricRecord))
|
|
345
345
|
for ch in children.values()
|
|
346
346
|
):
|
|
347
347
|
raise ValueError(
|
flwr/common/record/typeddict.py
CHANGED
|
@@ -15,8 +15,15 @@
|
|
|
15
15
|
"""Typed dict base class for *Records."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import
|
|
19
|
-
|
|
18
|
+
from collections.abc import (
|
|
19
|
+
Callable,
|
|
20
|
+
ItemsView,
|
|
21
|
+
Iterator,
|
|
22
|
+
KeysView,
|
|
23
|
+
MutableMapping,
|
|
24
|
+
ValuesView,
|
|
25
|
+
)
|
|
26
|
+
from typing import Generic, TypeVar, cast
|
|
20
27
|
|
|
21
28
|
from typing_extensions import Self
|
|
22
29
|
|