flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/__init__.py +1 -1
- flwr/cli/app.py +21 -2
- flwr/cli/build.py +83 -58
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +53 -17
- flwr/cli/example.py +1 -1
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +4 -4
- flwr/cli/login/__init__.py +1 -1
- flwr/cli/login/login.py +15 -8
- flwr/cli/ls.py +16 -37
- flwr/cli/new/__init__.py +1 -1
- flwr/cli/new/new.py +4 -4
- flwr/cli/new/templates/__init__.py +1 -1
- flwr/cli/new/templates/app/__init__.py +1 -1
- flwr/cli/new/templates/app/code/__init__.py +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -1
- flwr/cli/run/run.py +11 -19
- flwr/cli/stop.py +3 -3
- flwr/cli/utils.py +42 -17
- flwr/client/__init__.py +3 -3
- flwr/client/client.py +1 -1
- flwr/client/client_app.py +140 -138
- flwr/client/clientapp/__init__.py +1 -8
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +1 -1
- flwr/client/grpc_adapter_client/connection.py +5 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +131 -61
- flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/centraldp_mods.py +1 -1
- flwr/client/mod/comms_mods.py +39 -20
- flwr/client/mod/localdp_mod.py +6 -6
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +174 -68
- flwr/client/run_info_store.py +1 -1
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +3 -3
- flwr/common/address.py +1 -1
- flwr/common/args.py +1 -1
- flwr/common/auth_plugin/__init__.py +3 -1
- flwr/common/auth_plugin/auth_plugin.py +30 -4
- flwr/common/config.py +1 -1
- flwr/common/constant.py +37 -8
- flwr/common/context.py +1 -1
- flwr/common/date.py +1 -1
- flwr/common/differential_privacy.py +1 -1
- flwr/common/differential_privacy_constants.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit.py +6 -6
- flwr/common/exit_handlers.py +31 -1
- flwr/common/grpc.py +1 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/logger.py +1 -1
- flwr/common/message.py +137 -252
- flwr/common/object_ref.py +1 -1
- flwr/common/parameter.py +1 -1
- flwr/common/pyproject.py +1 -1
- flwr/common/record/__init__.py +3 -2
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +121 -243
- flwr/common/record/configrecord.py +71 -16
- flwr/common/record/conversion_utils.py +2 -2
- flwr/common/record/metricrecord.py +71 -20
- flwr/common/record/recorddict.py +207 -90
- flwr/common/record/typeddict.py +1 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +15 -11
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +52 -30
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +60 -184
- flwr/common/serde_utils.py +175 -0
- flwr/common/telemetry.py +2 -2
- flwr/common/typing.py +6 -4
- flwr/common/version.py +1 -1
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +71 -211
- flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
- flwr/{client → compat/client}/grpc_client/connection.py +13 -13
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/__init__.py +1 -1
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +2 -2
- flwr/server/app.py +69 -187
- flwr/server/client_manager.py +1 -1
- flwr/server/client_proxy.py +1 -1
- flwr/server/compat/__init__.py +1 -1
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +51 -29
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/criterion.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/history.py +1 -1
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +1 -1
- flwr/server/server_app.py +65 -58
- flwr/server/server_config.py +1 -1
- flwr/server/serverapp/__init__.py +1 -1
- flwr/server/serverapp/app.py +19 -1
- flwr/server/serverapp_components.py +1 -1
- flwr/server/strategy/__init__.py +1 -1
- flwr/server/strategy/aggregate.py +1 -1
- flwr/server/strategy/bulyan.py +2 -2
- flwr/server/strategy/dp_adaptive_clipping.py +17 -17
- flwr/server/strategy/dp_fixed_clipping.py +17 -17
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedtrimmedavg.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +3 -2
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/strategy/strategy.py +1 -1
- flwr/server/superlink/__init__.py +1 -1
- flwr/server/superlink/ffs/__init__.py +3 -1
- flwr/server/superlink/ffs/disk_ffs.py +1 -1
- flwr/server/superlink/ffs/ffs.py +1 -1
- flwr/server/superlink/ffs/ffs_factory.py +1 -1
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +7 -4
- flwr/server/superlink/linkstate/__init__.py +1 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
- flwr/server/superlink/linkstate/linkstate.py +54 -21
- flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
- flwr/server/superlink/linkstate/utils.py +34 -30
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/__init__.py +1 -1
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +45 -3
- flwr/server/typing.py +1 -1
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +3 -3
- flwr/server/workflow/__init__.py +1 -1
- flwr/server/workflow/constant.py +1 -1
- flwr/server/workflow/default_workflows.py +1 -1
- flwr/server/workflow/secure_aggregation/__init__.py +1 -1
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +18 -1
- flwr/simulation/legacy_app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/ray_transport/utils.py +1 -1
- flwr/simulation/run_simulation.py +2 -2
- flwr/simulation/simulationio_connection.py +1 -1
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/__init__.py +1 -1
- flwr/superexec/app.py +1 -1
- flwr/superexec/deployment.py +7 -3
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +8 -4
- flwr/superexec/exec_servicer.py +126 -24
- flwr/superexec/exec_user_auth_interceptor.py +38 -9
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -2
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +1 -8
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/{client → supernode}/nodestate/__init__.py +1 -1
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
- flwr-1.19.0.dist-info/RECORD +365 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- flwr-1.17.0.dist-info/LICENSE +0 -202
- flwr-1.17.0.dist-info/RECORD +0 -333
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Array."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import sys
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from io import BytesIO
|
|
23
|
+
from typing import TYPE_CHECKING, Any, cast, overload
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
|
|
28
|
+
|
|
29
|
+
from ..constant import SType
|
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
31
|
+
from ..typing import NDArray
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
import torch
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _raise_array_init_error() -> None:
|
|
38
|
+
raise TypeError(
|
|
39
|
+
f"Invalid arguments for {Array.__qualname__}. Expected either a "
|
|
40
|
+
"PyTorch tensor, a NumPy ndarray, or explicit"
|
|
41
|
+
" dtype/shape/stype/data values."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class Array(InflatableObject):
|
|
47
|
+
"""Array type.
|
|
48
|
+
|
|
49
|
+
A dataclass containing serialized data from an array-like or tensor-like object
|
|
50
|
+
along with metadata about it. The class can be initialized in one of three ways:
|
|
51
|
+
|
|
52
|
+
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
|
|
53
|
+
2. By providing a NumPy ndarray (via the `ndarray` argument).
|
|
54
|
+
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
|
|
55
|
+
|
|
56
|
+
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
|
|
57
|
+
derived from the input. In scenario (1), these fields must be specified manually.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
dtype : Optional[str] (default: None)
|
|
62
|
+
A string representing the data type of the serialized object (e.g. `"float32"`).
|
|
63
|
+
Only required if you are not passing in a ndarray or a tensor.
|
|
64
|
+
|
|
65
|
+
shape : Optional[tuple[int, ...]] (default: None)
|
|
66
|
+
A tuple representing the shape of the unserialized array-like object. Only
|
|
67
|
+
required if you are not passing in a ndarray or a tensor.
|
|
68
|
+
|
|
69
|
+
stype : Optional[str] (default: None)
|
|
70
|
+
A string indicating the serialization mechanism used to generate the bytes in
|
|
71
|
+
`data` from an array-like or tensor-like object. Only required if you are not
|
|
72
|
+
passing in a ndarray or a tensor.
|
|
73
|
+
|
|
74
|
+
data : Optional[bytes] (default: None)
|
|
75
|
+
A buffer of bytes containing the data. Only required if you are not passing in
|
|
76
|
+
a ndarray or a tensor.
|
|
77
|
+
|
|
78
|
+
ndarray : Optional[NDArray] (default: None)
|
|
79
|
+
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
|
|
80
|
+
fields are derived automatically from it.
|
|
81
|
+
|
|
82
|
+
torch_tensor : Optional[torch.Tensor] (default: None)
|
|
83
|
+
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
|
|
84
|
+
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
|
|
85
|
+
will be derived automatically from it.
|
|
86
|
+
|
|
87
|
+
Examples
|
|
88
|
+
--------
|
|
89
|
+
Initializing by specifying all fields directly::
|
|
90
|
+
|
|
91
|
+
arr1 = Array(
|
|
92
|
+
dtype="float32",
|
|
93
|
+
shape=[3, 3],
|
|
94
|
+
stype="numpy.ndarray",
|
|
95
|
+
data=b"serialized_data...",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
Initializing with a NumPy ndarray::
|
|
99
|
+
|
|
100
|
+
import numpy as np
|
|
101
|
+
arr2 = Array(np.random.randn(3, 3))
|
|
102
|
+
|
|
103
|
+
Initializing with a PyTorch tensor::
|
|
104
|
+
|
|
105
|
+
import torch
|
|
106
|
+
arr3 = Array(torch.randn(3, 3))
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
dtype: str
|
|
110
|
+
shape: tuple[int, ...]
|
|
111
|
+
stype: str
|
|
112
|
+
data: bytes
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def __init__( # noqa: E704
|
|
116
|
+
self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
|
|
117
|
+
) -> None: ...
|
|
118
|
+
|
|
119
|
+
@overload
|
|
120
|
+
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
|
|
121
|
+
|
|
122
|
+
@overload
|
|
123
|
+
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
|
|
124
|
+
|
|
125
|
+
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
|
126
|
+
self,
|
|
127
|
+
*args: Any,
|
|
128
|
+
dtype: str | None = None,
|
|
129
|
+
shape: tuple[int, ...] | None = None,
|
|
130
|
+
stype: str | None = None,
|
|
131
|
+
data: bytes | None = None,
|
|
132
|
+
ndarray: NDArray | None = None,
|
|
133
|
+
torch_tensor: torch.Tensor | None = None,
|
|
134
|
+
) -> None:
|
|
135
|
+
# Determine the initialization method and validate input arguments.
|
|
136
|
+
# Support three initialization formats:
|
|
137
|
+
# 1. Array(dtype: str, shape: tuple[int, ...], stype: str, data: bytes)
|
|
138
|
+
# 2. Array(ndarray: NDArray)
|
|
139
|
+
# 3. Array(torch_tensor: torch.Tensor)
|
|
140
|
+
|
|
141
|
+
# Initialize all arguments
|
|
142
|
+
# If more than 4 positional arguments are provided, raise an error.
|
|
143
|
+
if len(args) > 4:
|
|
144
|
+
_raise_array_init_error()
|
|
145
|
+
all_args = [None] * 4
|
|
146
|
+
for i, arg in enumerate(args):
|
|
147
|
+
all_args[i] = arg
|
|
148
|
+
init_method: str | None = None # Track which init method is being used
|
|
149
|
+
|
|
150
|
+
# Try to assign a value to all_args[index] if it's not already set.
|
|
151
|
+
# If an initialization method is provided, update init_method.
|
|
152
|
+
def _try_set_arg(index: int, arg: Any, method: str) -> None:
|
|
153
|
+
# Skip if arg is None
|
|
154
|
+
if arg is None:
|
|
155
|
+
return
|
|
156
|
+
# Raise an error if all_args[index] is already set
|
|
157
|
+
if all_args[index] is not None:
|
|
158
|
+
_raise_array_init_error()
|
|
159
|
+
# Raise an error if a different initialization method is already set
|
|
160
|
+
nonlocal init_method
|
|
161
|
+
if init_method is not None and init_method != method:
|
|
162
|
+
_raise_array_init_error()
|
|
163
|
+
# Set init_method and all_args[index]
|
|
164
|
+
if init_method is None:
|
|
165
|
+
init_method = method
|
|
166
|
+
all_args[index] = arg
|
|
167
|
+
|
|
168
|
+
# Try to set keyword arguments in all_args
|
|
169
|
+
_try_set_arg(0, dtype, "direct")
|
|
170
|
+
_try_set_arg(1, shape, "direct")
|
|
171
|
+
_try_set_arg(2, stype, "direct")
|
|
172
|
+
_try_set_arg(3, data, "direct")
|
|
173
|
+
_try_set_arg(0, ndarray, "ndarray")
|
|
174
|
+
_try_set_arg(0, torch_tensor, "torch_tensor")
|
|
175
|
+
|
|
176
|
+
# Check if all arguments are correctly set
|
|
177
|
+
all_args = [arg for arg in all_args if arg is not None]
|
|
178
|
+
|
|
179
|
+
# Handle direct field initialization
|
|
180
|
+
if not init_method or init_method == "direct":
|
|
181
|
+
if (
|
|
182
|
+
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
|
|
183
|
+
and isinstance(all_args[0], str)
|
|
184
|
+
and isinstance(all_args[1], tuple)
|
|
185
|
+
and all(isinstance(i, int) for i in all_args[1])
|
|
186
|
+
and isinstance(all_args[2], str)
|
|
187
|
+
and isinstance(all_args[3], bytes)
|
|
188
|
+
):
|
|
189
|
+
self.dtype, self.shape, self.stype, self.data = all_args
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
# Handle NumPy array
|
|
193
|
+
if not init_method or init_method == "ndarray":
|
|
194
|
+
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
|
|
195
|
+
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
# Handle PyTorch tensor
|
|
199
|
+
if not init_method or init_method == "torch_tensor":
|
|
200
|
+
if (
|
|
201
|
+
len(all_args) == 1
|
|
202
|
+
and "torch" in sys.modules
|
|
203
|
+
and isinstance(all_args[0], sys.modules["torch"].Tensor)
|
|
204
|
+
):
|
|
205
|
+
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
_raise_array_init_error()
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
|
|
212
|
+
"""Create Array from NumPy ndarray."""
|
|
213
|
+
assert isinstance(
|
|
214
|
+
ndarray, np.ndarray
|
|
215
|
+
), f"Expected NumPy ndarray, got {type(ndarray)}"
|
|
216
|
+
buffer = BytesIO()
|
|
217
|
+
# WARNING: NEVER set allow_pickle to true.
|
|
218
|
+
# Reason: loading pickled data can execute arbitrary code
|
|
219
|
+
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
|
|
220
|
+
np.save(buffer, ndarray, allow_pickle=False)
|
|
221
|
+
data = buffer.getvalue()
|
|
222
|
+
return Array(
|
|
223
|
+
dtype=str(ndarray.dtype),
|
|
224
|
+
shape=tuple(ndarray.shape),
|
|
225
|
+
stype=SType.NUMPY,
|
|
226
|
+
data=data,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
|
|
231
|
+
"""Create Array from PyTorch tensor."""
|
|
232
|
+
if not (torch := sys.modules.get("torch")):
|
|
233
|
+
raise RuntimeError(
|
|
234
|
+
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
assert isinstance(
|
|
238
|
+
tensor, torch.Tensor
|
|
239
|
+
), f"Expected PyTorch Tensor, got {type(tensor)}"
|
|
240
|
+
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
|
|
241
|
+
|
|
242
|
+
def numpy(self) -> NDArray:
|
|
243
|
+
"""Return the array as a NumPy array."""
|
|
244
|
+
if self.stype != SType.NUMPY:
|
|
245
|
+
raise TypeError(
|
|
246
|
+
f"Unsupported serialization type for numpy conversion: '{self.stype}'"
|
|
247
|
+
)
|
|
248
|
+
bytes_io = BytesIO(self.data)
|
|
249
|
+
# WARNING: NEVER set allow_pickle to true.
|
|
250
|
+
# Reason: loading pickled data can execute arbitrary code
|
|
251
|
+
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
|
|
252
|
+
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
|
253
|
+
return cast(NDArray, ndarray_deserialized)
|
|
254
|
+
|
|
255
|
+
def deflate(self) -> bytes:
|
|
256
|
+
"""Deflate the Array."""
|
|
257
|
+
array_proto = ArrayProto(
|
|
258
|
+
dtype=self.dtype,
|
|
259
|
+
shape=self.shape,
|
|
260
|
+
stype=self.stype,
|
|
261
|
+
data=self.data,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
obj_body = array_proto.SerializeToString(deterministic=True)
|
|
265
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
266
|
+
|
|
267
|
+
@classmethod
|
|
268
|
+
def inflate(
|
|
269
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
270
|
+
) -> Array:
|
|
271
|
+
"""Inflate an Array from bytes.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
object_content : bytes
|
|
276
|
+
The deflated object content of the Array.
|
|
277
|
+
|
|
278
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
279
|
+
Must be ``None``. ``Array`` does not support child objects.
|
|
280
|
+
Providing any children will raise a ``ValueError``.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
Array
|
|
285
|
+
The inflated Array.
|
|
286
|
+
"""
|
|
287
|
+
if children:
|
|
288
|
+
raise ValueError("`Array` objects do not have children.")
|
|
289
|
+
|
|
290
|
+
obj_body = get_object_body(object_content, cls)
|
|
291
|
+
proto_array = ArrayProto.FromString(obj_body)
|
|
292
|
+
return cls(
|
|
293
|
+
dtype=proto_array.dtype,
|
|
294
|
+
shape=tuple(proto_array.shape),
|
|
295
|
+
stype=proto_array.stype,
|
|
296
|
+
data=proto_array.data,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def object_id(self) -> str:
|
|
301
|
+
"""Get object ID."""
|
|
302
|
+
ret = super().object_id
|
|
303
|
+
self.is_dirty = False # Reset dirty flag
|
|
304
|
+
return ret
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def is_dirty(self) -> bool:
|
|
308
|
+
"""Check if the object is dirty after the last deflation."""
|
|
309
|
+
if "_is_dirty" not in self.__dict__:
|
|
310
|
+
self.__dict__["_is_dirty"] = True
|
|
311
|
+
return cast(bool, self.__dict__["_is_dirty"])
|
|
312
|
+
|
|
313
|
+
@is_dirty.setter
|
|
314
|
+
def is_dirty(self, value: bool) -> None:
|
|
315
|
+
"""Set the dirty flag."""
|
|
316
|
+
self.__dict__["_is_dirty"] = value
|
|
317
|
+
|
|
318
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
319
|
+
"""Set attribute with special handling for dirty state."""
|
|
320
|
+
if name in ("dtype", "shape", "stype", "data"):
|
|
321
|
+
# Mark as dirty if any of the main attributes are set
|
|
322
|
+
self.is_dirty = True
|
|
323
|
+
super().__setattr__(name, value)
|