flwr 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (248) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/cli/__init__.py +1 -1
  3. flwr/cli/app.py +21 -2
  4. flwr/cli/build.py +1 -1
  5. flwr/cli/cli_user_auth_interceptor.py +1 -1
  6. flwr/cli/config_utils.py +53 -17
  7. flwr/cli/example.py +1 -1
  8. flwr/cli/install.py +1 -1
  9. flwr/cli/log.py +1 -1
  10. flwr/cli/login/__init__.py +1 -1
  11. flwr/cli/login/login.py +12 -1
  12. flwr/cli/ls.py +1 -1
  13. flwr/cli/new/__init__.py +1 -1
  14. flwr/cli/new/new.py +4 -4
  15. flwr/cli/new/templates/__init__.py +1 -1
  16. flwr/cli/new/templates/app/__init__.py +1 -1
  17. flwr/cli/new/templates/app/code/__init__.py +1 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
  19. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +5 -5
  20. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  23. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  30. flwr/cli/run/__init__.py +1 -1
  31. flwr/cli/run/run.py +6 -10
  32. flwr/cli/stop.py +1 -1
  33. flwr/cli/utils.py +11 -12
  34. flwr/client/__init__.py +1 -1
  35. flwr/client/app.py +58 -56
  36. flwr/client/client.py +1 -1
  37. flwr/client/client_app.py +231 -166
  38. flwr/client/clientapp/__init__.py +1 -1
  39. flwr/client/clientapp/app.py +3 -3
  40. flwr/client/clientapp/clientappio_servicer.py +1 -1
  41. flwr/client/clientapp/utils.py +1 -1
  42. flwr/client/dpfedavg_numpy_client.py +1 -1
  43. flwr/client/grpc_adapter_client/__init__.py +1 -1
  44. flwr/client/grpc_adapter_client/connection.py +1 -1
  45. flwr/client/grpc_client/__init__.py +1 -1
  46. flwr/client/grpc_client/connection.py +37 -34
  47. flwr/client/grpc_rere_client/__init__.py +1 -1
  48. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  49. flwr/client/grpc_rere_client/connection.py +1 -1
  50. flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
  51. flwr/client/heartbeat.py +1 -1
  52. flwr/client/message_handler/__init__.py +1 -1
  53. flwr/client/message_handler/message_handler.py +28 -28
  54. flwr/client/mod/__init__.py +3 -3
  55. flwr/client/mod/centraldp_mods.py +8 -8
  56. flwr/client/mod/comms_mods.py +17 -23
  57. flwr/client/mod/localdp_mod.py +10 -10
  58. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  59. flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
  60. flwr/client/mod/secure_aggregation/secaggplus_mod.py +32 -32
  61. flwr/client/mod/utils.py +1 -1
  62. flwr/client/nodestate/__init__.py +1 -1
  63. flwr/client/nodestate/in_memory_nodestate.py +1 -1
  64. flwr/client/nodestate/nodestate.py +1 -1
  65. flwr/client/nodestate/nodestate_factory.py +1 -1
  66. flwr/client/numpy_client.py +1 -1
  67. flwr/client/rest_client/__init__.py +1 -1
  68. flwr/client/rest_client/connection.py +1 -1
  69. flwr/client/run_info_store.py +3 -3
  70. flwr/client/supernode/__init__.py +1 -1
  71. flwr/client/supernode/app.py +1 -1
  72. flwr/client/typing.py +1 -1
  73. flwr/common/__init__.py +13 -5
  74. flwr/common/address.py +1 -1
  75. flwr/common/args.py +1 -1
  76. flwr/common/auth_plugin/__init__.py +1 -1
  77. flwr/common/auth_plugin/auth_plugin.py +1 -1
  78. flwr/common/config.py +5 -5
  79. flwr/common/constant.py +7 -7
  80. flwr/common/context.py +5 -5
  81. flwr/common/date.py +1 -1
  82. flwr/common/differential_privacy.py +1 -1
  83. flwr/common/differential_privacy_constants.py +1 -1
  84. flwr/common/dp.py +1 -1
  85. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  86. flwr/common/exit/exit.py +6 -6
  87. flwr/common/exit_handlers.py +1 -1
  88. flwr/common/grpc.py +1 -1
  89. flwr/common/logger.py +3 -3
  90. flwr/common/message.py +344 -102
  91. flwr/common/object_ref.py +1 -1
  92. flwr/common/parameter.py +1 -1
  93. flwr/common/pyproject.py +1 -1
  94. flwr/common/record/__init__.py +9 -5
  95. flwr/common/record/arrayrecord.py +626 -0
  96. flwr/common/record/{configsrecord.py → configrecord.py} +83 -37
  97. flwr/common/record/conversion_utils.py +2 -2
  98. flwr/common/record/{metricsrecord.py → metricrecord.py} +90 -44
  99. flwr/common/record/recorddict.py +337 -0
  100. flwr/common/record/typeddict.py +1 -1
  101. flwr/common/recorddict_compat.py +410 -0
  102. flwr/common/retry_invoker.py +10 -10
  103. flwr/common/secure_aggregation/__init__.py +1 -1
  104. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  105. flwr/common/secure_aggregation/crypto/shamir.py +52 -30
  106. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  107. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  108. flwr/common/secure_aggregation/quantization.py +1 -1
  109. flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
  110. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  111. flwr/common/serde.py +67 -72
  112. flwr/common/telemetry.py +2 -2
  113. flwr/common/typing.py +9 -9
  114. flwr/common/version.py +1 -1
  115. flwr/proto/__init__.py +1 -1
  116. flwr/proto/exec_pb2.py +3 -3
  117. flwr/proto/exec_pb2.pyi +3 -3
  118. flwr/proto/message_pb2.py +12 -12
  119. flwr/proto/message_pb2.pyi +9 -9
  120. flwr/proto/recorddict_pb2.py +70 -0
  121. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  122. flwr/proto/run_pb2.py +31 -31
  123. flwr/proto/run_pb2.pyi +3 -3
  124. flwr/server/__init__.py +4 -2
  125. flwr/server/app.py +67 -12
  126. flwr/server/client_manager.py +1 -1
  127. flwr/server/client_proxy.py +1 -1
  128. flwr/server/compat/__init__.py +3 -3
  129. flwr/server/compat/app.py +12 -12
  130. flwr/server/compat/app_utils.py +17 -17
  131. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
  132. flwr/server/compat/legacy_context.py +1 -1
  133. flwr/server/criterion.py +1 -1
  134. flwr/server/fleet_event_log_interceptor.py +94 -0
  135. flwr/server/{driver → grid}/__init__.py +8 -7
  136. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  137. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
  138. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
  139. flwr/server/history.py +1 -1
  140. flwr/server/run_serverapp.py +5 -5
  141. flwr/server/server.py +1 -1
  142. flwr/server/server_app.py +98 -71
  143. flwr/server/server_config.py +1 -1
  144. flwr/server/serverapp/__init__.py +1 -1
  145. flwr/server/serverapp/app.py +11 -11
  146. flwr/server/serverapp_components.py +1 -1
  147. flwr/server/strategy/__init__.py +1 -1
  148. flwr/server/strategy/aggregate.py +1 -1
  149. flwr/server/strategy/bulyan.py +2 -2
  150. flwr/server/strategy/dp_adaptive_clipping.py +17 -17
  151. flwr/server/strategy/dp_fixed_clipping.py +17 -17
  152. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  153. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  154. flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
  155. flwr/server/strategy/fedadagrad.py +1 -1
  156. flwr/server/strategy/fedadam.py +1 -1
  157. flwr/server/strategy/fedavg.py +1 -1
  158. flwr/server/strategy/fedavg_android.py +1 -1
  159. flwr/server/strategy/fedavgm.py +1 -1
  160. flwr/server/strategy/fedmedian.py +1 -1
  161. flwr/server/strategy/fedopt.py +1 -1
  162. flwr/server/strategy/fedprox.py +1 -1
  163. flwr/server/strategy/fedtrimmedavg.py +1 -1
  164. flwr/server/strategy/fedxgb_bagging.py +1 -1
  165. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  166. flwr/server/strategy/fedxgb_nn_avg.py +3 -2
  167. flwr/server/strategy/fedyogi.py +1 -1
  168. flwr/server/strategy/krum.py +1 -1
  169. flwr/server/strategy/qfedavg.py +1 -1
  170. flwr/server/strategy/strategy.py +1 -1
  171. flwr/server/superlink/__init__.py +1 -1
  172. flwr/server/superlink/ffs/__init__.py +1 -1
  173. flwr/server/superlink/ffs/disk_ffs.py +1 -1
  174. flwr/server/superlink/ffs/ffs.py +1 -1
  175. flwr/server/superlink/ffs/ffs_factory.py +1 -1
  176. flwr/server/superlink/fleet/__init__.py +1 -1
  177. flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
  178. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
  179. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  180. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  181. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  182. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  183. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
  184. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  185. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  186. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
  187. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  188. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  189. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  190. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  191. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  192. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  193. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  194. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -3
  195. flwr/server/superlink/fleet/vce/vce_api.py +2 -4
  196. flwr/server/superlink/linkstate/__init__.py +1 -1
  197. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -9
  198. flwr/server/superlink/linkstate/linkstate.py +5 -5
  199. flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
  200. flwr/server/superlink/linkstate/sqlite_linkstate.py +62 -28
  201. flwr/server/superlink/linkstate/utils.py +94 -28
  202. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  203. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  204. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
  205. flwr/server/superlink/simulation/__init__.py +1 -1
  206. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  207. flwr/server/superlink/simulation/simulationio_servicer.py +3 -3
  208. flwr/server/superlink/utils.py +1 -1
  209. flwr/server/typing.py +4 -4
  210. flwr/server/utils/__init__.py +1 -1
  211. flwr/server/utils/tensorboard.py +1 -1
  212. flwr/server/utils/validator.py +5 -5
  213. flwr/server/workflow/__init__.py +1 -1
  214. flwr/server/workflow/constant.py +1 -1
  215. flwr/server/workflow/default_workflows.py +49 -58
  216. flwr/server/workflow/secure_aggregation/__init__.py +1 -1
  217. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
  218. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +49 -51
  219. flwr/simulation/__init__.py +1 -1
  220. flwr/simulation/app.py +3 -3
  221. flwr/simulation/legacy_app.py +1 -1
  222. flwr/simulation/ray_transport/__init__.py +1 -1
  223. flwr/simulation/ray_transport/ray_actor.py +5 -3
  224. flwr/simulation/ray_transport/ray_client_proxy.py +35 -33
  225. flwr/simulation/ray_transport/utils.py +1 -1
  226. flwr/simulation/run_simulation.py +17 -17
  227. flwr/simulation/simulationio_connection.py +1 -1
  228. flwr/superexec/__init__.py +1 -1
  229. flwr/superexec/app.py +1 -1
  230. flwr/superexec/deployment.py +5 -5
  231. flwr/superexec/exec_event_log_interceptor.py +135 -0
  232. flwr/superexec/exec_grpc.py +11 -5
  233. flwr/superexec/exec_servicer.py +3 -3
  234. flwr/superexec/exec_user_auth_interceptor.py +19 -3
  235. flwr/superexec/executor.py +4 -4
  236. flwr/superexec/simulation.py +4 -4
  237. {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/METADATA +3 -3
  238. flwr-1.18.0.dist-info/RECORD +332 -0
  239. flwr/common/record/parametersrecord.py +0 -339
  240. flwr/common/record/recordset.py +0 -209
  241. flwr/common/recordset_compat.py +0 -418
  242. flwr/proto/recordset_pb2.py +0 -70
  243. flwr-1.16.0.dist-info/LICENSE +0 -202
  244. flwr-1.16.0.dist-info/RECORD +0 -331
  245. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  246. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  247. {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/WHEEL +0 -0
  248. {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,626 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ArrayRecord and Array."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ import gc
21
+ import sys
22
+ from collections import OrderedDict
23
+ from dataclasses import dataclass
24
+ from io import BytesIO
25
+ from logging import WARN
26
+ from typing import TYPE_CHECKING, Any, cast, overload
27
+
28
+ import numpy as np
29
+
30
+ from ..constant import GC_THRESHOLD, SType
31
+ from ..logger import log
32
+ from ..typing import NDArray
33
+ from .typeddict import TypedDict
34
+
35
+ if TYPE_CHECKING:
36
+ import torch
37
+
38
+
39
+ def _raise_array_init_error() -> None:
40
+ raise TypeError(
41
+ f"Invalid arguments for {Array.__qualname__}. Expected either a "
42
+ "PyTorch tensor, a NumPy ndarray, or explicit"
43
+ " dtype/shape/stype/data values."
44
+ )
45
+
46
+
47
+ def _raise_array_record_init_error() -> None:
48
+ raise TypeError(
49
+ f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
50
+ "a list of NumPy ndarrays, a PyTorch state_dict, or a dictionary of Arrays. "
51
+ "The `keep_input` argument is keyword-only."
52
+ )
53
+
54
+
55
+ @dataclass
56
+ class Array:
57
+ """Array type.
58
+
59
+ A dataclass containing serialized data from an array-like or tensor-like object
60
+ along with metadata about it. The class can be initialized in one of three ways:
61
+
62
+ 1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
63
+ 2. By providing a NumPy ndarray (via the `ndarray` argument).
64
+ 3. By providing a PyTorch tensor (via the `torch_tensor` argument).
65
+
66
+ In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
67
+ derived from the input. In scenario (1), these fields must be specified manually.
68
+
69
+ Parameters
70
+ ----------
71
+ dtype : Optional[str] (default: None)
72
+ A string representing the data type of the serialized object (e.g. `"float32"`).
73
+ Only required if you are not passing in a ndarray or a tensor.
74
+
75
+ shape : Optional[list[int]] (default: None)
76
+ A list representing the shape of the unserialized array-like object. Only
77
+ required if you are not passing in a ndarray or a tensor.
78
+
79
+ stype : Optional[str] (default: None)
80
+ A string indicating the serialization mechanism used to generate the bytes in
81
+ `data` from an array-like or tensor-like object. Only required if you are not
82
+ passing in a ndarray or a tensor.
83
+
84
+ data : Optional[bytes] (default: None)
85
+ A buffer of bytes containing the data. Only required if you are not passing in
86
+ a ndarray or a tensor.
87
+
88
+ ndarray : Optional[NDArray] (default: None)
89
+ A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
90
+ fields are derived automatically from it.
91
+
92
+ torch_tensor : Optional[torch.Tensor] (default: None)
93
+ A PyTorch tensor. If provided, it will be **detached and moved to CPU**
94
+ before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
95
+ will be derived automatically from it.
96
+
97
+ Examples
98
+ --------
99
+ Initializing by specifying all fields directly::
100
+
101
+ arr1 = Array(
102
+ dtype="float32",
103
+ shape=[3, 3],
104
+ stype="numpy.ndarray",
105
+ data=b"serialized_data...",
106
+ )
107
+
108
+ Initializing with a NumPy ndarray::
109
+
110
+ import numpy as np
111
+ arr2 = Array(np.random.randn(3, 3))
112
+
113
+ Initializing with a PyTorch tensor::
114
+
115
+ import torch
116
+ arr3 = Array(torch.randn(3, 3))
117
+ """
118
+
119
+ dtype: str
120
+ shape: list[int]
121
+ stype: str
122
+ data: bytes
123
+
124
+ @overload
125
+ def __init__( # noqa: E704
126
+ self, dtype: str, shape: list[int], stype: str, data: bytes
127
+ ) -> None: ...
128
+
129
+ @overload
130
+ def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
131
+
132
+ @overload
133
+ def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
134
+
135
+ def __init__( # pylint: disable=too-many-arguments, too-many-locals
136
+ self,
137
+ *args: Any,
138
+ dtype: str | None = None,
139
+ shape: list[int] | None = None,
140
+ stype: str | None = None,
141
+ data: bytes | None = None,
142
+ ndarray: NDArray | None = None,
143
+ torch_tensor: torch.Tensor | None = None,
144
+ ) -> None:
145
+ # Determine the initialization method and validate input arguments.
146
+ # Support three initialization formats:
147
+ # 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
148
+ # 2. Array(ndarray: NDArray)
149
+ # 3. Array(torch_tensor: torch.Tensor)
150
+
151
+ # Initialize all arguments
152
+ # If more than 4 positional arguments are provided, raise an error.
153
+ if len(args) > 4:
154
+ _raise_array_init_error()
155
+ all_args = [None] * 4
156
+ for i, arg in enumerate(args):
157
+ all_args[i] = arg
158
+ init_method: str | None = None # Track which init method is being used
159
+
160
+ # Try to assign a value to all_args[index] if it's not already set.
161
+ # If an initialization method is provided, update init_method.
162
+ def _try_set_arg(index: int, arg: Any, method: str) -> None:
163
+ # Skip if arg is None
164
+ if arg is None:
165
+ return
166
+ # Raise an error if all_args[index] is already set
167
+ if all_args[index] is not None:
168
+ _raise_array_init_error()
169
+ # Raise an error if a different initialization method is already set
170
+ nonlocal init_method
171
+ if init_method is not None and init_method != method:
172
+ _raise_array_init_error()
173
+ # Set init_method and all_args[index]
174
+ if init_method is None:
175
+ init_method = method
176
+ all_args[index] = arg
177
+
178
+ # Try to set keyword arguments in all_args
179
+ _try_set_arg(0, dtype, "direct")
180
+ _try_set_arg(1, shape, "direct")
181
+ _try_set_arg(2, stype, "direct")
182
+ _try_set_arg(3, data, "direct")
183
+ _try_set_arg(0, ndarray, "ndarray")
184
+ _try_set_arg(0, torch_tensor, "torch_tensor")
185
+
186
+ # Check if all arguments are correctly set
187
+ all_args = [arg for arg in all_args if arg is not None]
188
+
189
+ # Handle direct field initialization
190
+ if not init_method or init_method == "direct":
191
+ if (
192
+ len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
193
+ and isinstance(all_args[0], str)
194
+ and isinstance(all_args[1], list)
195
+ and all(isinstance(i, int) for i in all_args[1])
196
+ and isinstance(all_args[2], str)
197
+ and isinstance(all_args[3], bytes)
198
+ ):
199
+ self.dtype, self.shape, self.stype, self.data = all_args
200
+ return
201
+
202
+ # Handle NumPy array
203
+ if not init_method or init_method == "ndarray":
204
+ if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
205
+ self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
206
+ return
207
+
208
+ # Handle PyTorch tensor
209
+ if not init_method or init_method == "torch_tensor":
210
+ if (
211
+ len(all_args) == 1
212
+ and "torch" in sys.modules
213
+ and isinstance(all_args[0], sys.modules["torch"].Tensor)
214
+ ):
215
+ self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
216
+ return
217
+
218
+ _raise_array_init_error()
219
+
220
+ @classmethod
221
+ def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
222
+ """Create Array from NumPy ndarray."""
223
+ assert isinstance(
224
+ ndarray, np.ndarray
225
+ ), f"Expected NumPy ndarray, got {type(ndarray)}"
226
+ buffer = BytesIO()
227
+ # WARNING: NEVER set allow_pickle to true.
228
+ # Reason: loading pickled data can execute arbitrary code
229
+ # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
230
+ np.save(buffer, ndarray, allow_pickle=False)
231
+ data = buffer.getvalue()
232
+ return Array(
233
+ dtype=str(ndarray.dtype),
234
+ shape=list(ndarray.shape),
235
+ stype=SType.NUMPY,
236
+ data=data,
237
+ )
238
+
239
+ @classmethod
240
+ def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
241
+ """Create Array from PyTorch tensor."""
242
+ if not (torch := sys.modules.get("torch")):
243
+ raise RuntimeError(
244
+ f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
245
+ )
246
+
247
+ assert isinstance(
248
+ tensor, torch.Tensor
249
+ ), f"Expected PyTorch Tensor, got {type(tensor)}"
250
+ return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
251
+
252
+ def numpy(self) -> NDArray:
253
+ """Return the array as a NumPy array."""
254
+ if self.stype != SType.NUMPY:
255
+ raise TypeError(
256
+ f"Unsupported serialization type for numpy conversion: '{self.stype}'"
257
+ )
258
+ bytes_io = BytesIO(self.data)
259
+ # WARNING: NEVER set allow_pickle to true.
260
+ # Reason: loading pickled data can execute arbitrary code
261
+ # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
262
+ ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
263
+ return cast(NDArray, ndarray_deserialized)
264
+
265
+
266
+ def _check_key(key: str) -> None:
267
+ """Check if key is of expected type."""
268
+ if not isinstance(key, str):
269
+ raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.")
270
+
271
+
272
+ def _check_value(value: Array) -> None:
273
+ if not isinstance(value, Array):
274
+ raise TypeError(
275
+ f"Value must be of type `{Array}` but `{type(value)}` was passed."
276
+ )
277
+
278
+
279
+ class ArrayRecord(TypedDict[str, Array]):
280
+ """Array record.
281
+
282
+ A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
283
+ including model parameters, gradients, embeddings or non-parameter arrays.
284
+ Internally, this behaves similarly to an ``OrderedDict[str, Array]``.
285
+ An ``ArrayRecord`` can be viewed as an equivalent to PyTorch's ``state_dict``,
286
+ but it holds arrays in a serialized form.
287
+
288
+ This object is one of the record types supported by :class:`RecordDict` and can
289
+ therefore be stored in the ``content`` of a :class:`Message` or the ``state``
290
+ of a :class:`Context`.
291
+
292
+ This class can be instantiated in multiple ways:
293
+
294
+ 1. By providing nothing (empty container).
295
+ 2. By providing a dictionary of :class:`Array` (via the ``array_dict`` argument).
296
+ 3. By providing a list of NumPy ``ndarray`` (via the ``numpy_ndarrays`` argument).
297
+ 4. By providing a PyTorch ``state_dict`` (via the ``torch_state_dict`` argument).
298
+
299
+ Parameters
300
+ ----------
301
+ array_dict : Optional[OrderedDict[str, Array]] (default: None)
302
+ An existing dictionary containing named :class:`Array` instances. If
303
+ provided, these entries will be used directly to populate the record.
304
+ numpy_ndarrays : Optional[list[NDArray]] (default: None)
305
+ A list of NumPy arrays. Each array will be automatically converted
306
+ into an :class:`Array` and stored in this record with generated keys.
307
+ torch_state_dict : Optional[OrderedDict[str, torch.Tensor]] (default: None)
308
+ A PyTorch ``state_dict`` (``str`` keys to ``torch.Tensor`` values). Each
309
+ tensor will be converted into an :class:`Array` and stored in this record.
310
+ keep_input : bool (default: True)
311
+ If ``False``, entries from the input are removed after being added to
312
+ this record to free up memory. If ``True``, the input remains unchanged.
313
+ Regardless of this value, no duplicate memory is used if the input is a
314
+ dictionary of :class:`Array`, i.e., ``array_dict``.
315
+
316
+ Examples
317
+ --------
318
+ Initializing an empty ArrayRecord::
319
+
320
+ record = ArrayRecord()
321
+
322
+ Initializing with a dictionary of :class:`Array`::
323
+
324
+ arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
325
+ record = ArrayRecord({"weight": arr})
326
+
327
+ Initializing with a list of NumPy arrays::
328
+
329
+ import numpy as np
330
+ arr1 = np.random.randn(3, 3)
331
+ arr2 = np.random.randn(2, 2)
332
+ record = ArrayRecord([arr1, arr2])
333
+
334
+ Initializing with a PyTorch model state_dict::
335
+
336
+ import torch.nn as nn
337
+ model = nn.Linear(10, 5)
338
+ record = ArrayRecord(model.state_dict())
339
+
340
+ Initializing with a TensorFlow model weights (a list of NumPy arrays)::
341
+
342
+ import tensorflow as tf
343
+ model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
344
+ record = ArrayRecord(model.get_weights())
345
+ """
346
+
347
+ @overload
348
+ def __init__(self) -> None: ... # noqa: E704
349
+
350
+ @overload
351
+ def __init__( # noqa: E704
352
+ self, array_dict: OrderedDict[str, Array], *, keep_input: bool = True
353
+ ) -> None: ...
354
+
355
+ @overload
356
+ def __init__( # noqa: E704
357
+ self, numpy_ndarrays: list[NDArray], *, keep_input: bool = True
358
+ ) -> None: ...
359
+
360
+ @overload
361
+ def __init__( # noqa: E704
362
+ self,
363
+ torch_state_dict: OrderedDict[str, torch.Tensor],
364
+ *,
365
+ keep_input: bool = True,
366
+ ) -> None: ...
367
+
368
+ def __init__( # pylint: disable=too-many-arguments
369
+ self,
370
+ *args: Any,
371
+ numpy_ndarrays: list[NDArray] | None = None,
372
+ torch_state_dict: OrderedDict[str, torch.Tensor] | None = None,
373
+ array_dict: OrderedDict[str, Array] | None = None,
374
+ keep_input: bool = True,
375
+ ) -> None:
376
+ super().__init__(_check_key, _check_value)
377
+
378
+ # Determine the initialization method and validates input arguments.
379
+ # Support the following initialization formats:
380
+ # 1. cls(array_dict: OrderedDict[str, Array], keep_input: bool)
381
+ # 2. cls(numpy_ndarrays: list[NDArray], keep_input: bool)
382
+ # 3. cls(torch_state_dict: dict[str, torch.Tensor], keep_input: bool)
383
+
384
+ # Init the argument
385
+ if len(args) > 1:
386
+ _raise_array_record_init_error()
387
+ arg = args[0] if args else None
388
+ init_method: str | None = None # Track which init method is being used
389
+
390
+ # Try to assign a value to arg if it's not already set.
391
+ # If an initialization method is provided, update init_method.
392
+ def _try_set_arg(_arg: Any, method: str) -> None:
393
+ # Skip if _arg is None
394
+ if _arg is None:
395
+ return
396
+ nonlocal arg, init_method
397
+ # Raise an error if arg is already set
398
+ if arg is not None:
399
+ _raise_array_record_init_error()
400
+ # Raise an error if a different initialization method is already set
401
+ if init_method is not None:
402
+ _raise_array_record_init_error()
403
+ # Set init_method and arg
404
+ if init_method is None:
405
+ init_method = method
406
+ arg = _arg
407
+
408
+ # Try to set keyword arguments
409
+ _try_set_arg(array_dict, "array_dict")
410
+ _try_set_arg(numpy_ndarrays, "numpy_ndarrays")
411
+ _try_set_arg(torch_state_dict, "state_dict")
412
+
413
+ # If no arguments are provided, return and keep self empty
414
+ if arg is None:
415
+ return
416
+
417
+ # Handle dictionary of Arrays
418
+ if not init_method or init_method == "array_dict":
419
+ # Type check the input
420
+ if (
421
+ isinstance(arg, dict)
422
+ and all(isinstance(k, str) for k in arg.keys())
423
+ and all(isinstance(v, Array) for v in arg.values())
424
+ ):
425
+ array_dict = cast(OrderedDict[str, Array], arg)
426
+ converted = self.from_array_dict(array_dict, keep_input=keep_input)
427
+ self.__dict__.update(converted.__dict__)
428
+ return
429
+
430
+ # Handle NumPy ndarrays
431
+ if not init_method or init_method == "numpy_ndarrays":
432
+ # Type check the input
433
+ # pylint: disable-next=not-an-iterable
434
+ if isinstance(arg, list) and all(isinstance(v, np.ndarray) for v in arg):
435
+ numpy_ndarrays = cast(list[NDArray], arg)
436
+ converted = self.from_numpy_ndarrays(
437
+ numpy_ndarrays, keep_input=keep_input
438
+ )
439
+ self.__dict__.update(converted.__dict__)
440
+ return
441
+
442
+ # Handle PyTorch state_dict
443
+ if not init_method or init_method == "state_dict":
444
+ # Type check the input
445
+ if (
446
+ (torch := sys.modules.get("torch")) is not None
447
+ and isinstance(arg, dict)
448
+ and all(isinstance(k, str) for k in arg.keys())
449
+ and all(isinstance(v, torch.Tensor) for v in arg.values())
450
+ ):
451
+ torch_state_dict = cast(
452
+ OrderedDict[str, torch.Tensor], arg # type: ignore
453
+ )
454
+ converted = self.from_torch_state_dict(
455
+ torch_state_dict, keep_input=keep_input
456
+ )
457
+ self.__dict__.update(converted.__dict__)
458
+ return
459
+
460
+ _raise_array_record_init_error()
461
+
462
+ @classmethod
463
+ def from_array_dict(
464
+ cls,
465
+ array_dict: OrderedDict[str, Array],
466
+ *,
467
+ keep_input: bool = True,
468
+ ) -> ArrayRecord:
469
+ """Create ArrayRecord from a dictionary of :class:`Array`."""
470
+ record = ArrayRecord()
471
+ for k, v in array_dict.items():
472
+ record[k] = Array(
473
+ dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
474
+ )
475
+ if not keep_input:
476
+ array_dict.clear()
477
+ return record
478
+
479
+ @classmethod
480
+ def from_numpy_ndarrays(
481
+ cls,
482
+ ndarrays: list[NDArray],
483
+ *,
484
+ keep_input: bool = True,
485
+ ) -> ArrayRecord:
486
+ """Create ArrayRecord from a list of NumPy ``ndarray``."""
487
+ record = ArrayRecord()
488
+ total_serialized_bytes = 0
489
+
490
+ for i in range(len(ndarrays)): # pylint: disable=C0200
491
+ record[str(i)] = Array.from_numpy_ndarray(ndarrays[i])
492
+
493
+ if not keep_input:
494
+ # Remove the reference
495
+ ndarrays[i] = None # type: ignore
496
+ total_serialized_bytes += len(record[str(i)].data)
497
+
498
+ # If total serialized data exceeds the threshold, trigger GC
499
+ if total_serialized_bytes > GC_THRESHOLD:
500
+ total_serialized_bytes = 0
501
+ gc.collect()
502
+
503
+ if not keep_input:
504
+ # Clear the entire list to remove all references and force GC
505
+ ndarrays.clear()
506
+ gc.collect()
507
+ return record
508
+
509
+ @classmethod
510
+ def from_torch_state_dict(
511
+ cls,
512
+ state_dict: OrderedDict[str, torch.Tensor],
513
+ *,
514
+ keep_input: bool = True,
515
+ ) -> ArrayRecord:
516
+ """Create ArrayRecord from PyTorch ``state_dict``."""
517
+ if "torch" not in sys.modules:
518
+ raise RuntimeError(
519
+ f"PyTorch is required to use {cls.from_torch_state_dict.__name__}"
520
+ )
521
+
522
+ record = ArrayRecord()
523
+
524
+ for k in list(state_dict.keys()):
525
+ v = state_dict[k] if keep_input else state_dict.pop(k)
526
+ record[k] = Array.from_numpy_ndarray(v.detach().cpu().numpy())
527
+
528
+ return record
529
+
530
+ def to_numpy_ndarrays(self, *, keep_input: bool = True) -> list[NDArray]:
531
+ """Return the ArrayRecord as a list of NumPy ``ndarray``."""
532
+ if keep_input:
533
+ return [v.numpy() for v in self.values()]
534
+
535
+ # Clear the record and return the list of NumPy arrays
536
+ ret: list[NDArray] = []
537
+ total_serialized_bytes = 0
538
+ for k in list(self.keys()):
539
+ arr = self.pop(k)
540
+ ret.append(arr.numpy())
541
+ total_serialized_bytes += len(arr.data)
542
+ del arr
543
+
544
+ # If total serialized data exceeds the threshold, trigger GC
545
+ if total_serialized_bytes > GC_THRESHOLD:
546
+ total_serialized_bytes = 0
547
+ gc.collect()
548
+
549
+ if not keep_input:
550
+ # Force GC
551
+ gc.collect()
552
+ return ret
553
+
554
+ def to_torch_state_dict(
555
+ self, *, keep_input: bool = True
556
+ ) -> OrderedDict[str, torch.Tensor]:
557
+ """Return the ArrayRecord as a PyTorch ``state_dict``."""
558
+ if not (torch := sys.modules.get("torch")):
559
+ raise RuntimeError(
560
+ f"PyTorch is required to use {self.to_torch_state_dict.__name__}"
561
+ )
562
+
563
+ state_dict = OrderedDict()
564
+
565
+ for k in list(self.keys()):
566
+ arr = self[k] if keep_input else self.pop(k)
567
+ state_dict[k] = torch.from_numpy(arr.numpy())
568
+
569
+ return state_dict
570
+
571
+ def count_bytes(self) -> int:
572
+ """Return number of Bytes stored in this object.
573
+
574
+ Note that a small amount of Bytes might also be included in this counting that
575
+ correspond to metadata of the serialized object (e.g. of NumPy array) needed for
576
+ deseralization.
577
+ """
578
+ num_bytes = 0
579
+
580
+ for k, v in self.items():
581
+ num_bytes += len(v.data)
582
+
583
+ # We also count the bytes footprint of the keys
584
+ num_bytes += len(k)
585
+
586
+ return num_bytes
587
+
588
+
589
+ class ParametersRecord(ArrayRecord):
590
+ """Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
591
+
592
+ This class exists solely for backward compatibility with legacy
593
+ code that previously used ``ParametersRecord``. It has been renamed
594
+ to ``ArrayRecord``.
595
+
596
+ .. warning::
597
+ ``ParametersRecord`` is deprecated and will be removed in a future release.
598
+ Use ``ArrayRecord`` instead.
599
+
600
+ Examples
601
+ --------
602
+ Legacy (deprecated) usage::
603
+
604
+ from flwr.common import ParametersRecord
605
+
606
+ record = ParametersRecord()
607
+
608
+ Updated usage::
609
+
610
+ from flwr.common import ArrayRecord
611
+
612
+ record = ArrayRecord()
613
+ """
614
+
615
+ _warning_logged = False
616
+
617
+ def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None:
618
+ if not ParametersRecord._warning_logged:
619
+ ParametersRecord._warning_logged = True
620
+ log(
621
+ WARN,
622
+ "The `ParametersRecord` class has been renamed to `ArrayRecord`. "
623
+ "Support for `ParametersRecord` will be removed in a future release. "
624
+ "Please update your code accordingly.",
625
+ )
626
+ super().__init__(*args, **kwargs)