flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +82 -57
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  9. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  10. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  20. flwr/cli/run/run.py +10 -18
  21. flwr/cli/stop.py +2 -2
  22. flwr/cli/utils.py +31 -5
  23. flwr/client/__init__.py +2 -2
  24. flwr/client/client_app.py +1 -1
  25. flwr/client/clientapp/__init__.py +0 -7
  26. flwr/client/grpc_adapter_client/connection.py +4 -4
  27. flwr/client/grpc_rere_client/connection.py +130 -60
  28. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  29. flwr/client/message_handler/message_handler.py +1 -1
  30. flwr/client/mod/comms_mods.py +36 -17
  31. flwr/client/rest_client/connection.py +173 -67
  32. flwr/clientapp/__init__.py +15 -0
  33. flwr/common/__init__.py +2 -2
  34. flwr/common/auth_plugin/__init__.py +2 -0
  35. flwr/common/auth_plugin/auth_plugin.py +29 -3
  36. flwr/common/constant.py +36 -7
  37. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  38. flwr/common/exit_handlers.py +30 -0
  39. flwr/common/heartbeat.py +165 -0
  40. flwr/common/inflatable.py +290 -0
  41. flwr/common/inflatable_grpc_utils.py +99 -0
  42. flwr/common/inflatable_rest_utils.py +99 -0
  43. flwr/common/inflatable_utils.py +341 -0
  44. flwr/common/message.py +110 -242
  45. flwr/common/record/__init__.py +2 -1
  46. flwr/common/record/array.py +323 -0
  47. flwr/common/record/arrayrecord.py +103 -225
  48. flwr/common/record/configrecord.py +59 -4
  49. flwr/common/record/conversion_utils.py +1 -1
  50. flwr/common/record/metricrecord.py +55 -4
  51. flwr/common/record/recorddict.py +69 -1
  52. flwr/common/recorddict_compat.py +2 -2
  53. flwr/common/retry_invoker.py +5 -1
  54. flwr/common/serde.py +59 -183
  55. flwr/common/serde_utils.py +175 -0
  56. flwr/common/typing.py +5 -3
  57. flwr/compat/__init__.py +15 -0
  58. flwr/compat/client/__init__.py +15 -0
  59. flwr/{client → compat/client}/app.py +19 -159
  60. flwr/compat/common/__init__.py +15 -0
  61. flwr/compat/server/__init__.py +15 -0
  62. flwr/compat/server/app.py +174 -0
  63. flwr/compat/simulation/__init__.py +15 -0
  64. flwr/proto/fleet_pb2.py +32 -27
  65. flwr/proto/fleet_pb2.pyi +49 -35
  66. flwr/proto/fleet_pb2_grpc.py +117 -13
  67. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  68. flwr/proto/heartbeat_pb2.py +33 -0
  69. flwr/proto/heartbeat_pb2.pyi +66 -0
  70. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  71. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  72. flwr/proto/message_pb2.py +28 -11
  73. flwr/proto/message_pb2.pyi +125 -0
  74. flwr/proto/recorddict_pb2.py +16 -28
  75. flwr/proto/recorddict_pb2.pyi +46 -64
  76. flwr/proto/run_pb2.py +24 -32
  77. flwr/proto/run_pb2.pyi +4 -52
  78. flwr/proto/serverappio_pb2.py +32 -23
  79. flwr/proto/serverappio_pb2.pyi +45 -3
  80. flwr/proto/serverappio_pb2_grpc.py +138 -34
  81. flwr/proto/serverappio_pb2_grpc.pyi +54 -13
  82. flwr/proto/simulationio_pb2.py +12 -11
  83. flwr/proto/simulationio_pb2_grpc.py +35 -0
  84. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  85. flwr/server/__init__.py +1 -1
  86. flwr/server/app.py +68 -186
  87. flwr/server/compat/app_utils.py +50 -28
  88. flwr/server/fleet_event_log_interceptor.py +2 -2
  89. flwr/server/grid/grpc_grid.py +104 -34
  90. flwr/server/grid/inmemory_grid.py +5 -4
  91. flwr/server/serverapp/app.py +18 -0
  92. flwr/server/superlink/ffs/__init__.py +2 -0
  93. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
  94. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
  95. flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
  96. flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
  97. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  98. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  99. flwr/server/superlink/linkstate/linkstate.py +53 -20
  100. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  101. flwr/server/superlink/linkstate/utils.py +33 -29
  102. flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
  103. flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
  104. flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
  105. flwr/server/superlink/utils.py +44 -2
  106. flwr/server/utils/validator.py +2 -2
  107. flwr/serverapp/__init__.py +15 -0
  108. flwr/simulation/app.py +17 -0
  109. flwr/supercore/__init__.py +15 -0
  110. flwr/supercore/object_store/__init__.py +24 -0
  111. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  112. flwr/supercore/object_store/object_store.py +192 -0
  113. flwr/supercore/object_store/object_store_factory.py +44 -0
  114. flwr/superexec/deployment.py +6 -2
  115. flwr/superexec/exec_event_log_interceptor.py +4 -4
  116. flwr/superexec/exec_grpc.py +7 -3
  117. flwr/superexec/exec_servicer.py +125 -23
  118. flwr/superexec/exec_user_auth_interceptor.py +37 -8
  119. flwr/superexec/executor.py +4 -0
  120. flwr/superexec/simulation.py +7 -1
  121. flwr/superlink/__init__.py +15 -0
  122. flwr/{client/supernode → supernode}/__init__.py +0 -7
  123. flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
  124. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
  125. flwr/supernode/cli/flwr_clientapp.py +81 -0
  126. flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
  127. flwr/supernode/nodestate/nodestate.py +212 -0
  128. flwr/supernode/runtime/__init__.py +15 -0
  129. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
  130. flwr/supernode/servicer/__init__.py +15 -0
  131. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  132. flwr/supernode/start_client_internal.py +491 -0
  133. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
  134. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
  135. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
  136. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
  137. flwr/client/heartbeat.py +0 -74
  138. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  139. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  140. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  141. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  142. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
  143. /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
@@ -0,0 +1,323 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Array."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ import sys
21
+ from dataclasses import dataclass
22
+ from io import BytesIO
23
+ from typing import TYPE_CHECKING, Any, cast, overload
24
+
25
+ import numpy as np
26
+
27
+ from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
28
+
29
+ from ..constant import SType
30
+ from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
31
+ from ..typing import NDArray
32
+
33
+ if TYPE_CHECKING:
34
+ import torch
35
+
36
+
37
+ def _raise_array_init_error() -> None:
38
+ raise TypeError(
39
+ f"Invalid arguments for {Array.__qualname__}. Expected either a "
40
+ "PyTorch tensor, a NumPy ndarray, or explicit"
41
+ " dtype/shape/stype/data values."
42
+ )
43
+
44
+
45
+ @dataclass
46
+ class Array(InflatableObject):
47
+ """Array type.
48
+
49
+ A dataclass containing serialized data from an array-like or tensor-like object
50
+ along with metadata about it. The class can be initialized in one of three ways:
51
+
52
+ 1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
53
+ 2. By providing a NumPy ndarray (via the `ndarray` argument).
54
+ 3. By providing a PyTorch tensor (via the `torch_tensor` argument).
55
+
56
+ In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
57
+ derived from the input. In scenario (1), these fields must be specified manually.
58
+
59
+ Parameters
60
+ ----------
61
+ dtype : Optional[str] (default: None)
62
+ A string representing the data type of the serialized object (e.g. `"float32"`).
63
+ Only required if you are not passing in a ndarray or a tensor.
64
+
65
+ shape : Optional[tuple[int, ...]] (default: None)
66
+ A tuple representing the shape of the unserialized array-like object. Only
67
+ required if you are not passing in a ndarray or a tensor.
68
+
69
+ stype : Optional[str] (default: None)
70
+ A string indicating the serialization mechanism used to generate the bytes in
71
+ `data` from an array-like or tensor-like object. Only required if you are not
72
+ passing in a ndarray or a tensor.
73
+
74
+ data : Optional[bytes] (default: None)
75
+ A buffer of bytes containing the data. Only required if you are not passing in
76
+ a ndarray or a tensor.
77
+
78
+ ndarray : Optional[NDArray] (default: None)
79
+ A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
80
+ fields are derived automatically from it.
81
+
82
+ torch_tensor : Optional[torch.Tensor] (default: None)
83
+ A PyTorch tensor. If provided, it will be **detached and moved to CPU**
84
+ before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
85
+ will be derived automatically from it.
86
+
87
+ Examples
88
+ --------
89
+ Initializing by specifying all fields directly::
90
+
91
+ arr1 = Array(
92
+ dtype="float32",
93
+ shape=[3, 3],
94
+ stype="numpy.ndarray",
95
+ data=b"serialized_data...",
96
+ )
97
+
98
+ Initializing with a NumPy ndarray::
99
+
100
+ import numpy as np
101
+ arr2 = Array(np.random.randn(3, 3))
102
+
103
+ Initializing with a PyTorch tensor::
104
+
105
+ import torch
106
+ arr3 = Array(torch.randn(3, 3))
107
+ """
108
+
109
+ dtype: str
110
+ shape: tuple[int, ...]
111
+ stype: str
112
+ data: bytes
113
+
114
+ @overload
115
+ def __init__( # noqa: E704
116
+ self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
117
+ ) -> None: ...
118
+
119
+ @overload
120
+ def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
121
+
122
+ @overload
123
+ def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
124
+
125
+ def __init__( # pylint: disable=too-many-arguments, too-many-locals
126
+ self,
127
+ *args: Any,
128
+ dtype: str | None = None,
129
+ shape: tuple[int, ...] | None = None,
130
+ stype: str | None = None,
131
+ data: bytes | None = None,
132
+ ndarray: NDArray | None = None,
133
+ torch_tensor: torch.Tensor | None = None,
134
+ ) -> None:
135
+ # Determine the initialization method and validate input arguments.
136
+ # Support three initialization formats:
137
+ # 1. Array(dtype: str, shape: tuple[int, ...], stype: str, data: bytes)
138
+ # 2. Array(ndarray: NDArray)
139
+ # 3. Array(torch_tensor: torch.Tensor)
140
+
141
+ # Initialize all arguments
142
+ # If more than 4 positional arguments are provided, raise an error.
143
+ if len(args) > 4:
144
+ _raise_array_init_error()
145
+ all_args = [None] * 4
146
+ for i, arg in enumerate(args):
147
+ all_args[i] = arg
148
+ init_method: str | None = None # Track which init method is being used
149
+
150
+ # Try to assign a value to all_args[index] if it's not already set.
151
+ # If an initialization method is provided, update init_method.
152
+ def _try_set_arg(index: int, arg: Any, method: str) -> None:
153
+ # Skip if arg is None
154
+ if arg is None:
155
+ return
156
+ # Raise an error if all_args[index] is already set
157
+ if all_args[index] is not None:
158
+ _raise_array_init_error()
159
+ # Raise an error if a different initialization method is already set
160
+ nonlocal init_method
161
+ if init_method is not None and init_method != method:
162
+ _raise_array_init_error()
163
+ # Set init_method and all_args[index]
164
+ if init_method is None:
165
+ init_method = method
166
+ all_args[index] = arg
167
+
168
+ # Try to set keyword arguments in all_args
169
+ _try_set_arg(0, dtype, "direct")
170
+ _try_set_arg(1, shape, "direct")
171
+ _try_set_arg(2, stype, "direct")
172
+ _try_set_arg(3, data, "direct")
173
+ _try_set_arg(0, ndarray, "ndarray")
174
+ _try_set_arg(0, torch_tensor, "torch_tensor")
175
+
176
+ # Check if all arguments are correctly set
177
+ all_args = [arg for arg in all_args if arg is not None]
178
+
179
+ # Handle direct field initialization
180
+ if not init_method or init_method == "direct":
181
+ if (
182
+ len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
183
+ and isinstance(all_args[0], str)
184
+ and isinstance(all_args[1], tuple)
185
+ and all(isinstance(i, int) for i in all_args[1])
186
+ and isinstance(all_args[2], str)
187
+ and isinstance(all_args[3], bytes)
188
+ ):
189
+ self.dtype, self.shape, self.stype, self.data = all_args
190
+ return
191
+
192
+ # Handle NumPy array
193
+ if not init_method or init_method == "ndarray":
194
+ if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
195
+ self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
196
+ return
197
+
198
+ # Handle PyTorch tensor
199
+ if not init_method or init_method == "torch_tensor":
200
+ if (
201
+ len(all_args) == 1
202
+ and "torch" in sys.modules
203
+ and isinstance(all_args[0], sys.modules["torch"].Tensor)
204
+ ):
205
+ self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
206
+ return
207
+
208
+ _raise_array_init_error()
209
+
210
+ @classmethod
211
+ def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
212
+ """Create Array from NumPy ndarray."""
213
+ assert isinstance(
214
+ ndarray, np.ndarray
215
+ ), f"Expected NumPy ndarray, got {type(ndarray)}"
216
+ buffer = BytesIO()
217
+ # WARNING: NEVER set allow_pickle to true.
218
+ # Reason: loading pickled data can execute arbitrary code
219
+ # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
220
+ np.save(buffer, ndarray, allow_pickle=False)
221
+ data = buffer.getvalue()
222
+ return Array(
223
+ dtype=str(ndarray.dtype),
224
+ shape=tuple(ndarray.shape),
225
+ stype=SType.NUMPY,
226
+ data=data,
227
+ )
228
+
229
+ @classmethod
230
+ def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
231
+ """Create Array from PyTorch tensor."""
232
+ if not (torch := sys.modules.get("torch")):
233
+ raise RuntimeError(
234
+ f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
235
+ )
236
+
237
+ assert isinstance(
238
+ tensor, torch.Tensor
239
+ ), f"Expected PyTorch Tensor, got {type(tensor)}"
240
+ return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
241
+
242
+ def numpy(self) -> NDArray:
243
+ """Return the array as a NumPy array."""
244
+ if self.stype != SType.NUMPY:
245
+ raise TypeError(
246
+ f"Unsupported serialization type for numpy conversion: '{self.stype}'"
247
+ )
248
+ bytes_io = BytesIO(self.data)
249
+ # WARNING: NEVER set allow_pickle to true.
250
+ # Reason: loading pickled data can execute arbitrary code
251
+ # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
252
+ ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
253
+ return cast(NDArray, ndarray_deserialized)
254
+
255
+ def deflate(self) -> bytes:
256
+ """Deflate the Array."""
257
+ array_proto = ArrayProto(
258
+ dtype=self.dtype,
259
+ shape=self.shape,
260
+ stype=self.stype,
261
+ data=self.data,
262
+ )
263
+
264
+ obj_body = array_proto.SerializeToString(deterministic=True)
265
+ return add_header_to_object_body(object_body=obj_body, obj=self)
266
+
267
+ @classmethod
268
+ def inflate(
269
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
270
+ ) -> Array:
271
+ """Inflate an Array from bytes.
272
+
273
+ Parameters
274
+ ----------
275
+ object_content : bytes
276
+ The deflated object content of the Array.
277
+
278
+ children : Optional[dict[str, InflatableObject]] (default: None)
279
+ Must be ``None``. ``Array`` does not support child objects.
280
+ Providing any children will raise a ``ValueError``.
281
+
282
+ Returns
283
+ -------
284
+ Array
285
+ The inflated Array.
286
+ """
287
+ if children:
288
+ raise ValueError("`Array` objects do not have children.")
289
+
290
+ obj_body = get_object_body(object_content, cls)
291
+ proto_array = ArrayProto.FromString(obj_body)
292
+ return cls(
293
+ dtype=proto_array.dtype,
294
+ shape=tuple(proto_array.shape),
295
+ stype=proto_array.stype,
296
+ data=proto_array.data,
297
+ )
298
+
299
+ @property
300
+ def object_id(self) -> str:
301
+ """Get object ID."""
302
+ ret = super().object_id
303
+ self.is_dirty = False # Reset dirty flag
304
+ return ret
305
+
306
+ @property
307
+ def is_dirty(self) -> bool:
308
+ """Check if the object is dirty after the last deflation."""
309
+ if "_is_dirty" not in self.__dict__:
310
+ self.__dict__["_is_dirty"] = True
311
+ return cast(bool, self.__dict__["_is_dirty"])
312
+
313
+ @is_dirty.setter
314
+ def is_dirty(self, value: bool) -> None:
315
+ """Set the dirty flag."""
316
+ self.__dict__["_is_dirty"] = value
317
+
318
+ def __setattr__(self, name: str, value: Any) -> None:
319
+ """Set attribute with special handling for dirty state."""
320
+ if name in ("dtype", "shape", "stype", "data"):
321
+ # Mark as dirty if any of the main attributes are set
322
+ self.is_dirty = True
323
+ super().__setattr__(name, value)
@@ -12,38 +12,31 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """ArrayRecord and Array."""
15
+ """ArrayRecord."""
16
16
 
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
20
  import gc
21
+ import json
21
22
  import sys
22
23
  from collections import OrderedDict
23
- from dataclasses import dataclass
24
- from io import BytesIO
25
24
  from logging import WARN
26
25
  from typing import TYPE_CHECKING, Any, cast, overload
27
26
 
28
27
  import numpy as np
29
28
 
30
- from ..constant import GC_THRESHOLD, SType
29
+ from ..constant import GC_THRESHOLD
30
+ from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
31
31
  from ..logger import log
32
32
  from ..typing import NDArray
33
+ from .array import Array
33
34
  from .typeddict import TypedDict
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  import torch
37
38
 
38
39
 
39
- def _raise_array_init_error() -> None:
40
- raise TypeError(
41
- f"Invalid arguments for {Array.__qualname__}. Expected either a "
42
- "PyTorch tensor, a NumPy ndarray, or explicit"
43
- " dtype/shape/stype/data values."
44
- )
45
-
46
-
47
40
  def _raise_array_record_init_error() -> None:
48
41
  raise TypeError(
49
42
  f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
@@ -52,217 +45,6 @@ def _raise_array_record_init_error() -> None:
52
45
  )
53
46
 
54
47
 
55
- @dataclass
56
- class Array:
57
- """Array type.
58
-
59
- A dataclass containing serialized data from an array-like or tensor-like object
60
- along with metadata about it. The class can be initialized in one of three ways:
61
-
62
- 1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
63
- 2. By providing a NumPy ndarray (via the `ndarray` argument).
64
- 3. By providing a PyTorch tensor (via the `torch_tensor` argument).
65
-
66
- In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
67
- derived from the input. In scenario (1), these fields must be specified manually.
68
-
69
- Parameters
70
- ----------
71
- dtype : Optional[str] (default: None)
72
- A string representing the data type of the serialized object (e.g. `"float32"`).
73
- Only required if you are not passing in a ndarray or a tensor.
74
-
75
- shape : Optional[list[int]] (default: None)
76
- A list representing the shape of the unserialized array-like object. Only
77
- required if you are not passing in a ndarray or a tensor.
78
-
79
- stype : Optional[str] (default: None)
80
- A string indicating the serialization mechanism used to generate the bytes in
81
- `data` from an array-like or tensor-like object. Only required if you are not
82
- passing in a ndarray or a tensor.
83
-
84
- data : Optional[bytes] (default: None)
85
- A buffer of bytes containing the data. Only required if you are not passing in
86
- a ndarray or a tensor.
87
-
88
- ndarray : Optional[NDArray] (default: None)
89
- A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
90
- fields are derived automatically from it.
91
-
92
- torch_tensor : Optional[torch.Tensor] (default: None)
93
- A PyTorch tensor. If provided, it will be **detached and moved to CPU**
94
- before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
95
- will be derived automatically from it.
96
-
97
- Examples
98
- --------
99
- Initializing by specifying all fields directly::
100
-
101
- arr1 = Array(
102
- dtype="float32",
103
- shape=[3, 3],
104
- stype="numpy.ndarray",
105
- data=b"serialized_data...",
106
- )
107
-
108
- Initializing with a NumPy ndarray::
109
-
110
- import numpy as np
111
- arr2 = Array(np.random.randn(3, 3))
112
-
113
- Initializing with a PyTorch tensor::
114
-
115
- import torch
116
- arr3 = Array(torch.randn(3, 3))
117
- """
118
-
119
- dtype: str
120
- shape: list[int]
121
- stype: str
122
- data: bytes
123
-
124
- @overload
125
- def __init__( # noqa: E704
126
- self, dtype: str, shape: list[int], stype: str, data: bytes
127
- ) -> None: ...
128
-
129
- @overload
130
- def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
131
-
132
- @overload
133
- def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
134
-
135
- def __init__( # pylint: disable=too-many-arguments, too-many-locals
136
- self,
137
- *args: Any,
138
- dtype: str | None = None,
139
- shape: list[int] | None = None,
140
- stype: str | None = None,
141
- data: bytes | None = None,
142
- ndarray: NDArray | None = None,
143
- torch_tensor: torch.Tensor | None = None,
144
- ) -> None:
145
- # Determine the initialization method and validate input arguments.
146
- # Support three initialization formats:
147
- # 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
148
- # 2. Array(ndarray: NDArray)
149
- # 3. Array(torch_tensor: torch.Tensor)
150
-
151
- # Initialize all arguments
152
- # If more than 4 positional arguments are provided, raise an error.
153
- if len(args) > 4:
154
- _raise_array_init_error()
155
- all_args = [None] * 4
156
- for i, arg in enumerate(args):
157
- all_args[i] = arg
158
- init_method: str | None = None # Track which init method is being used
159
-
160
- # Try to assign a value to all_args[index] if it's not already set.
161
- # If an initialization method is provided, update init_method.
162
- def _try_set_arg(index: int, arg: Any, method: str) -> None:
163
- # Skip if arg is None
164
- if arg is None:
165
- return
166
- # Raise an error if all_args[index] is already set
167
- if all_args[index] is not None:
168
- _raise_array_init_error()
169
- # Raise an error if a different initialization method is already set
170
- nonlocal init_method
171
- if init_method is not None and init_method != method:
172
- _raise_array_init_error()
173
- # Set init_method and all_args[index]
174
- if init_method is None:
175
- init_method = method
176
- all_args[index] = arg
177
-
178
- # Try to set keyword arguments in all_args
179
- _try_set_arg(0, dtype, "direct")
180
- _try_set_arg(1, shape, "direct")
181
- _try_set_arg(2, stype, "direct")
182
- _try_set_arg(3, data, "direct")
183
- _try_set_arg(0, ndarray, "ndarray")
184
- _try_set_arg(0, torch_tensor, "torch_tensor")
185
-
186
- # Check if all arguments are correctly set
187
- all_args = [arg for arg in all_args if arg is not None]
188
-
189
- # Handle direct field initialization
190
- if not init_method or init_method == "direct":
191
- if (
192
- len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
193
- and isinstance(all_args[0], str)
194
- and isinstance(all_args[1], list)
195
- and all(isinstance(i, int) for i in all_args[1])
196
- and isinstance(all_args[2], str)
197
- and isinstance(all_args[3], bytes)
198
- ):
199
- self.dtype, self.shape, self.stype, self.data = all_args
200
- return
201
-
202
- # Handle NumPy array
203
- if not init_method or init_method == "ndarray":
204
- if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
205
- self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
206
- return
207
-
208
- # Handle PyTorch tensor
209
- if not init_method or init_method == "torch_tensor":
210
- if (
211
- len(all_args) == 1
212
- and "torch" in sys.modules
213
- and isinstance(all_args[0], sys.modules["torch"].Tensor)
214
- ):
215
- self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
216
- return
217
-
218
- _raise_array_init_error()
219
-
220
- @classmethod
221
- def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
222
- """Create Array from NumPy ndarray."""
223
- assert isinstance(
224
- ndarray, np.ndarray
225
- ), f"Expected NumPy ndarray, got {type(ndarray)}"
226
- buffer = BytesIO()
227
- # WARNING: NEVER set allow_pickle to true.
228
- # Reason: loading pickled data can execute arbitrary code
229
- # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
230
- np.save(buffer, ndarray, allow_pickle=False)
231
- data = buffer.getvalue()
232
- return Array(
233
- dtype=str(ndarray.dtype),
234
- shape=list(ndarray.shape),
235
- stype=SType.NUMPY,
236
- data=data,
237
- )
238
-
239
- @classmethod
240
- def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
241
- """Create Array from PyTorch tensor."""
242
- if not (torch := sys.modules.get("torch")):
243
- raise RuntimeError(
244
- f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
245
- )
246
-
247
- assert isinstance(
248
- tensor, torch.Tensor
249
- ), f"Expected PyTorch Tensor, got {type(tensor)}"
250
- return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
251
-
252
- def numpy(self) -> NDArray:
253
- """Return the array as a NumPy array."""
254
- if self.stype != SType.NUMPY:
255
- raise TypeError(
256
- f"Unsupported serialization type for numpy conversion: '{self.stype}'"
257
- )
258
- bytes_io = BytesIO(self.data)
259
- # WARNING: NEVER set allow_pickle to true.
260
- # Reason: loading pickled data can execute arbitrary code
261
- # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
262
- ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
263
- return cast(NDArray, ndarray_deserialized)
264
-
265
-
266
48
  def _check_key(key: str) -> None:
267
49
  """Check if key is of expected type."""
268
50
  if not isinstance(key, str):
@@ -276,7 +58,7 @@ def _check_value(value: Array) -> None:
276
58
  )
277
59
 
278
60
 
279
- class ArrayRecord(TypedDict[str, Array]):
61
+ class ArrayRecord(TypedDict[str, Array], InflatableObject):
280
62
  """Array record.
281
63
 
282
64
  A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
@@ -470,7 +252,7 @@ class ArrayRecord(TypedDict[str, Array]):
470
252
  record = ArrayRecord()
471
253
  for k, v in array_dict.items():
472
254
  record[k] = Array(
473
- dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
255
+ dtype=v.dtype, shape=tuple(v.shape), stype=v.stype, data=v.data
474
256
  )
475
257
  if not keep_input:
476
258
  array_dict.clear()
@@ -585,6 +367,102 @@ class ArrayRecord(TypedDict[str, Array]):
585
367
 
586
368
  return num_bytes
587
369
 
370
+ @property
371
+ def children(self) -> dict[str, InflatableObject]:
372
+ """Return a dictionary of Arrays with their Object IDs as keys."""
373
+ return {arr.object_id: arr for arr in self.values()}
374
+
375
+ def deflate(self) -> bytes:
376
+ """Deflate the ArrayRecord."""
377
+ # array_name: array_object_id mapping
378
+ array_refs: dict[str, str] = {}
379
+
380
+ for array_name, array in self.items():
381
+ array_refs[array_name] = array.object_id
382
+
383
+ # Serialize references dict
384
+ object_body = json.dumps(array_refs).encode("utf-8")
385
+ return add_header_to_object_body(object_body=object_body, obj=self)
386
+
387
+ @classmethod
388
+ def inflate(
389
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
390
+ ) -> ArrayRecord:
391
+ """Inflate an ArrayRecord from bytes.
392
+
393
+ Parameters
394
+ ----------
395
+ object_content : bytes
396
+ The deflated object content of the ArrayRecord.
397
+ children : Optional[dict[str, InflatableObject]] (default: None)
398
+ Dictionary of children InflatableObjects mapped to their Object IDs.
399
+ These children enable the full inflation of the ArrayRecord.
400
+
401
+ Returns
402
+ -------
403
+ ArrayRecord
404
+ The inflated ArrayRecord.
405
+ """
406
+ if children is None:
407
+ children = {}
408
+
409
+ # Inflate mapping of array_names (keys in the ArrayRecord) to Arrays' object IDs
410
+ obj_body = get_object_body(object_content, cls)
411
+ array_refs: dict[str, str] = json.loads(obj_body.decode(encoding="utf-8"))
412
+
413
+ unique_arrays = set(array_refs.values())
414
+ children_obj_ids = set(children.keys())
415
+ if unique_arrays != children_obj_ids:
416
+ raise ValueError(
417
+ "Unexpected set of `children`. "
418
+ f"Expected {unique_arrays} but got {children_obj_ids}."
419
+ )
420
+
421
+ # Ensure children are of type Array
422
+ if not all(isinstance(arr, Array) for arr in children.values()):
423
+ raise ValueError("`Children` are expected to be of type `Array`.")
424
+
425
+ # Instantiate new ArrayRecord
426
+ return ArrayRecord(
427
+ OrderedDict(
428
+ {name: children[object_id] for name, object_id in array_refs.items()}
429
+ )
430
+ )
431
+
432
+ @property
433
+ def object_id(self) -> str:
434
+ """Get object ID."""
435
+ ret = super().object_id
436
+ self.is_dirty = False # Reset dirty flag
437
+ return ret
438
+
439
+ @property
440
+ def is_dirty(self) -> bool:
441
+ """Check if the object is dirty after the last deflation."""
442
+ if "_is_dirty" not in self.__dict__:
443
+ self.__dict__["_is_dirty"] = True
444
+
445
+ if not self.__dict__["_is_dirty"]:
446
+ if any(v.is_dirty for v in self.values()):
447
+ # If any Array is dirty, mark the record as dirty
448
+ self.__dict__["_is_dirty"] = True
449
+ return cast(bool, self.__dict__["_is_dirty"])
450
+
451
+ @is_dirty.setter
452
+ def is_dirty(self, value: bool) -> None:
453
+ """Set the dirty flag."""
454
+ self.__dict__["_is_dirty"] = value
455
+
456
+ def __setitem__(self, key: str, value: Array) -> None:
457
+ """Set item and mark the record as dirty."""
458
+ self.is_dirty = True # Mark as dirty when setting an item
459
+ super().__setitem__(key, value)
460
+
461
+ def __delitem__(self, key: str) -> None:
462
+ """Delete item and mark the record as dirty."""
463
+ self.is_dirty = True # Mark as dirty when deleting an item
464
+ super().__delitem__(key)
465
+
588
466
 
589
467
  class ParametersRecord(ArrayRecord):
590
468
  """Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.