flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/app/__init__.py +15 -0
  3. flwr/app/error.py +68 -0
  4. flwr/app/metadata.py +223 -0
  5. flwr/cli/__init__.py +1 -1
  6. flwr/cli/app.py +21 -2
  7. flwr/cli/build.py +83 -58
  8. flwr/cli/cli_user_auth_interceptor.py +1 -1
  9. flwr/cli/config_utils.py +53 -17
  10. flwr/cli/example.py +1 -1
  11. flwr/cli/install.py +1 -1
  12. flwr/cli/log.py +4 -4
  13. flwr/cli/login/__init__.py +1 -1
  14. flwr/cli/login/login.py +15 -8
  15. flwr/cli/ls.py +16 -37
  16. flwr/cli/new/__init__.py +1 -1
  17. flwr/cli/new/new.py +4 -4
  18. flwr/cli/new/templates/__init__.py +1 -1
  19. flwr/cli/new/templates/app/__init__.py +1 -1
  20. flwr/cli/new/templates/app/code/__init__.py +1 -1
  21. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  22. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
  24. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  25. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  26. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
  28. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  29. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  34. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  35. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  36. flwr/cli/run/__init__.py +1 -1
  37. flwr/cli/run/run.py +11 -19
  38. flwr/cli/stop.py +3 -3
  39. flwr/cli/utils.py +42 -17
  40. flwr/client/__init__.py +3 -3
  41. flwr/client/client.py +1 -1
  42. flwr/client/client_app.py +140 -138
  43. flwr/client/clientapp/__init__.py +1 -8
  44. flwr/client/clientapp/utils.py +1 -1
  45. flwr/client/dpfedavg_numpy_client.py +1 -1
  46. flwr/client/grpc_adapter_client/__init__.py +1 -1
  47. flwr/client/grpc_adapter_client/connection.py +5 -5
  48. flwr/client/grpc_rere_client/__init__.py +1 -1
  49. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  50. flwr/client/grpc_rere_client/connection.py +131 -61
  51. flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
  52. flwr/client/message_handler/__init__.py +1 -1
  53. flwr/client/message_handler/message_handler.py +2 -2
  54. flwr/client/mod/__init__.py +1 -1
  55. flwr/client/mod/centraldp_mods.py +1 -1
  56. flwr/client/mod/comms_mods.py +39 -20
  57. flwr/client/mod/localdp_mod.py +6 -6
  58. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  59. flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
  60. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  61. flwr/client/mod/utils.py +1 -1
  62. flwr/client/numpy_client.py +1 -1
  63. flwr/client/rest_client/__init__.py +1 -1
  64. flwr/client/rest_client/connection.py +174 -68
  65. flwr/client/run_info_store.py +1 -1
  66. flwr/client/typing.py +1 -1
  67. flwr/clientapp/__init__.py +15 -0
  68. flwr/common/__init__.py +3 -3
  69. flwr/common/address.py +1 -1
  70. flwr/common/args.py +1 -1
  71. flwr/common/auth_plugin/__init__.py +3 -1
  72. flwr/common/auth_plugin/auth_plugin.py +30 -4
  73. flwr/common/config.py +1 -1
  74. flwr/common/constant.py +37 -8
  75. flwr/common/context.py +1 -1
  76. flwr/common/date.py +1 -1
  77. flwr/common/differential_privacy.py +1 -1
  78. flwr/common/differential_privacy_constants.py +1 -1
  79. flwr/common/dp.py +1 -1
  80. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  81. flwr/common/exit/exit.py +6 -6
  82. flwr/common/exit_handlers.py +31 -1
  83. flwr/common/grpc.py +1 -1
  84. flwr/common/heartbeat.py +165 -0
  85. flwr/common/inflatable.py +290 -0
  86. flwr/common/inflatable_grpc_utils.py +99 -0
  87. flwr/common/inflatable_rest_utils.py +99 -0
  88. flwr/common/inflatable_utils.py +341 -0
  89. flwr/common/logger.py +1 -1
  90. flwr/common/message.py +137 -252
  91. flwr/common/object_ref.py +1 -1
  92. flwr/common/parameter.py +1 -1
  93. flwr/common/pyproject.py +1 -1
  94. flwr/common/record/__init__.py +3 -2
  95. flwr/common/record/array.py +323 -0
  96. flwr/common/record/arrayrecord.py +121 -243
  97. flwr/common/record/configrecord.py +71 -16
  98. flwr/common/record/conversion_utils.py +2 -2
  99. flwr/common/record/metricrecord.py +71 -20
  100. flwr/common/record/recorddict.py +207 -90
  101. flwr/common/record/typeddict.py +1 -1
  102. flwr/common/recorddict_compat.py +2 -2
  103. flwr/common/retry_invoker.py +15 -11
  104. flwr/common/secure_aggregation/__init__.py +1 -1
  105. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  106. flwr/common/secure_aggregation/crypto/shamir.py +52 -30
  107. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  108. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  109. flwr/common/secure_aggregation/quantization.py +1 -1
  110. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  111. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  112. flwr/common/serde.py +60 -184
  113. flwr/common/serde_utils.py +175 -0
  114. flwr/common/telemetry.py +2 -2
  115. flwr/common/typing.py +6 -4
  116. flwr/common/version.py +1 -1
  117. flwr/compat/__init__.py +15 -0
  118. flwr/compat/client/__init__.py +15 -0
  119. flwr/{client → compat/client}/app.py +71 -211
  120. flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
  121. flwr/{client → compat/client}/grpc_client/connection.py +13 -13
  122. flwr/compat/common/__init__.py +15 -0
  123. flwr/compat/server/__init__.py +15 -0
  124. flwr/compat/server/app.py +174 -0
  125. flwr/compat/simulation/__init__.py +15 -0
  126. flwr/proto/__init__.py +1 -1
  127. flwr/proto/fleet_pb2.py +32 -27
  128. flwr/proto/fleet_pb2.pyi +49 -35
  129. flwr/proto/fleet_pb2_grpc.py +117 -13
  130. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  131. flwr/proto/heartbeat_pb2.py +33 -0
  132. flwr/proto/heartbeat_pb2.pyi +66 -0
  133. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  134. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  135. flwr/proto/message_pb2.py +28 -11
  136. flwr/proto/message_pb2.pyi +125 -0
  137. flwr/proto/recorddict_pb2.py +16 -28
  138. flwr/proto/recorddict_pb2.pyi +46 -64
  139. flwr/proto/run_pb2.py +24 -32
  140. flwr/proto/run_pb2.pyi +4 -52
  141. flwr/proto/serverappio_pb2.py +32 -23
  142. flwr/proto/serverappio_pb2.pyi +45 -3
  143. flwr/proto/serverappio_pb2_grpc.py +138 -34
  144. flwr/proto/serverappio_pb2_grpc.pyi +54 -13
  145. flwr/proto/simulationio_pb2.py +12 -11
  146. flwr/proto/simulationio_pb2_grpc.py +35 -0
  147. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  148. flwr/server/__init__.py +2 -2
  149. flwr/server/app.py +69 -187
  150. flwr/server/client_manager.py +1 -1
  151. flwr/server/client_proxy.py +1 -1
  152. flwr/server/compat/__init__.py +1 -1
  153. flwr/server/compat/app.py +1 -1
  154. flwr/server/compat/app_utils.py +51 -29
  155. flwr/server/compat/legacy_context.py +1 -1
  156. flwr/server/criterion.py +1 -1
  157. flwr/server/fleet_event_log_interceptor.py +2 -2
  158. flwr/server/grid/grid.py +3 -3
  159. flwr/server/grid/grpc_grid.py +104 -34
  160. flwr/server/grid/inmemory_grid.py +5 -4
  161. flwr/server/history.py +1 -1
  162. flwr/server/run_serverapp.py +1 -1
  163. flwr/server/server.py +1 -1
  164. flwr/server/server_app.py +65 -58
  165. flwr/server/server_config.py +1 -1
  166. flwr/server/serverapp/__init__.py +1 -1
  167. flwr/server/serverapp/app.py +19 -1
  168. flwr/server/serverapp_components.py +1 -1
  169. flwr/server/strategy/__init__.py +1 -1
  170. flwr/server/strategy/aggregate.py +1 -1
  171. flwr/server/strategy/bulyan.py +2 -2
  172. flwr/server/strategy/dp_adaptive_clipping.py +17 -17
  173. flwr/server/strategy/dp_fixed_clipping.py +17 -17
  174. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  175. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  176. flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
  177. flwr/server/strategy/fedadagrad.py +1 -1
  178. flwr/server/strategy/fedadam.py +1 -1
  179. flwr/server/strategy/fedavg.py +1 -1
  180. flwr/server/strategy/fedavg_android.py +1 -1
  181. flwr/server/strategy/fedavgm.py +1 -1
  182. flwr/server/strategy/fedmedian.py +1 -1
  183. flwr/server/strategy/fedopt.py +1 -1
  184. flwr/server/strategy/fedprox.py +1 -1
  185. flwr/server/strategy/fedtrimmedavg.py +1 -1
  186. flwr/server/strategy/fedxgb_bagging.py +1 -1
  187. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  188. flwr/server/strategy/fedxgb_nn_avg.py +3 -2
  189. flwr/server/strategy/fedyogi.py +1 -1
  190. flwr/server/strategy/krum.py +1 -1
  191. flwr/server/strategy/qfedavg.py +1 -1
  192. flwr/server/strategy/strategy.py +1 -1
  193. flwr/server/superlink/__init__.py +1 -1
  194. flwr/server/superlink/ffs/__init__.py +3 -1
  195. flwr/server/superlink/ffs/disk_ffs.py +1 -1
  196. flwr/server/superlink/ffs/ffs.py +1 -1
  197. flwr/server/superlink/ffs/ffs_factory.py +1 -1
  198. flwr/server/superlink/fleet/__init__.py +1 -1
  199. flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
  200. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
  201. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  202. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  203. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
  206. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  208. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
  209. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  210. flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
  211. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  212. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
  213. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  214. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  215. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  216. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  217. flwr/server/superlink/fleet/vce/vce_api.py +7 -4
  218. flwr/server/superlink/linkstate/__init__.py +1 -1
  219. flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
  220. flwr/server/superlink/linkstate/linkstate.py +54 -21
  221. flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
  222. flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
  223. flwr/server/superlink/linkstate/utils.py +34 -30
  224. flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
  225. flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
  226. flwr/server/superlink/simulation/__init__.py +1 -1
  227. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  228. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  229. flwr/server/superlink/utils.py +45 -3
  230. flwr/server/typing.py +1 -1
  231. flwr/server/utils/__init__.py +1 -1
  232. flwr/server/utils/tensorboard.py +1 -1
  233. flwr/server/utils/validator.py +3 -3
  234. flwr/server/workflow/__init__.py +1 -1
  235. flwr/server/workflow/constant.py +1 -1
  236. flwr/server/workflow/default_workflows.py +1 -1
  237. flwr/server/workflow/secure_aggregation/__init__.py +1 -1
  238. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
  239. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
  240. flwr/serverapp/__init__.py +15 -0
  241. flwr/simulation/__init__.py +1 -1
  242. flwr/simulation/app.py +18 -1
  243. flwr/simulation/legacy_app.py +1 -1
  244. flwr/simulation/ray_transport/__init__.py +1 -1
  245. flwr/simulation/ray_transport/ray_actor.py +1 -1
  246. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  247. flwr/simulation/ray_transport/utils.py +1 -1
  248. flwr/simulation/run_simulation.py +2 -2
  249. flwr/simulation/simulationio_connection.py +1 -1
  250. flwr/supercore/__init__.py +15 -0
  251. flwr/supercore/object_store/__init__.py +24 -0
  252. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  253. flwr/supercore/object_store/object_store.py +192 -0
  254. flwr/supercore/object_store/object_store_factory.py +44 -0
  255. flwr/superexec/__init__.py +1 -1
  256. flwr/superexec/app.py +1 -1
  257. flwr/superexec/deployment.py +7 -3
  258. flwr/superexec/exec_event_log_interceptor.py +4 -4
  259. flwr/superexec/exec_grpc.py +8 -4
  260. flwr/superexec/exec_servicer.py +126 -24
  261. flwr/superexec/exec_user_auth_interceptor.py +38 -9
  262. flwr/superexec/executor.py +5 -1
  263. flwr/superexec/simulation.py +8 -2
  264. flwr/superlink/__init__.py +15 -0
  265. flwr/{client/supernode → supernode}/__init__.py +1 -8
  266. flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
  267. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
  268. flwr/supernode/cli/flwr_clientapp.py +81 -0
  269. flwr/{client → supernode}/nodestate/__init__.py +1 -1
  270. flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
  271. flwr/supernode/nodestate/nodestate.py +212 -0
  272. flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
  273. flwr/supernode/runtime/__init__.py +15 -0
  274. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
  275. flwr/supernode/servicer/__init__.py +15 -0
  276. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  277. flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
  278. flwr/supernode/start_client_internal.py +491 -0
  279. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
  280. flwr-1.19.0.dist-info/RECORD +365 -0
  281. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
  282. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
  283. flwr/client/heartbeat.py +0 -74
  284. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  285. flwr-1.17.0.dist-info/LICENSE +0 -202
  286. flwr-1.17.0.dist-info/RECORD +0 -333
@@ -12,38 +12,31 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """ArrayRecord and Array."""
15
+ """ArrayRecord."""
16
16
 
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
20
  import gc
21
+ import json
21
22
  import sys
22
23
  from collections import OrderedDict
23
- from dataclasses import dataclass
24
- from io import BytesIO
25
24
  from logging import WARN
26
25
  from typing import TYPE_CHECKING, Any, cast, overload
27
26
 
28
27
  import numpy as np
29
28
 
30
- from ..constant import GC_THRESHOLD, SType
29
+ from ..constant import GC_THRESHOLD
30
+ from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
31
31
  from ..logger import log
32
32
  from ..typing import NDArray
33
+ from .array import Array
33
34
  from .typeddict import TypedDict
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  import torch
37
38
 
38
39
 
39
- def _raise_array_init_error() -> None:
40
- raise TypeError(
41
- f"Invalid arguments for {Array.__qualname__}. Expected either a "
42
- "PyTorch tensor, a NumPy ndarray, or explicit"
43
- " dtype/shape/stype/data values."
44
- )
45
-
46
-
47
40
  def _raise_array_record_init_error() -> None:
48
41
  raise TypeError(
49
42
  f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
@@ -52,217 +45,6 @@ def _raise_array_record_init_error() -> None:
52
45
  )
53
46
 
54
47
 
55
- @dataclass
56
- class Array:
57
- """Array type.
58
-
59
- A dataclass containing serialized data from an array-like or tensor-like object
60
- along with metadata about it. The class can be initialized in one of three ways:
61
-
62
- 1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
63
- 2. By providing a NumPy ndarray (via the `ndarray` argument).
64
- 3. By providing a PyTorch tensor (via the `torch_tensor` argument).
65
-
66
- In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
67
- derived from the input. In scenario (1), these fields must be specified manually.
68
-
69
- Parameters
70
- ----------
71
- dtype : Optional[str] (default: None)
72
- A string representing the data type of the serialized object (e.g. `"float32"`).
73
- Only required if you are not passing in a ndarray or a tensor.
74
-
75
- shape : Optional[list[int]] (default: None)
76
- A list representing the shape of the unserialized array-like object. Only
77
- required if you are not passing in a ndarray or a tensor.
78
-
79
- stype : Optional[str] (default: None)
80
- A string indicating the serialization mechanism used to generate the bytes in
81
- `data` from an array-like or tensor-like object. Only required if you are not
82
- passing in a ndarray or a tensor.
83
-
84
- data : Optional[bytes] (default: None)
85
- A buffer of bytes containing the data. Only required if you are not passing in
86
- a ndarray or a tensor.
87
-
88
- ndarray : Optional[NDArray] (default: None)
89
- A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
90
- fields are derived automatically from it.
91
-
92
- torch_tensor : Optional[torch.Tensor] (default: None)
93
- A PyTorch tensor. If provided, it will be **detached and moved to CPU**
94
- before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
95
- will be derived automatically from it.
96
-
97
- Examples
98
- --------
99
- Initializing by specifying all fields directly:
100
-
101
- >>> arr1 = Array(
102
- >>> dtype="float32",
103
- >>> shape=[3, 3],
104
- >>> stype="numpy.ndarray",
105
- >>> data=b"serialized_data...",
106
- >>> )
107
-
108
- Initializing with a NumPy ndarray:
109
-
110
- >>> import numpy as np
111
- >>> arr2 = Array(np.random.randn(3, 3))
112
-
113
- Initializing with a PyTorch tensor:
114
-
115
- >>> import torch
116
- >>> arr3 = Array(torch.randn(3, 3))
117
- """
118
-
119
- dtype: str
120
- shape: list[int]
121
- stype: str
122
- data: bytes
123
-
124
- @overload
125
- def __init__( # noqa: E704
126
- self, dtype: str, shape: list[int], stype: str, data: bytes
127
- ) -> None: ...
128
-
129
- @overload
130
- def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
131
-
132
- @overload
133
- def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
134
-
135
- def __init__( # pylint: disable=too-many-arguments, too-many-locals
136
- self,
137
- *args: Any,
138
- dtype: str | None = None,
139
- shape: list[int] | None = None,
140
- stype: str | None = None,
141
- data: bytes | None = None,
142
- ndarray: NDArray | None = None,
143
- torch_tensor: torch.Tensor | None = None,
144
- ) -> None:
145
- # Determine the initialization method and validate input arguments.
146
- # Support three initialization formats:
147
- # 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
148
- # 2. Array(ndarray: NDArray)
149
- # 3. Array(torch_tensor: torch.Tensor)
150
-
151
- # Initialize all arguments
152
- # If more than 4 positional arguments are provided, raise an error.
153
- if len(args) > 4:
154
- _raise_array_init_error()
155
- all_args = [None] * 4
156
- for i, arg in enumerate(args):
157
- all_args[i] = arg
158
- init_method: str | None = None # Track which init method is being used
159
-
160
- # Try to assign a value to all_args[index] if it's not already set.
161
- # If an initialization method is provided, update init_method.
162
- def _try_set_arg(index: int, arg: Any, method: str) -> None:
163
- # Skip if arg is None
164
- if arg is None:
165
- return
166
- # Raise an error if all_args[index] is already set
167
- if all_args[index] is not None:
168
- _raise_array_init_error()
169
- # Raise an error if a different initialization method is already set
170
- nonlocal init_method
171
- if init_method is not None and init_method != method:
172
- _raise_array_init_error()
173
- # Set init_method and all_args[index]
174
- if init_method is None:
175
- init_method = method
176
- all_args[index] = arg
177
-
178
- # Try to set keyword arguments in all_args
179
- _try_set_arg(0, dtype, "direct")
180
- _try_set_arg(1, shape, "direct")
181
- _try_set_arg(2, stype, "direct")
182
- _try_set_arg(3, data, "direct")
183
- _try_set_arg(0, ndarray, "ndarray")
184
- _try_set_arg(0, torch_tensor, "torch_tensor")
185
-
186
- # Check if all arguments are correctly set
187
- all_args = [arg for arg in all_args if arg is not None]
188
-
189
- # Handle direct field initialization
190
- if not init_method or init_method == "direct":
191
- if (
192
- len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
193
- and isinstance(all_args[0], str)
194
- and isinstance(all_args[1], list)
195
- and all(isinstance(i, int) for i in all_args[1])
196
- and isinstance(all_args[2], str)
197
- and isinstance(all_args[3], bytes)
198
- ):
199
- self.dtype, self.shape, self.stype, self.data = all_args
200
- return
201
-
202
- # Handle NumPy array
203
- if not init_method or init_method == "ndarray":
204
- if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
205
- self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
206
- return
207
-
208
- # Handle PyTorch tensor
209
- if not init_method or init_method == "torch_tensor":
210
- if (
211
- len(all_args) == 1
212
- and "torch" in sys.modules
213
- and isinstance(all_args[0], sys.modules["torch"].Tensor)
214
- ):
215
- self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
216
- return
217
-
218
- _raise_array_init_error()
219
-
220
- @classmethod
221
- def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
222
- """Create Array from NumPy ndarray."""
223
- assert isinstance(
224
- ndarray, np.ndarray
225
- ), f"Expected NumPy ndarray, got {type(ndarray)}"
226
- buffer = BytesIO()
227
- # WARNING: NEVER set allow_pickle to true.
228
- # Reason: loading pickled data can execute arbitrary code
229
- # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
230
- np.save(buffer, ndarray, allow_pickle=False)
231
- data = buffer.getvalue()
232
- return Array(
233
- dtype=str(ndarray.dtype),
234
- shape=list(ndarray.shape),
235
- stype=SType.NUMPY,
236
- data=data,
237
- )
238
-
239
- @classmethod
240
- def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
241
- """Create Array from PyTorch tensor."""
242
- if not (torch := sys.modules.get("torch")):
243
- raise RuntimeError(
244
- f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
245
- )
246
-
247
- assert isinstance(
248
- tensor, torch.Tensor
249
- ), f"Expected PyTorch Tensor, got {type(tensor)}"
250
- return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
251
-
252
- def numpy(self) -> NDArray:
253
- """Return the array as a NumPy array."""
254
- if self.stype != SType.NUMPY:
255
- raise TypeError(
256
- f"Unsupported serialization type for numpy conversion: '{self.stype}'"
257
- )
258
- bytes_io = BytesIO(self.data)
259
- # WARNING: NEVER set allow_pickle to true.
260
- # Reason: loading pickled data can execute arbitrary code
261
- # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
262
- ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
263
- return cast(NDArray, ndarray_deserialized)
264
-
265
-
266
48
  def _check_key(key: str) -> None:
267
49
  """Check if key is of expected type."""
268
50
  if not isinstance(key, str):
@@ -276,7 +58,7 @@ def _check_value(value: Array) -> None:
276
58
  )
277
59
 
278
60
 
279
- class ArrayRecord(TypedDict[str, Array]):
61
+ class ArrayRecord(TypedDict[str, Array], InflatableObject):
280
62
  """Array record.
281
63
 
282
64
  A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
@@ -315,33 +97,33 @@ class ArrayRecord(TypedDict[str, Array]):
315
97
 
316
98
  Examples
317
99
  --------
318
- Initializing an empty ArrayRecord:
100
+ Initializing an empty ArrayRecord::
319
101
 
320
- >>> record = ArrayRecord()
102
+ record = ArrayRecord()
321
103
 
322
- Initializing with a dictionary of :class:`Array`:
104
+ Initializing with a dictionary of :class:`Array`::
323
105
 
324
- >>> arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
325
- >>> record = ArrayRecord({"weight": arr})
106
+ arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
107
+ record = ArrayRecord({"weight": arr})
326
108
 
327
- Initializing with a list of NumPy arrays:
109
+ Initializing with a list of NumPy arrays::
328
110
 
329
- >>> import numpy as np
330
- >>> arr1 = np.random.randn(3, 3)
331
- >>> arr2 = np.random.randn(2, 2)
332
- >>> record = ArrayRecord([arr1, arr2])
111
+ import numpy as np
112
+ arr1 = np.random.randn(3, 3)
113
+ arr2 = np.random.randn(2, 2)
114
+ record = ArrayRecord([arr1, arr2])
333
115
 
334
- Initializing with a PyTorch model state_dict:
116
+ Initializing with a PyTorch model state_dict::
335
117
 
336
- >>> import torch.nn as nn
337
- >>> model = nn.Linear(10, 5)
338
- >>> record = ArrayRecord(model.state_dict())
118
+ import torch.nn as nn
119
+ model = nn.Linear(10, 5)
120
+ record = ArrayRecord(model.state_dict())
339
121
 
340
- Initializing with a TensorFlow model weights (a list of NumPy arrays):
122
+ Initializing with a TensorFlow model weights (a list of NumPy arrays)::
341
123
 
342
- >>> import tensorflow as tf
343
- >>> model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
344
- >>> record = ArrayRecord(model.get_weights())
124
+ import tensorflow as tf
125
+ model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
126
+ record = ArrayRecord(model.get_weights())
345
127
  """
346
128
 
347
129
  @overload
@@ -470,7 +252,7 @@ class ArrayRecord(TypedDict[str, Array]):
470
252
  record = ArrayRecord()
471
253
  for k, v in array_dict.items():
472
254
  record[k] = Array(
473
- dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
255
+ dtype=v.dtype, shape=tuple(v.shape), stype=v.stype, data=v.data
474
256
  )
475
257
  if not keep_input:
476
258
  array_dict.clear()
@@ -585,6 +367,102 @@ class ArrayRecord(TypedDict[str, Array]):
585
367
 
586
368
  return num_bytes
587
369
 
370
+ @property
371
+ def children(self) -> dict[str, InflatableObject]:
372
+ """Return a dictionary of Arrays with their Object IDs as keys."""
373
+ return {arr.object_id: arr for arr in self.values()}
374
+
375
+ def deflate(self) -> bytes:
376
+ """Deflate the ArrayRecord."""
377
+ # array_name: array_object_id mapping
378
+ array_refs: dict[str, str] = {}
379
+
380
+ for array_name, array in self.items():
381
+ array_refs[array_name] = array.object_id
382
+
383
+ # Serialize references dict
384
+ object_body = json.dumps(array_refs).encode("utf-8")
385
+ return add_header_to_object_body(object_body=object_body, obj=self)
386
+
387
+ @classmethod
388
+ def inflate(
389
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
390
+ ) -> ArrayRecord:
391
+ """Inflate an ArrayRecord from bytes.
392
+
393
+ Parameters
394
+ ----------
395
+ object_content : bytes
396
+ The deflated object content of the ArrayRecord.
397
+ children : Optional[dict[str, InflatableObject]] (default: None)
398
+ Dictionary of children InflatableObjects mapped to their Object IDs.
399
+ These children enable the full inflation of the ArrayRecord.
400
+
401
+ Returns
402
+ -------
403
+ ArrayRecord
404
+ The inflated ArrayRecord.
405
+ """
406
+ if children is None:
407
+ children = {}
408
+
409
+ # Inflate mapping of array_names (keys in the ArrayRecord) to Arrays' object IDs
410
+ obj_body = get_object_body(object_content, cls)
411
+ array_refs: dict[str, str] = json.loads(obj_body.decode(encoding="utf-8"))
412
+
413
+ unique_arrays = set(array_refs.values())
414
+ children_obj_ids = set(children.keys())
415
+ if unique_arrays != children_obj_ids:
416
+ raise ValueError(
417
+ "Unexpected set of `children`. "
418
+ f"Expected {unique_arrays} but got {children_obj_ids}."
419
+ )
420
+
421
+ # Ensure children are of type Array
422
+ if not all(isinstance(arr, Array) for arr in children.values()):
423
+ raise ValueError("`Children` are expected to be of type `Array`.")
424
+
425
+ # Instantiate new ArrayRecord
426
+ return ArrayRecord(
427
+ OrderedDict(
428
+ {name: children[object_id] for name, object_id in array_refs.items()}
429
+ )
430
+ )
431
+
432
+ @property
433
+ def object_id(self) -> str:
434
+ """Get object ID."""
435
+ ret = super().object_id
436
+ self.is_dirty = False # Reset dirty flag
437
+ return ret
438
+
439
+ @property
440
+ def is_dirty(self) -> bool:
441
+ """Check if the object is dirty after the last deflation."""
442
+ if "_is_dirty" not in self.__dict__:
443
+ self.__dict__["_is_dirty"] = True
444
+
445
+ if not self.__dict__["_is_dirty"]:
446
+ if any(v.is_dirty for v in self.values()):
447
+ # If any Array is dirty, mark the record as dirty
448
+ self.__dict__["_is_dirty"] = True
449
+ return cast(bool, self.__dict__["_is_dirty"])
450
+
451
+ @is_dirty.setter
452
+ def is_dirty(self, value: bool) -> None:
453
+ """Set the dirty flag."""
454
+ self.__dict__["_is_dirty"] = value
455
+
456
+ def __setitem__(self, key: str, value: Array) -> None:
457
+ """Set item and mark the record as dirty."""
458
+ self.is_dirty = True # Mark as dirty when setting an item
459
+ super().__setitem__(key, value)
460
+
461
+ def __delitem__(self, key: str) -> None:
462
+ """Delete item and mark the record as dirty."""
463
+ self.is_dirty = True # Mark as dirty when deleting an item
464
+ super().__delitem__(key)
465
+
588
466
 
589
467
  class ParametersRecord(ArrayRecord):
590
468
  """Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
@@ -15,12 +15,21 @@
15
15
  """ConfigRecord."""
16
16
 
17
17
 
18
+ from __future__ import annotations
19
+
18
20
  from logging import WARN
19
- from typing import Optional, get_args
21
+ from typing import cast, get_args
20
22
 
21
23
  from flwr.common.typing import ConfigRecordValues, ConfigScalar
22
24
 
25
+ # pylint: disable=E0611
26
+ from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
27
+ from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
28
+
29
+ # pylint: enable=E0611
30
+ from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
23
31
  from ..logger import log
32
+ from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
24
33
  from .typeddict import TypedDict
25
34
 
26
35
 
@@ -59,8 +68,8 @@ def _check_value(value: ConfigRecordValues) -> None:
59
68
  is_valid(value)
60
69
 
61
70
 
62
- class ConfigRecord(TypedDict[str, ConfigRecordValues]):
63
- """Configs record.
71
+ class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
72
+ """Config record.
64
73
 
65
74
  A :code:`ConfigRecord` is a Python dictionary designed to ensure that
66
75
  each key-value pair adheres to specified data types. A :code:`ConfigRecord`
@@ -90,18 +99,18 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
90
99
  encourage you to use a :code:`ArrayRecord` instead if these are of high
91
100
  dimensionality.
92
101
 
93
- Let's see some examples of how to construct a :code:`ConfigRecord` from scratch:
102
+ Let's see some examples of how to construct a :code:`ConfigRecord` from scratch::
103
+
104
+ from flwr.common import ConfigRecord
94
105
 
95
- >>> from flwr.common import ConfigRecord
96
- >>>
97
- >>> # A `ConfigRecord` is a specialized Python dictionary
98
- >>> record = ConfigRecord({"lr": 0.1, "batch-size": 128})
99
- >>> # You can add more content to an existing record
100
- >>> record["compute-average"] = True
101
- >>> # It also supports lists
102
- >>> record["loss-fn-coefficients"] = [0.4, 0.25, 0.35]
103
- >>> # And string values (among other types)
104
- >>> record["path-to-S3"] = "s3://bucket_name/folder1/fileA.json"
106
+ # A `ConfigRecord` is a specialized Python dictionary
107
+ record = ConfigRecord({"lr": 0.1, "batch-size": 128})
108
+ # You can add more content to an existing record
109
+ record["compute-average"] = True
110
+ # It also supports lists
111
+ record["loss-fn-coefficients"] = [0.4, 0.25, 0.35]
112
+ # And string values (among other types)
113
+ record["path-to-S3"] = "s3://bucket_name/folder1/fileA.json"
105
114
 
106
115
  Just like the other types of records in a :code:`flwr.common.RecordDict`, types are
107
116
  enforced. If you need to add a custom data structure or object, we recommend to
@@ -111,7 +120,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
111
120
 
112
121
  def __init__(
113
122
  self,
114
- config_dict: Optional[dict[str, ConfigRecordValues]] = None,
123
+ config_dict: dict[str, ConfigRecordValues] | None = None,
115
124
  keep_input: bool = True,
116
125
  ) -> None:
117
126
 
@@ -164,6 +173,52 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
164
173
 
165
174
  return num_bytes
166
175
 
176
+ def deflate(self) -> bytes:
177
+ """Deflate object."""
178
+ protos = record_value_dict_to_proto(
179
+ self,
180
+ [bool, int, float, str, bytes],
181
+ ProtoConfigRecordValue,
182
+ )
183
+ obj_body = ProtoConfigRecord(
184
+ items=[ProtoConfigRecord.Item(key=k, value=v) for k, v in protos.items()]
185
+ ).SerializeToString()
186
+ return add_header_to_object_body(object_body=obj_body, obj=self)
187
+
188
+ @classmethod
189
+ def inflate(
190
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
191
+ ) -> ConfigRecord:
192
+ """Inflate a ConfigRecord from bytes.
193
+
194
+ Parameters
195
+ ----------
196
+ object_content : bytes
197
+ The deflated object content of the ConfigRecord.
198
+
199
+ children : Optional[dict[str, InflatableObject]] (default: None)
200
+ Must be ``None``. ``ConfigRecord`` does not support child objects.
201
+ Providing any children will raise a ``ValueError``.
202
+
203
+ Returns
204
+ -------
205
+ ConfigRecord
206
+ The inflated ConfigRecord.
207
+ """
208
+ if children:
209
+ raise ValueError("`ConfigRecord` objects do not have children.")
210
+
211
+ obj_body = get_object_body(object_content, cls)
212
+ config_record_proto = ProtoConfigRecord.FromString(obj_body)
213
+ protos = {item.key: item.value for item in config_record_proto.items}
214
+ return ConfigRecord(
215
+ config_dict=cast(
216
+ dict[str, ConfigRecordValues],
217
+ record_value_dict_from_proto(protos),
218
+ ),
219
+ keep_input=False,
220
+ )
221
+
167
222
 
168
223
  class ConfigsRecord(ConfigRecord):
169
224
  """Deprecated class ``ConfigsRecord``, use ``ConfigRecord`` instead.
@@ -195,7 +250,7 @@ class ConfigsRecord(ConfigRecord):
195
250
 
196
251
  def __init__(
197
252
  self,
198
- config_dict: Optional[dict[str, ConfigRecordValues]] = None,
253
+ config_dict: dict[str, ConfigRecordValues] | None = None,
199
254
  keep_input: bool = True,
200
255
  ):
201
256
  if not ConfigsRecord._warning_logged:
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
17
17
 
18
18
  from ..logger import warn_deprecated_feature
19
19
  from ..typing import NDArray
20
- from .arrayrecord import Array
20
+ from .array import Array
21
21
 
22
22
  WARN_DEPRECATED_MESSAGE = (
23
23
  "`array_from_numpy` is deprecated. Instead, use the `Array(ndarray)` class "