flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/__init__.py +1 -1
- flwr/cli/app.py +21 -2
- flwr/cli/build.py +83 -58
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +53 -17
- flwr/cli/example.py +1 -1
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +4 -4
- flwr/cli/login/__init__.py +1 -1
- flwr/cli/login/login.py +15 -8
- flwr/cli/ls.py +16 -37
- flwr/cli/new/__init__.py +1 -1
- flwr/cli/new/new.py +4 -4
- flwr/cli/new/templates/__init__.py +1 -1
- flwr/cli/new/templates/app/__init__.py +1 -1
- flwr/cli/new/templates/app/code/__init__.py +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -1
- flwr/cli/run/run.py +11 -19
- flwr/cli/stop.py +3 -3
- flwr/cli/utils.py +42 -17
- flwr/client/__init__.py +3 -3
- flwr/client/client.py +1 -1
- flwr/client/client_app.py +140 -138
- flwr/client/clientapp/__init__.py +1 -8
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +1 -1
- flwr/client/grpc_adapter_client/connection.py +5 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +131 -61
- flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/centraldp_mods.py +1 -1
- flwr/client/mod/comms_mods.py +39 -20
- flwr/client/mod/localdp_mod.py +6 -6
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +174 -68
- flwr/client/run_info_store.py +1 -1
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +3 -3
- flwr/common/address.py +1 -1
- flwr/common/args.py +1 -1
- flwr/common/auth_plugin/__init__.py +3 -1
- flwr/common/auth_plugin/auth_plugin.py +30 -4
- flwr/common/config.py +1 -1
- flwr/common/constant.py +37 -8
- flwr/common/context.py +1 -1
- flwr/common/date.py +1 -1
- flwr/common/differential_privacy.py +1 -1
- flwr/common/differential_privacy_constants.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit.py +6 -6
- flwr/common/exit_handlers.py +31 -1
- flwr/common/grpc.py +1 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/logger.py +1 -1
- flwr/common/message.py +137 -252
- flwr/common/object_ref.py +1 -1
- flwr/common/parameter.py +1 -1
- flwr/common/pyproject.py +1 -1
- flwr/common/record/__init__.py +3 -2
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +121 -243
- flwr/common/record/configrecord.py +71 -16
- flwr/common/record/conversion_utils.py +2 -2
- flwr/common/record/metricrecord.py +71 -20
- flwr/common/record/recorddict.py +207 -90
- flwr/common/record/typeddict.py +1 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +15 -11
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +52 -30
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +60 -184
- flwr/common/serde_utils.py +175 -0
- flwr/common/telemetry.py +2 -2
- flwr/common/typing.py +6 -4
- flwr/common/version.py +1 -1
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +71 -211
- flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
- flwr/{client → compat/client}/grpc_client/connection.py +13 -13
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/__init__.py +1 -1
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +2 -2
- flwr/server/app.py +69 -187
- flwr/server/client_manager.py +1 -1
- flwr/server/client_proxy.py +1 -1
- flwr/server/compat/__init__.py +1 -1
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +51 -29
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/criterion.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/history.py +1 -1
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +1 -1
- flwr/server/server_app.py +65 -58
- flwr/server/server_config.py +1 -1
- flwr/server/serverapp/__init__.py +1 -1
- flwr/server/serverapp/app.py +19 -1
- flwr/server/serverapp_components.py +1 -1
- flwr/server/strategy/__init__.py +1 -1
- flwr/server/strategy/aggregate.py +1 -1
- flwr/server/strategy/bulyan.py +2 -2
- flwr/server/strategy/dp_adaptive_clipping.py +17 -17
- flwr/server/strategy/dp_fixed_clipping.py +17 -17
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedtrimmedavg.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +3 -2
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/strategy/strategy.py +1 -1
- flwr/server/superlink/__init__.py +1 -1
- flwr/server/superlink/ffs/__init__.py +3 -1
- flwr/server/superlink/ffs/disk_ffs.py +1 -1
- flwr/server/superlink/ffs/ffs.py +1 -1
- flwr/server/superlink/ffs/ffs_factory.py +1 -1
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +7 -4
- flwr/server/superlink/linkstate/__init__.py +1 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
- flwr/server/superlink/linkstate/linkstate.py +54 -21
- flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
- flwr/server/superlink/linkstate/utils.py +34 -30
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/__init__.py +1 -1
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +45 -3
- flwr/server/typing.py +1 -1
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +3 -3
- flwr/server/workflow/__init__.py +1 -1
- flwr/server/workflow/constant.py +1 -1
- flwr/server/workflow/default_workflows.py +1 -1
- flwr/server/workflow/secure_aggregation/__init__.py +1 -1
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +18 -1
- flwr/simulation/legacy_app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/ray_transport/utils.py +1 -1
- flwr/simulation/run_simulation.py +2 -2
- flwr/simulation/simulationio_connection.py +1 -1
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/__init__.py +1 -1
- flwr/superexec/app.py +1 -1
- flwr/superexec/deployment.py +7 -3
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +8 -4
- flwr/superexec/exec_servicer.py +126 -24
- flwr/superexec/exec_user_auth_interceptor.py +38 -9
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -2
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +1 -8
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/{client → supernode}/nodestate/__init__.py +1 -1
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
- flwr-1.19.0.dist-info/RECORD +365 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- flwr-1.17.0.dist-info/LICENSE +0 -202
- flwr-1.17.0.dist-info/RECORD +0 -333
|
@@ -12,38 +12,31 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""ArrayRecord
|
|
15
|
+
"""ArrayRecord."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
20
|
import gc
|
|
21
|
+
import json
|
|
21
22
|
import sys
|
|
22
23
|
from collections import OrderedDict
|
|
23
|
-
from dataclasses import dataclass
|
|
24
|
-
from io import BytesIO
|
|
25
24
|
from logging import WARN
|
|
26
25
|
from typing import TYPE_CHECKING, Any, cast, overload
|
|
27
26
|
|
|
28
27
|
import numpy as np
|
|
29
28
|
|
|
30
|
-
from ..constant import GC_THRESHOLD
|
|
29
|
+
from ..constant import GC_THRESHOLD
|
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
31
31
|
from ..logger import log
|
|
32
32
|
from ..typing import NDArray
|
|
33
|
+
from .array import Array
|
|
33
34
|
from .typeddict import TypedDict
|
|
34
35
|
|
|
35
36
|
if TYPE_CHECKING:
|
|
36
37
|
import torch
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
def _raise_array_init_error() -> None:
|
|
40
|
-
raise TypeError(
|
|
41
|
-
f"Invalid arguments for {Array.__qualname__}. Expected either a "
|
|
42
|
-
"PyTorch tensor, a NumPy ndarray, or explicit"
|
|
43
|
-
" dtype/shape/stype/data values."
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
|
|
47
40
|
def _raise_array_record_init_error() -> None:
|
|
48
41
|
raise TypeError(
|
|
49
42
|
f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
|
|
@@ -52,217 +45,6 @@ def _raise_array_record_init_error() -> None:
|
|
|
52
45
|
)
|
|
53
46
|
|
|
54
47
|
|
|
55
|
-
@dataclass
|
|
56
|
-
class Array:
|
|
57
|
-
"""Array type.
|
|
58
|
-
|
|
59
|
-
A dataclass containing serialized data from an array-like or tensor-like object
|
|
60
|
-
along with metadata about it. The class can be initialized in one of three ways:
|
|
61
|
-
|
|
62
|
-
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
|
|
63
|
-
2. By providing a NumPy ndarray (via the `ndarray` argument).
|
|
64
|
-
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
|
|
65
|
-
|
|
66
|
-
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
|
|
67
|
-
derived from the input. In scenario (1), these fields must be specified manually.
|
|
68
|
-
|
|
69
|
-
Parameters
|
|
70
|
-
----------
|
|
71
|
-
dtype : Optional[str] (default: None)
|
|
72
|
-
A string representing the data type of the serialized object (e.g. `"float32"`).
|
|
73
|
-
Only required if you are not passing in a ndarray or a tensor.
|
|
74
|
-
|
|
75
|
-
shape : Optional[list[int]] (default: None)
|
|
76
|
-
A list representing the shape of the unserialized array-like object. Only
|
|
77
|
-
required if you are not passing in a ndarray or a tensor.
|
|
78
|
-
|
|
79
|
-
stype : Optional[str] (default: None)
|
|
80
|
-
A string indicating the serialization mechanism used to generate the bytes in
|
|
81
|
-
`data` from an array-like or tensor-like object. Only required if you are not
|
|
82
|
-
passing in a ndarray or a tensor.
|
|
83
|
-
|
|
84
|
-
data : Optional[bytes] (default: None)
|
|
85
|
-
A buffer of bytes containing the data. Only required if you are not passing in
|
|
86
|
-
a ndarray or a tensor.
|
|
87
|
-
|
|
88
|
-
ndarray : Optional[NDArray] (default: None)
|
|
89
|
-
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
|
|
90
|
-
fields are derived automatically from it.
|
|
91
|
-
|
|
92
|
-
torch_tensor : Optional[torch.Tensor] (default: None)
|
|
93
|
-
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
|
|
94
|
-
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
|
|
95
|
-
will be derived automatically from it.
|
|
96
|
-
|
|
97
|
-
Examples
|
|
98
|
-
--------
|
|
99
|
-
Initializing by specifying all fields directly:
|
|
100
|
-
|
|
101
|
-
>>> arr1 = Array(
|
|
102
|
-
>>> dtype="float32",
|
|
103
|
-
>>> shape=[3, 3],
|
|
104
|
-
>>> stype="numpy.ndarray",
|
|
105
|
-
>>> data=b"serialized_data...",
|
|
106
|
-
>>> )
|
|
107
|
-
|
|
108
|
-
Initializing with a NumPy ndarray:
|
|
109
|
-
|
|
110
|
-
>>> import numpy as np
|
|
111
|
-
>>> arr2 = Array(np.random.randn(3, 3))
|
|
112
|
-
|
|
113
|
-
Initializing with a PyTorch tensor:
|
|
114
|
-
|
|
115
|
-
>>> import torch
|
|
116
|
-
>>> arr3 = Array(torch.randn(3, 3))
|
|
117
|
-
"""
|
|
118
|
-
|
|
119
|
-
dtype: str
|
|
120
|
-
shape: list[int]
|
|
121
|
-
stype: str
|
|
122
|
-
data: bytes
|
|
123
|
-
|
|
124
|
-
@overload
|
|
125
|
-
def __init__( # noqa: E704
|
|
126
|
-
self, dtype: str, shape: list[int], stype: str, data: bytes
|
|
127
|
-
) -> None: ...
|
|
128
|
-
|
|
129
|
-
@overload
|
|
130
|
-
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
|
|
131
|
-
|
|
132
|
-
@overload
|
|
133
|
-
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
|
|
134
|
-
|
|
135
|
-
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
|
136
|
-
self,
|
|
137
|
-
*args: Any,
|
|
138
|
-
dtype: str | None = None,
|
|
139
|
-
shape: list[int] | None = None,
|
|
140
|
-
stype: str | None = None,
|
|
141
|
-
data: bytes | None = None,
|
|
142
|
-
ndarray: NDArray | None = None,
|
|
143
|
-
torch_tensor: torch.Tensor | None = None,
|
|
144
|
-
) -> None:
|
|
145
|
-
# Determine the initialization method and validate input arguments.
|
|
146
|
-
# Support three initialization formats:
|
|
147
|
-
# 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
|
|
148
|
-
# 2. Array(ndarray: NDArray)
|
|
149
|
-
# 3. Array(torch_tensor: torch.Tensor)
|
|
150
|
-
|
|
151
|
-
# Initialize all arguments
|
|
152
|
-
# If more than 4 positional arguments are provided, raise an error.
|
|
153
|
-
if len(args) > 4:
|
|
154
|
-
_raise_array_init_error()
|
|
155
|
-
all_args = [None] * 4
|
|
156
|
-
for i, arg in enumerate(args):
|
|
157
|
-
all_args[i] = arg
|
|
158
|
-
init_method: str | None = None # Track which init method is being used
|
|
159
|
-
|
|
160
|
-
# Try to assign a value to all_args[index] if it's not already set.
|
|
161
|
-
# If an initialization method is provided, update init_method.
|
|
162
|
-
def _try_set_arg(index: int, arg: Any, method: str) -> None:
|
|
163
|
-
# Skip if arg is None
|
|
164
|
-
if arg is None:
|
|
165
|
-
return
|
|
166
|
-
# Raise an error if all_args[index] is already set
|
|
167
|
-
if all_args[index] is not None:
|
|
168
|
-
_raise_array_init_error()
|
|
169
|
-
# Raise an error if a different initialization method is already set
|
|
170
|
-
nonlocal init_method
|
|
171
|
-
if init_method is not None and init_method != method:
|
|
172
|
-
_raise_array_init_error()
|
|
173
|
-
# Set init_method and all_args[index]
|
|
174
|
-
if init_method is None:
|
|
175
|
-
init_method = method
|
|
176
|
-
all_args[index] = arg
|
|
177
|
-
|
|
178
|
-
# Try to set keyword arguments in all_args
|
|
179
|
-
_try_set_arg(0, dtype, "direct")
|
|
180
|
-
_try_set_arg(1, shape, "direct")
|
|
181
|
-
_try_set_arg(2, stype, "direct")
|
|
182
|
-
_try_set_arg(3, data, "direct")
|
|
183
|
-
_try_set_arg(0, ndarray, "ndarray")
|
|
184
|
-
_try_set_arg(0, torch_tensor, "torch_tensor")
|
|
185
|
-
|
|
186
|
-
# Check if all arguments are correctly set
|
|
187
|
-
all_args = [arg for arg in all_args if arg is not None]
|
|
188
|
-
|
|
189
|
-
# Handle direct field initialization
|
|
190
|
-
if not init_method or init_method == "direct":
|
|
191
|
-
if (
|
|
192
|
-
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
|
|
193
|
-
and isinstance(all_args[0], str)
|
|
194
|
-
and isinstance(all_args[1], list)
|
|
195
|
-
and all(isinstance(i, int) for i in all_args[1])
|
|
196
|
-
and isinstance(all_args[2], str)
|
|
197
|
-
and isinstance(all_args[3], bytes)
|
|
198
|
-
):
|
|
199
|
-
self.dtype, self.shape, self.stype, self.data = all_args
|
|
200
|
-
return
|
|
201
|
-
|
|
202
|
-
# Handle NumPy array
|
|
203
|
-
if not init_method or init_method == "ndarray":
|
|
204
|
-
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
|
|
205
|
-
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
|
|
206
|
-
return
|
|
207
|
-
|
|
208
|
-
# Handle PyTorch tensor
|
|
209
|
-
if not init_method or init_method == "torch_tensor":
|
|
210
|
-
if (
|
|
211
|
-
len(all_args) == 1
|
|
212
|
-
and "torch" in sys.modules
|
|
213
|
-
and isinstance(all_args[0], sys.modules["torch"].Tensor)
|
|
214
|
-
):
|
|
215
|
-
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
|
|
216
|
-
return
|
|
217
|
-
|
|
218
|
-
_raise_array_init_error()
|
|
219
|
-
|
|
220
|
-
@classmethod
|
|
221
|
-
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
|
|
222
|
-
"""Create Array from NumPy ndarray."""
|
|
223
|
-
assert isinstance(
|
|
224
|
-
ndarray, np.ndarray
|
|
225
|
-
), f"Expected NumPy ndarray, got {type(ndarray)}"
|
|
226
|
-
buffer = BytesIO()
|
|
227
|
-
# WARNING: NEVER set allow_pickle to true.
|
|
228
|
-
# Reason: loading pickled data can execute arbitrary code
|
|
229
|
-
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
|
|
230
|
-
np.save(buffer, ndarray, allow_pickle=False)
|
|
231
|
-
data = buffer.getvalue()
|
|
232
|
-
return Array(
|
|
233
|
-
dtype=str(ndarray.dtype),
|
|
234
|
-
shape=list(ndarray.shape),
|
|
235
|
-
stype=SType.NUMPY,
|
|
236
|
-
data=data,
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
@classmethod
|
|
240
|
-
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
|
|
241
|
-
"""Create Array from PyTorch tensor."""
|
|
242
|
-
if not (torch := sys.modules.get("torch")):
|
|
243
|
-
raise RuntimeError(
|
|
244
|
-
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
assert isinstance(
|
|
248
|
-
tensor, torch.Tensor
|
|
249
|
-
), f"Expected PyTorch Tensor, got {type(tensor)}"
|
|
250
|
-
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
|
|
251
|
-
|
|
252
|
-
def numpy(self) -> NDArray:
|
|
253
|
-
"""Return the array as a NumPy array."""
|
|
254
|
-
if self.stype != SType.NUMPY:
|
|
255
|
-
raise TypeError(
|
|
256
|
-
f"Unsupported serialization type for numpy conversion: '{self.stype}'"
|
|
257
|
-
)
|
|
258
|
-
bytes_io = BytesIO(self.data)
|
|
259
|
-
# WARNING: NEVER set allow_pickle to true.
|
|
260
|
-
# Reason: loading pickled data can execute arbitrary code
|
|
261
|
-
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
|
|
262
|
-
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
|
263
|
-
return cast(NDArray, ndarray_deserialized)
|
|
264
|
-
|
|
265
|
-
|
|
266
48
|
def _check_key(key: str) -> None:
|
|
267
49
|
"""Check if key is of expected type."""
|
|
268
50
|
if not isinstance(key, str):
|
|
@@ -276,7 +58,7 @@ def _check_value(value: Array) -> None:
|
|
|
276
58
|
)
|
|
277
59
|
|
|
278
60
|
|
|
279
|
-
class ArrayRecord(TypedDict[str, Array]):
|
|
61
|
+
class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
280
62
|
"""Array record.
|
|
281
63
|
|
|
282
64
|
A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
|
|
@@ -315,33 +97,33 @@ class ArrayRecord(TypedDict[str, Array]):
|
|
|
315
97
|
|
|
316
98
|
Examples
|
|
317
99
|
--------
|
|
318
|
-
Initializing an empty ArrayRecord
|
|
100
|
+
Initializing an empty ArrayRecord::
|
|
319
101
|
|
|
320
|
-
|
|
102
|
+
record = ArrayRecord()
|
|
321
103
|
|
|
322
|
-
Initializing with a dictionary of :class:`Array
|
|
104
|
+
Initializing with a dictionary of :class:`Array`::
|
|
323
105
|
|
|
324
|
-
|
|
325
|
-
|
|
106
|
+
arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
|
|
107
|
+
record = ArrayRecord({"weight": arr})
|
|
326
108
|
|
|
327
|
-
Initializing with a list of NumPy arrays
|
|
109
|
+
Initializing with a list of NumPy arrays::
|
|
328
110
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
111
|
+
import numpy as np
|
|
112
|
+
arr1 = np.random.randn(3, 3)
|
|
113
|
+
arr2 = np.random.randn(2, 2)
|
|
114
|
+
record = ArrayRecord([arr1, arr2])
|
|
333
115
|
|
|
334
|
-
Initializing with a PyTorch model state_dict
|
|
116
|
+
Initializing with a PyTorch model state_dict::
|
|
335
117
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
118
|
+
import torch.nn as nn
|
|
119
|
+
model = nn.Linear(10, 5)
|
|
120
|
+
record = ArrayRecord(model.state_dict())
|
|
339
121
|
|
|
340
|
-
Initializing with a TensorFlow model weights (a list of NumPy arrays)
|
|
122
|
+
Initializing with a TensorFlow model weights (a list of NumPy arrays)::
|
|
341
123
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
124
|
+
import tensorflow as tf
|
|
125
|
+
model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
|
|
126
|
+
record = ArrayRecord(model.get_weights())
|
|
345
127
|
"""
|
|
346
128
|
|
|
347
129
|
@overload
|
|
@@ -470,7 +252,7 @@ class ArrayRecord(TypedDict[str, Array]):
|
|
|
470
252
|
record = ArrayRecord()
|
|
471
253
|
for k, v in array_dict.items():
|
|
472
254
|
record[k] = Array(
|
|
473
|
-
dtype=v.dtype, shape=
|
|
255
|
+
dtype=v.dtype, shape=tuple(v.shape), stype=v.stype, data=v.data
|
|
474
256
|
)
|
|
475
257
|
if not keep_input:
|
|
476
258
|
array_dict.clear()
|
|
@@ -585,6 +367,102 @@ class ArrayRecord(TypedDict[str, Array]):
|
|
|
585
367
|
|
|
586
368
|
return num_bytes
|
|
587
369
|
|
|
370
|
+
@property
|
|
371
|
+
def children(self) -> dict[str, InflatableObject]:
|
|
372
|
+
"""Return a dictionary of Arrays with their Object IDs as keys."""
|
|
373
|
+
return {arr.object_id: arr for arr in self.values()}
|
|
374
|
+
|
|
375
|
+
def deflate(self) -> bytes:
|
|
376
|
+
"""Deflate the ArrayRecord."""
|
|
377
|
+
# array_name: array_object_id mapping
|
|
378
|
+
array_refs: dict[str, str] = {}
|
|
379
|
+
|
|
380
|
+
for array_name, array in self.items():
|
|
381
|
+
array_refs[array_name] = array.object_id
|
|
382
|
+
|
|
383
|
+
# Serialize references dict
|
|
384
|
+
object_body = json.dumps(array_refs).encode("utf-8")
|
|
385
|
+
return add_header_to_object_body(object_body=object_body, obj=self)
|
|
386
|
+
|
|
387
|
+
@classmethod
|
|
388
|
+
def inflate(
|
|
389
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
390
|
+
) -> ArrayRecord:
|
|
391
|
+
"""Inflate an ArrayRecord from bytes.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
object_content : bytes
|
|
396
|
+
The deflated object content of the ArrayRecord.
|
|
397
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
398
|
+
Dictionary of children InflatableObjects mapped to their Object IDs.
|
|
399
|
+
These children enable the full inflation of the ArrayRecord.
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
ArrayRecord
|
|
404
|
+
The inflated ArrayRecord.
|
|
405
|
+
"""
|
|
406
|
+
if children is None:
|
|
407
|
+
children = {}
|
|
408
|
+
|
|
409
|
+
# Inflate mapping of array_names (keys in the ArrayRecord) to Arrays' object IDs
|
|
410
|
+
obj_body = get_object_body(object_content, cls)
|
|
411
|
+
array_refs: dict[str, str] = json.loads(obj_body.decode(encoding="utf-8"))
|
|
412
|
+
|
|
413
|
+
unique_arrays = set(array_refs.values())
|
|
414
|
+
children_obj_ids = set(children.keys())
|
|
415
|
+
if unique_arrays != children_obj_ids:
|
|
416
|
+
raise ValueError(
|
|
417
|
+
"Unexpected set of `children`. "
|
|
418
|
+
f"Expected {unique_arrays} but got {children_obj_ids}."
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Ensure children are of type Array
|
|
422
|
+
if not all(isinstance(arr, Array) for arr in children.values()):
|
|
423
|
+
raise ValueError("`Children` are expected to be of type `Array`.")
|
|
424
|
+
|
|
425
|
+
# Instantiate new ArrayRecord
|
|
426
|
+
return ArrayRecord(
|
|
427
|
+
OrderedDict(
|
|
428
|
+
{name: children[object_id] for name, object_id in array_refs.items()}
|
|
429
|
+
)
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def object_id(self) -> str:
|
|
434
|
+
"""Get object ID."""
|
|
435
|
+
ret = super().object_id
|
|
436
|
+
self.is_dirty = False # Reset dirty flag
|
|
437
|
+
return ret
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def is_dirty(self) -> bool:
|
|
441
|
+
"""Check if the object is dirty after the last deflation."""
|
|
442
|
+
if "_is_dirty" not in self.__dict__:
|
|
443
|
+
self.__dict__["_is_dirty"] = True
|
|
444
|
+
|
|
445
|
+
if not self.__dict__["_is_dirty"]:
|
|
446
|
+
if any(v.is_dirty for v in self.values()):
|
|
447
|
+
# If any Array is dirty, mark the record as dirty
|
|
448
|
+
self.__dict__["_is_dirty"] = True
|
|
449
|
+
return cast(bool, self.__dict__["_is_dirty"])
|
|
450
|
+
|
|
451
|
+
@is_dirty.setter
|
|
452
|
+
def is_dirty(self, value: bool) -> None:
|
|
453
|
+
"""Set the dirty flag."""
|
|
454
|
+
self.__dict__["_is_dirty"] = value
|
|
455
|
+
|
|
456
|
+
def __setitem__(self, key: str, value: Array) -> None:
|
|
457
|
+
"""Set item and mark the record as dirty."""
|
|
458
|
+
self.is_dirty = True # Mark as dirty when setting an item
|
|
459
|
+
super().__setitem__(key, value)
|
|
460
|
+
|
|
461
|
+
def __delitem__(self, key: str) -> None:
|
|
462
|
+
"""Delete item and mark the record as dirty."""
|
|
463
|
+
self.is_dirty = True # Mark as dirty when deleting an item
|
|
464
|
+
super().__delitem__(key)
|
|
465
|
+
|
|
588
466
|
|
|
589
467
|
class ParametersRecord(ArrayRecord):
|
|
590
468
|
"""Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
|
|
@@ -15,12 +15,21 @@
|
|
|
15
15
|
"""ConfigRecord."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
from logging import WARN
|
|
19
|
-
from typing import
|
|
21
|
+
from typing import cast, get_args
|
|
20
22
|
|
|
21
23
|
from flwr.common.typing import ConfigRecordValues, ConfigScalar
|
|
22
24
|
|
|
25
|
+
# pylint: disable=E0611
|
|
26
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
27
|
+
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
|
28
|
+
|
|
29
|
+
# pylint: enable=E0611
|
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
23
31
|
from ..logger import log
|
|
32
|
+
from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
|
24
33
|
from .typeddict import TypedDict
|
|
25
34
|
|
|
26
35
|
|
|
@@ -59,8 +68,8 @@ def _check_value(value: ConfigRecordValues) -> None:
|
|
|
59
68
|
is_valid(value)
|
|
60
69
|
|
|
61
70
|
|
|
62
|
-
class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
63
|
-
"""
|
|
71
|
+
class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
|
|
72
|
+
"""Config record.
|
|
64
73
|
|
|
65
74
|
A :code:`ConfigRecord` is a Python dictionary designed to ensure that
|
|
66
75
|
each key-value pair adheres to specified data types. A :code:`ConfigRecord`
|
|
@@ -90,18 +99,18 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
|
90
99
|
encourage you to use a :code:`ArrayRecord` instead if these are of high
|
|
91
100
|
dimensionality.
|
|
92
101
|
|
|
93
|
-
Let's see some examples of how to construct a :code:`ConfigRecord` from scratch
|
|
102
|
+
Let's see some examples of how to construct a :code:`ConfigRecord` from scratch::
|
|
103
|
+
|
|
104
|
+
from flwr.common import ConfigRecord
|
|
94
105
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
>>> # And string values (among other types)
|
|
104
|
-
>>> record["path-to-S3"] = "s3://bucket_name/folder1/fileA.json"
|
|
106
|
+
# A `ConfigRecord` is a specialized Python dictionary
|
|
107
|
+
record = ConfigRecord({"lr": 0.1, "batch-size": 128})
|
|
108
|
+
# You can add more content to an existing record
|
|
109
|
+
record["compute-average"] = True
|
|
110
|
+
# It also supports lists
|
|
111
|
+
record["loss-fn-coefficients"] = [0.4, 0.25, 0.35]
|
|
112
|
+
# And string values (among other types)
|
|
113
|
+
record["path-to-S3"] = "s3://bucket_name/folder1/fileA.json"
|
|
105
114
|
|
|
106
115
|
Just like the other types of records in a :code:`flwr.common.RecordDict`, types are
|
|
107
116
|
enforced. If you need to add a custom data structure or object, we recommend to
|
|
@@ -111,7 +120,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
|
111
120
|
|
|
112
121
|
def __init__(
|
|
113
122
|
self,
|
|
114
|
-
config_dict:
|
|
123
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
|
115
124
|
keep_input: bool = True,
|
|
116
125
|
) -> None:
|
|
117
126
|
|
|
@@ -164,6 +173,52 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
|
164
173
|
|
|
165
174
|
return num_bytes
|
|
166
175
|
|
|
176
|
+
def deflate(self) -> bytes:
|
|
177
|
+
"""Deflate object."""
|
|
178
|
+
protos = record_value_dict_to_proto(
|
|
179
|
+
self,
|
|
180
|
+
[bool, int, float, str, bytes],
|
|
181
|
+
ProtoConfigRecordValue,
|
|
182
|
+
)
|
|
183
|
+
obj_body = ProtoConfigRecord(
|
|
184
|
+
items=[ProtoConfigRecord.Item(key=k, value=v) for k, v in protos.items()]
|
|
185
|
+
).SerializeToString()
|
|
186
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def inflate(
|
|
190
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
191
|
+
) -> ConfigRecord:
|
|
192
|
+
"""Inflate a ConfigRecord from bytes.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
object_content : bytes
|
|
197
|
+
The deflated object content of the ConfigRecord.
|
|
198
|
+
|
|
199
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
200
|
+
Must be ``None``. ``ConfigRecord`` does not support child objects.
|
|
201
|
+
Providing any children will raise a ``ValueError``.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
ConfigRecord
|
|
206
|
+
The inflated ConfigRecord.
|
|
207
|
+
"""
|
|
208
|
+
if children:
|
|
209
|
+
raise ValueError("`ConfigRecord` objects do not have children.")
|
|
210
|
+
|
|
211
|
+
obj_body = get_object_body(object_content, cls)
|
|
212
|
+
config_record_proto = ProtoConfigRecord.FromString(obj_body)
|
|
213
|
+
protos = {item.key: item.value for item in config_record_proto.items}
|
|
214
|
+
return ConfigRecord(
|
|
215
|
+
config_dict=cast(
|
|
216
|
+
dict[str, ConfigRecordValues],
|
|
217
|
+
record_value_dict_from_proto(protos),
|
|
218
|
+
),
|
|
219
|
+
keep_input=False,
|
|
220
|
+
)
|
|
221
|
+
|
|
167
222
|
|
|
168
223
|
class ConfigsRecord(ConfigRecord):
|
|
169
224
|
"""Deprecated class ``ConfigsRecord``, use ``ConfigRecord`` instead.
|
|
@@ -195,7 +250,7 @@ class ConfigsRecord(ConfigRecord):
|
|
|
195
250
|
|
|
196
251
|
def __init__(
|
|
197
252
|
self,
|
|
198
|
-
config_dict:
|
|
253
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
|
199
254
|
keep_input: bool = True,
|
|
200
255
|
):
|
|
201
256
|
if not ConfigsRecord._warning_logged:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from ..logger import warn_deprecated_feature
|
|
19
19
|
from ..typing import NDArray
|
|
20
|
-
from .
|
|
20
|
+
from .array import Array
|
|
21
21
|
|
|
22
22
|
WARN_DEPRECATED_MESSAGE = (
|
|
23
23
|
"`array_from_numpy` is deprecated. Instead, use the `Array(ndarray)` class "
|