flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/app/__init__.py +15 -0
  3. flwr/app/error.py +68 -0
  4. flwr/app/metadata.py +223 -0
  5. flwr/cli/__init__.py +1 -1
  6. flwr/cli/app.py +21 -2
  7. flwr/cli/build.py +83 -58
  8. flwr/cli/cli_user_auth_interceptor.py +1 -1
  9. flwr/cli/config_utils.py +53 -17
  10. flwr/cli/example.py +1 -1
  11. flwr/cli/install.py +1 -1
  12. flwr/cli/log.py +4 -4
  13. flwr/cli/login/__init__.py +1 -1
  14. flwr/cli/login/login.py +15 -8
  15. flwr/cli/ls.py +16 -37
  16. flwr/cli/new/__init__.py +1 -1
  17. flwr/cli/new/new.py +4 -4
  18. flwr/cli/new/templates/__init__.py +1 -1
  19. flwr/cli/new/templates/app/__init__.py +1 -1
  20. flwr/cli/new/templates/app/code/__init__.py +1 -1
  21. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  22. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
  24. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  25. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  26. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
  28. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  29. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  34. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  35. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  36. flwr/cli/run/__init__.py +1 -1
  37. flwr/cli/run/run.py +11 -19
  38. flwr/cli/stop.py +3 -3
  39. flwr/cli/utils.py +42 -17
  40. flwr/client/__init__.py +3 -3
  41. flwr/client/client.py +1 -1
  42. flwr/client/client_app.py +140 -138
  43. flwr/client/clientapp/__init__.py +1 -8
  44. flwr/client/clientapp/utils.py +1 -1
  45. flwr/client/dpfedavg_numpy_client.py +1 -1
  46. flwr/client/grpc_adapter_client/__init__.py +1 -1
  47. flwr/client/grpc_adapter_client/connection.py +5 -5
  48. flwr/client/grpc_rere_client/__init__.py +1 -1
  49. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  50. flwr/client/grpc_rere_client/connection.py +131 -61
  51. flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
  52. flwr/client/message_handler/__init__.py +1 -1
  53. flwr/client/message_handler/message_handler.py +2 -2
  54. flwr/client/mod/__init__.py +1 -1
  55. flwr/client/mod/centraldp_mods.py +1 -1
  56. flwr/client/mod/comms_mods.py +39 -20
  57. flwr/client/mod/localdp_mod.py +6 -6
  58. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  59. flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
  60. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  61. flwr/client/mod/utils.py +1 -1
  62. flwr/client/numpy_client.py +1 -1
  63. flwr/client/rest_client/__init__.py +1 -1
  64. flwr/client/rest_client/connection.py +174 -68
  65. flwr/client/run_info_store.py +1 -1
  66. flwr/client/typing.py +1 -1
  67. flwr/clientapp/__init__.py +15 -0
  68. flwr/common/__init__.py +3 -3
  69. flwr/common/address.py +1 -1
  70. flwr/common/args.py +1 -1
  71. flwr/common/auth_plugin/__init__.py +3 -1
  72. flwr/common/auth_plugin/auth_plugin.py +30 -4
  73. flwr/common/config.py +1 -1
  74. flwr/common/constant.py +37 -8
  75. flwr/common/context.py +1 -1
  76. flwr/common/date.py +1 -1
  77. flwr/common/differential_privacy.py +1 -1
  78. flwr/common/differential_privacy_constants.py +1 -1
  79. flwr/common/dp.py +1 -1
  80. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  81. flwr/common/exit/exit.py +6 -6
  82. flwr/common/exit_handlers.py +31 -1
  83. flwr/common/grpc.py +1 -1
  84. flwr/common/heartbeat.py +165 -0
  85. flwr/common/inflatable.py +290 -0
  86. flwr/common/inflatable_grpc_utils.py +99 -0
  87. flwr/common/inflatable_rest_utils.py +99 -0
  88. flwr/common/inflatable_utils.py +341 -0
  89. flwr/common/logger.py +1 -1
  90. flwr/common/message.py +137 -252
  91. flwr/common/object_ref.py +1 -1
  92. flwr/common/parameter.py +1 -1
  93. flwr/common/pyproject.py +1 -1
  94. flwr/common/record/__init__.py +3 -2
  95. flwr/common/record/array.py +323 -0
  96. flwr/common/record/arrayrecord.py +121 -243
  97. flwr/common/record/configrecord.py +71 -16
  98. flwr/common/record/conversion_utils.py +2 -2
  99. flwr/common/record/metricrecord.py +71 -20
  100. flwr/common/record/recorddict.py +207 -90
  101. flwr/common/record/typeddict.py +1 -1
  102. flwr/common/recorddict_compat.py +2 -2
  103. flwr/common/retry_invoker.py +15 -11
  104. flwr/common/secure_aggregation/__init__.py +1 -1
  105. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  106. flwr/common/secure_aggregation/crypto/shamir.py +52 -30
  107. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  108. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  109. flwr/common/secure_aggregation/quantization.py +1 -1
  110. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  111. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  112. flwr/common/serde.py +60 -184
  113. flwr/common/serde_utils.py +175 -0
  114. flwr/common/telemetry.py +2 -2
  115. flwr/common/typing.py +6 -4
  116. flwr/common/version.py +1 -1
  117. flwr/compat/__init__.py +15 -0
  118. flwr/compat/client/__init__.py +15 -0
  119. flwr/{client → compat/client}/app.py +71 -211
  120. flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
  121. flwr/{client → compat/client}/grpc_client/connection.py +13 -13
  122. flwr/compat/common/__init__.py +15 -0
  123. flwr/compat/server/__init__.py +15 -0
  124. flwr/compat/server/app.py +174 -0
  125. flwr/compat/simulation/__init__.py +15 -0
  126. flwr/proto/__init__.py +1 -1
  127. flwr/proto/fleet_pb2.py +32 -27
  128. flwr/proto/fleet_pb2.pyi +49 -35
  129. flwr/proto/fleet_pb2_grpc.py +117 -13
  130. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  131. flwr/proto/heartbeat_pb2.py +33 -0
  132. flwr/proto/heartbeat_pb2.pyi +66 -0
  133. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  134. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  135. flwr/proto/message_pb2.py +28 -11
  136. flwr/proto/message_pb2.pyi +125 -0
  137. flwr/proto/recorddict_pb2.py +16 -28
  138. flwr/proto/recorddict_pb2.pyi +46 -64
  139. flwr/proto/run_pb2.py +24 -32
  140. flwr/proto/run_pb2.pyi +4 -52
  141. flwr/proto/serverappio_pb2.py +32 -23
  142. flwr/proto/serverappio_pb2.pyi +45 -3
  143. flwr/proto/serverappio_pb2_grpc.py +138 -34
  144. flwr/proto/serverappio_pb2_grpc.pyi +54 -13
  145. flwr/proto/simulationio_pb2.py +12 -11
  146. flwr/proto/simulationio_pb2_grpc.py +35 -0
  147. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  148. flwr/server/__init__.py +2 -2
  149. flwr/server/app.py +69 -187
  150. flwr/server/client_manager.py +1 -1
  151. flwr/server/client_proxy.py +1 -1
  152. flwr/server/compat/__init__.py +1 -1
  153. flwr/server/compat/app.py +1 -1
  154. flwr/server/compat/app_utils.py +51 -29
  155. flwr/server/compat/legacy_context.py +1 -1
  156. flwr/server/criterion.py +1 -1
  157. flwr/server/fleet_event_log_interceptor.py +2 -2
  158. flwr/server/grid/grid.py +3 -3
  159. flwr/server/grid/grpc_grid.py +104 -34
  160. flwr/server/grid/inmemory_grid.py +5 -4
  161. flwr/server/history.py +1 -1
  162. flwr/server/run_serverapp.py +1 -1
  163. flwr/server/server.py +1 -1
  164. flwr/server/server_app.py +65 -58
  165. flwr/server/server_config.py +1 -1
  166. flwr/server/serverapp/__init__.py +1 -1
  167. flwr/server/serverapp/app.py +19 -1
  168. flwr/server/serverapp_components.py +1 -1
  169. flwr/server/strategy/__init__.py +1 -1
  170. flwr/server/strategy/aggregate.py +1 -1
  171. flwr/server/strategy/bulyan.py +2 -2
  172. flwr/server/strategy/dp_adaptive_clipping.py +17 -17
  173. flwr/server/strategy/dp_fixed_clipping.py +17 -17
  174. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  175. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  176. flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
  177. flwr/server/strategy/fedadagrad.py +1 -1
  178. flwr/server/strategy/fedadam.py +1 -1
  179. flwr/server/strategy/fedavg.py +1 -1
  180. flwr/server/strategy/fedavg_android.py +1 -1
  181. flwr/server/strategy/fedavgm.py +1 -1
  182. flwr/server/strategy/fedmedian.py +1 -1
  183. flwr/server/strategy/fedopt.py +1 -1
  184. flwr/server/strategy/fedprox.py +1 -1
  185. flwr/server/strategy/fedtrimmedavg.py +1 -1
  186. flwr/server/strategy/fedxgb_bagging.py +1 -1
  187. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  188. flwr/server/strategy/fedxgb_nn_avg.py +3 -2
  189. flwr/server/strategy/fedyogi.py +1 -1
  190. flwr/server/strategy/krum.py +1 -1
  191. flwr/server/strategy/qfedavg.py +1 -1
  192. flwr/server/strategy/strategy.py +1 -1
  193. flwr/server/superlink/__init__.py +1 -1
  194. flwr/server/superlink/ffs/__init__.py +3 -1
  195. flwr/server/superlink/ffs/disk_ffs.py +1 -1
  196. flwr/server/superlink/ffs/ffs.py +1 -1
  197. flwr/server/superlink/ffs/ffs_factory.py +1 -1
  198. flwr/server/superlink/fleet/__init__.py +1 -1
  199. flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
  200. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
  201. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  202. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  203. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
  206. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  208. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
  209. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  210. flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
  211. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  212. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
  213. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  214. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  215. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  216. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  217. flwr/server/superlink/fleet/vce/vce_api.py +7 -4
  218. flwr/server/superlink/linkstate/__init__.py +1 -1
  219. flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
  220. flwr/server/superlink/linkstate/linkstate.py +54 -21
  221. flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
  222. flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
  223. flwr/server/superlink/linkstate/utils.py +34 -30
  224. flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
  225. flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
  226. flwr/server/superlink/simulation/__init__.py +1 -1
  227. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  228. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  229. flwr/server/superlink/utils.py +45 -3
  230. flwr/server/typing.py +1 -1
  231. flwr/server/utils/__init__.py +1 -1
  232. flwr/server/utils/tensorboard.py +1 -1
  233. flwr/server/utils/validator.py +3 -3
  234. flwr/server/workflow/__init__.py +1 -1
  235. flwr/server/workflow/constant.py +1 -1
  236. flwr/server/workflow/default_workflows.py +1 -1
  237. flwr/server/workflow/secure_aggregation/__init__.py +1 -1
  238. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
  239. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
  240. flwr/serverapp/__init__.py +15 -0
  241. flwr/simulation/__init__.py +1 -1
  242. flwr/simulation/app.py +18 -1
  243. flwr/simulation/legacy_app.py +1 -1
  244. flwr/simulation/ray_transport/__init__.py +1 -1
  245. flwr/simulation/ray_transport/ray_actor.py +1 -1
  246. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  247. flwr/simulation/ray_transport/utils.py +1 -1
  248. flwr/simulation/run_simulation.py +2 -2
  249. flwr/simulation/simulationio_connection.py +1 -1
  250. flwr/supercore/__init__.py +15 -0
  251. flwr/supercore/object_store/__init__.py +24 -0
  252. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  253. flwr/supercore/object_store/object_store.py +192 -0
  254. flwr/supercore/object_store/object_store_factory.py +44 -0
  255. flwr/superexec/__init__.py +1 -1
  256. flwr/superexec/app.py +1 -1
  257. flwr/superexec/deployment.py +7 -3
  258. flwr/superexec/exec_event_log_interceptor.py +4 -4
  259. flwr/superexec/exec_grpc.py +8 -4
  260. flwr/superexec/exec_servicer.py +126 -24
  261. flwr/superexec/exec_user_auth_interceptor.py +38 -9
  262. flwr/superexec/executor.py +5 -1
  263. flwr/superexec/simulation.py +8 -2
  264. flwr/superlink/__init__.py +15 -0
  265. flwr/{client/supernode → supernode}/__init__.py +1 -8
  266. flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
  267. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
  268. flwr/supernode/cli/flwr_clientapp.py +81 -0
  269. flwr/{client → supernode}/nodestate/__init__.py +1 -1
  270. flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
  271. flwr/supernode/nodestate/nodestate.py +212 -0
  272. flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
  273. flwr/supernode/runtime/__init__.py +15 -0
  274. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
  275. flwr/supernode/servicer/__init__.py +15 -0
  276. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  277. flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
  278. flwr/supernode/start_client_internal.py +491 -0
  279. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
  280. flwr-1.19.0.dist-info/RECORD +365 -0
  281. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
  282. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
  283. flwr/client/heartbeat.py +0 -74
  284. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  285. flwr-1.17.0.dist-info/LICENSE +0 -202
  286. flwr-1.17.0.dist-info/RECORD +0 -333
@@ -0,0 +1,323 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Array."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ import sys
21
+ from dataclasses import dataclass
22
+ from io import BytesIO
23
+ from typing import TYPE_CHECKING, Any, cast, overload
24
+
25
+ import numpy as np
26
+
27
+ from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
28
+
29
+ from ..constant import SType
30
+ from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
31
+ from ..typing import NDArray
32
+
33
+ if TYPE_CHECKING:
34
+ import torch
35
+
36
+
37
+ def _raise_array_init_error() -> None:
38
+ raise TypeError(
39
+ f"Invalid arguments for {Array.__qualname__}. Expected either a "
40
+ "PyTorch tensor, a NumPy ndarray, or explicit"
41
+ " dtype/shape/stype/data values."
42
+ )
43
+
44
+
45
+ @dataclass
46
+ class Array(InflatableObject):
47
+ """Array type.
48
+
49
+ A dataclass containing serialized data from an array-like or tensor-like object
50
+ along with metadata about it. The class can be initialized in one of three ways:
51
+
52
+ 1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
53
+ 2. By providing a NumPy ndarray (via the `ndarray` argument).
54
+ 3. By providing a PyTorch tensor (via the `torch_tensor` argument).
55
+
56
+ In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
57
+ derived from the input. In scenario (1), these fields must be specified manually.
58
+
59
+ Parameters
60
+ ----------
61
+ dtype : Optional[str] (default: None)
62
+ A string representing the data type of the serialized object (e.g. `"float32"`).
63
+ Only required if you are not passing in a ndarray or a tensor.
64
+
65
+ shape : Optional[tuple[int, ...]] (default: None)
66
+ A tuple representing the shape of the unserialized array-like object. Only
67
+ required if you are not passing in a ndarray or a tensor.
68
+
69
+ stype : Optional[str] (default: None)
70
+ A string indicating the serialization mechanism used to generate the bytes in
71
+ `data` from an array-like or tensor-like object. Only required if you are not
72
+ passing in a ndarray or a tensor.
73
+
74
+ data : Optional[bytes] (default: None)
75
+ A buffer of bytes containing the data. Only required if you are not passing in
76
+ a ndarray or a tensor.
77
+
78
+ ndarray : Optional[NDArray] (default: None)
79
+ A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
80
+ fields are derived automatically from it.
81
+
82
+ torch_tensor : Optional[torch.Tensor] (default: None)
83
+ A PyTorch tensor. If provided, it will be **detached and moved to CPU**
84
+ before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
85
+ will be derived automatically from it.
86
+
87
+ Examples
88
+ --------
89
+ Initializing by specifying all fields directly::
90
+
91
+ arr1 = Array(
92
+ dtype="float32",
93
+ shape=[3, 3],
94
+ stype="numpy.ndarray",
95
+ data=b"serialized_data...",
96
+ )
97
+
98
+ Initializing with a NumPy ndarray::
99
+
100
+ import numpy as np
101
+ arr2 = Array(np.random.randn(3, 3))
102
+
103
+ Initializing with a PyTorch tensor::
104
+
105
+ import torch
106
+ arr3 = Array(torch.randn(3, 3))
107
+ """
108
+
109
+ dtype: str
110
+ shape: tuple[int, ...]
111
+ stype: str
112
+ data: bytes
113
+
114
+ @overload
115
+ def __init__( # noqa: E704
116
+ self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
117
+ ) -> None: ...
118
+
119
+ @overload
120
+ def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
121
+
122
+ @overload
123
+ def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
124
+
125
+ def __init__( # pylint: disable=too-many-arguments, too-many-locals
126
+ self,
127
+ *args: Any,
128
+ dtype: str | None = None,
129
+ shape: tuple[int, ...] | None = None,
130
+ stype: str | None = None,
131
+ data: bytes | None = None,
132
+ ndarray: NDArray | None = None,
133
+ torch_tensor: torch.Tensor | None = None,
134
+ ) -> None:
135
+ # Determine the initialization method and validate input arguments.
136
+ # Support three initialization formats:
137
+ # 1. Array(dtype: str, shape: tuple[int, ...], stype: str, data: bytes)
138
+ # 2. Array(ndarray: NDArray)
139
+ # 3. Array(torch_tensor: torch.Tensor)
140
+
141
+ # Initialize all arguments
142
+ # If more than 4 positional arguments are provided, raise an error.
143
+ if len(args) > 4:
144
+ _raise_array_init_error()
145
+ all_args = [None] * 4
146
+ for i, arg in enumerate(args):
147
+ all_args[i] = arg
148
+ init_method: str | None = None # Track which init method is being used
149
+
150
+ # Try to assign a value to all_args[index] if it's not already set.
151
+ # If an initialization method is provided, update init_method.
152
+ def _try_set_arg(index: int, arg: Any, method: str) -> None:
153
+ # Skip if arg is None
154
+ if arg is None:
155
+ return
156
+ # Raise an error if all_args[index] is already set
157
+ if all_args[index] is not None:
158
+ _raise_array_init_error()
159
+ # Raise an error if a different initialization method is already set
160
+ nonlocal init_method
161
+ if init_method is not None and init_method != method:
162
+ _raise_array_init_error()
163
+ # Set init_method and all_args[index]
164
+ if init_method is None:
165
+ init_method = method
166
+ all_args[index] = arg
167
+
168
+ # Try to set keyword arguments in all_args
169
+ _try_set_arg(0, dtype, "direct")
170
+ _try_set_arg(1, shape, "direct")
171
+ _try_set_arg(2, stype, "direct")
172
+ _try_set_arg(3, data, "direct")
173
+ _try_set_arg(0, ndarray, "ndarray")
174
+ _try_set_arg(0, torch_tensor, "torch_tensor")
175
+
176
+ # Check if all arguments are correctly set
177
+ all_args = [arg for arg in all_args if arg is not None]
178
+
179
+ # Handle direct field initialization
180
+ if not init_method or init_method == "direct":
181
+ if (
182
+ len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
183
+ and isinstance(all_args[0], str)
184
+ and isinstance(all_args[1], tuple)
185
+ and all(isinstance(i, int) for i in all_args[1])
186
+ and isinstance(all_args[2], str)
187
+ and isinstance(all_args[3], bytes)
188
+ ):
189
+ self.dtype, self.shape, self.stype, self.data = all_args
190
+ return
191
+
192
+ # Handle NumPy array
193
+ if not init_method or init_method == "ndarray":
194
+ if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
195
+ self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
196
+ return
197
+
198
+ # Handle PyTorch tensor
199
+ if not init_method or init_method == "torch_tensor":
200
+ if (
201
+ len(all_args) == 1
202
+ and "torch" in sys.modules
203
+ and isinstance(all_args[0], sys.modules["torch"].Tensor)
204
+ ):
205
+ self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
206
+ return
207
+
208
+ _raise_array_init_error()
209
+
210
+ @classmethod
211
+ def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
212
+ """Create Array from NumPy ndarray."""
213
+ assert isinstance(
214
+ ndarray, np.ndarray
215
+ ), f"Expected NumPy ndarray, got {type(ndarray)}"
216
+ buffer = BytesIO()
217
+ # WARNING: NEVER set allow_pickle to true.
218
+ # Reason: loading pickled data can execute arbitrary code
219
+ # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
220
+ np.save(buffer, ndarray, allow_pickle=False)
221
+ data = buffer.getvalue()
222
+ return Array(
223
+ dtype=str(ndarray.dtype),
224
+ shape=tuple(ndarray.shape),
225
+ stype=SType.NUMPY,
226
+ data=data,
227
+ )
228
+
229
+ @classmethod
230
+ def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
231
+ """Create Array from PyTorch tensor."""
232
+ if not (torch := sys.modules.get("torch")):
233
+ raise RuntimeError(
234
+ f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
235
+ )
236
+
237
+ assert isinstance(
238
+ tensor, torch.Tensor
239
+ ), f"Expected PyTorch Tensor, got {type(tensor)}"
240
+ return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
241
+
242
+ def numpy(self) -> NDArray:
243
+ """Return the array as a NumPy array."""
244
+ if self.stype != SType.NUMPY:
245
+ raise TypeError(
246
+ f"Unsupported serialization type for numpy conversion: '{self.stype}'"
247
+ )
248
+ bytes_io = BytesIO(self.data)
249
+ # WARNING: NEVER set allow_pickle to true.
250
+ # Reason: loading pickled data can execute arbitrary code
251
+ # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
252
+ ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
253
+ return cast(NDArray, ndarray_deserialized)
254
+
255
+ def deflate(self) -> bytes:
256
+ """Deflate the Array."""
257
+ array_proto = ArrayProto(
258
+ dtype=self.dtype,
259
+ shape=self.shape,
260
+ stype=self.stype,
261
+ data=self.data,
262
+ )
263
+
264
+ obj_body = array_proto.SerializeToString(deterministic=True)
265
+ return add_header_to_object_body(object_body=obj_body, obj=self)
266
+
267
+ @classmethod
268
+ def inflate(
269
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
270
+ ) -> Array:
271
+ """Inflate an Array from bytes.
272
+
273
+ Parameters
274
+ ----------
275
+ object_content : bytes
276
+ The deflated object content of the Array.
277
+
278
+ children : Optional[dict[str, InflatableObject]] (default: None)
279
+ Must be ``None``. ``Array`` does not support child objects.
280
+ Providing any children will raise a ``ValueError``.
281
+
282
+ Returns
283
+ -------
284
+ Array
285
+ The inflated Array.
286
+ """
287
+ if children:
288
+ raise ValueError("`Array` objects do not have children.")
289
+
290
+ obj_body = get_object_body(object_content, cls)
291
+ proto_array = ArrayProto.FromString(obj_body)
292
+ return cls(
293
+ dtype=proto_array.dtype,
294
+ shape=tuple(proto_array.shape),
295
+ stype=proto_array.stype,
296
+ data=proto_array.data,
297
+ )
298
+
299
+ @property
300
+ def object_id(self) -> str:
301
+ """Get object ID."""
302
+ ret = super().object_id
303
+ self.is_dirty = False # Reset dirty flag
304
+ return ret
305
+
306
+ @property
307
+ def is_dirty(self) -> bool:
308
+ """Check if the object is dirty after the last deflation."""
309
+ if "_is_dirty" not in self.__dict__:
310
+ self.__dict__["_is_dirty"] = True
311
+ return cast(bool, self.__dict__["_is_dirty"])
312
+
313
+ @is_dirty.setter
314
+ def is_dirty(self, value: bool) -> None:
315
+ """Set the dirty flag."""
316
+ self.__dict__["_is_dirty"] = value
317
+
318
+ def __setattr__(self, name: str, value: Any) -> None:
319
+ """Set attribute with special handling for dirty state."""
320
+ if name in ("dtype", "shape", "stype", "data"):
321
+ # Mark as dirty if any of the main attributes are set
322
+ self.is_dirty = True
323
+ super().__setattr__(name, value)