flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250321__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +10 -12
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +1 -1
- flwr/common/context.py +4 -4
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/{parametersrecord.py → arrayrecord.py} +75 -32
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/compat/grid_client_proxy.py +31 -31
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +15 -23
- flwr/server/grid/inmemory_grid.py +14 -20
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -5
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +21 -25
- flwr/server/superlink/linkstate/utils.py +18 -15
- flwr/server/superlink/serverappio/serverappio_servicer.py +3 -3
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +34 -41
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +37 -39
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +5 -5
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/RECORD +66 -66
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,288 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""RecordDict."""
|
16
|
+
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from logging import WARN
|
21
|
+
from textwrap import indent
|
22
|
+
from typing import TypeVar, Union, cast
|
23
|
+
|
24
|
+
from ..logger import log
|
25
|
+
from .arrayrecord import ArrayRecord
|
26
|
+
from .configrecord import ConfigRecord
|
27
|
+
from .metricrecord import MetricRecord
|
28
|
+
from .typeddict import TypedDict
|
29
|
+
|
30
|
+
RecordType = Union[ArrayRecord, MetricRecord, ConfigRecord]
|
31
|
+
|
32
|
+
T = TypeVar("T")
|
33
|
+
|
34
|
+
|
35
|
+
def _check_key(key: str) -> None:
|
36
|
+
if not isinstance(key, str):
|
37
|
+
raise TypeError(
|
38
|
+
f"Expected `{str.__name__}`, but "
|
39
|
+
f"received `{type(key).__name__}` for the key."
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
def _check_value(value: RecordType) -> None:
|
44
|
+
if not isinstance(value, (ArrayRecord, MetricRecord, ConfigRecord)):
|
45
|
+
raise TypeError(
|
46
|
+
f"Expected `{ArrayRecord.__name__}`, `{MetricRecord.__name__}`, "
|
47
|
+
f"or `{ConfigRecord.__name__}` but received "
|
48
|
+
f"`{type(value).__name__}` for the value."
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
class _SyncedDict(TypedDict[str, T]):
|
53
|
+
"""A synchronized dictionary that mirrors changes to an underlying RecordDict.
|
54
|
+
|
55
|
+
This dictionary ensures that any modifications (set or delete operations)
|
56
|
+
are automatically reflected in the associated `RecordDict`. Only values of
|
57
|
+
the specified `allowed_type` are permitted.
|
58
|
+
"""
|
59
|
+
|
60
|
+
def __init__(self, ref_recorddict: RecordDict, allowed_type: type[T]) -> None:
|
61
|
+
if not issubclass(allowed_type, (ArrayRecord, MetricRecord, ConfigRecord)):
|
62
|
+
raise TypeError(f"{allowed_type} is not a valid type.")
|
63
|
+
super().__init__(_check_key, self.check_value)
|
64
|
+
self.recorddict = ref_recorddict
|
65
|
+
self.allowed_type = allowed_type
|
66
|
+
|
67
|
+
def __setitem__(self, key: str, value: T) -> None:
|
68
|
+
super().__setitem__(key, value)
|
69
|
+
self.recorddict[key] = cast(RecordType, value)
|
70
|
+
|
71
|
+
def __delitem__(self, key: str) -> None:
|
72
|
+
super().__delitem__(key)
|
73
|
+
del self.recorddict[key]
|
74
|
+
|
75
|
+
def check_value(self, value: T) -> None:
|
76
|
+
"""Check if value is of expected type."""
|
77
|
+
if not isinstance(value, self.allowed_type):
|
78
|
+
raise TypeError(
|
79
|
+
f"Expected `{self.allowed_type.__name__}`, but "
|
80
|
+
f"received `{type(value).__name__}` for the value."
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
class RecordDict(TypedDict[str, RecordType]):
|
85
|
+
"""RecordDict stores groups of arrays, metrics and configs.
|
86
|
+
|
87
|
+
A :class:`RecordDict` is the unified mechanism by which arrays,
|
88
|
+
metrics and configs can be either stored as part of a :class:`Context`
|
89
|
+
in your apps or communicated as part of a :class:`Message` between
|
90
|
+
your apps.
|
91
|
+
|
92
|
+
Parameters
|
93
|
+
----------
|
94
|
+
records : Optional[dict[str, RecordType]]
|
95
|
+
A dictionary mapping string keys to record instances, where each value
|
96
|
+
is either a :class:`ParametersRecord`, :class:`MetricsRecord`,
|
97
|
+
or :class:`ConfigsRecord`.
|
98
|
+
|
99
|
+
Examples
|
100
|
+
--------
|
101
|
+
A :class:`RecordDict` can hold three types of records, each designed
|
102
|
+
with an specific purpose. What is common to all of them is that they
|
103
|
+
are Python dictionaries designed to ensure that each key-value pair
|
104
|
+
adheres to specified data types.
|
105
|
+
|
106
|
+
Let's see an example.
|
107
|
+
|
108
|
+
>>> from flwr.common import RecordDict
|
109
|
+
>>> from flwr.common import ArrayRecord, ConfigRecord, MetricRecord
|
110
|
+
>>>
|
111
|
+
>>> # Let's begin with an empty record
|
112
|
+
>>> my_records = RecordDict()
|
113
|
+
>>>
|
114
|
+
>>> # We can create a ConfigRecord
|
115
|
+
>>> c_record = ConfigRecord({"lr": 0.1, "batch-size": 128})
|
116
|
+
>>> # Adding it to the RecordDict would look like this
|
117
|
+
>>> my_records["my_config"] = c_record
|
118
|
+
>>>
|
119
|
+
>>> # We can create a MetricRecord following a similar process
|
120
|
+
>>> m_record = MetricRecord({"accuracy": 0.93, "losses": [0.23, 0.1]})
|
121
|
+
>>> # Adding it to the RecordDict would look like this
|
122
|
+
>>> my_records["my_metrics"] = m_record
|
123
|
+
|
124
|
+
Adding an :code:`ArrayRecord` follows the same steps as above but first,
|
125
|
+
the array needs to be serialized and represented as a :code:`flwr.common.Array`.
|
126
|
+
For example:
|
127
|
+
|
128
|
+
>>> from flwr.common import Array
|
129
|
+
>>> # Creating an ArrayRecord would look like this
|
130
|
+
>>> arr_np = np.random.randn(3, 3)
|
131
|
+
>>>
|
132
|
+
>>> # You can use the built-in tool to serialize the array
|
133
|
+
>>> arr = Array(arr_np)
|
134
|
+
>>>
|
135
|
+
>>> # Finally, create the record
|
136
|
+
>>> arr_record = ArrayRecord({"my_array": arr})
|
137
|
+
>>>
|
138
|
+
>>> # Adding it to the RecordDict would look like this
|
139
|
+
>>> my_records["my_parameters"] = arr_record
|
140
|
+
|
141
|
+
For additional examples on how to construct each of the records types shown
|
142
|
+
above, please refer to the documentation for :code:`ConfigRecord`,
|
143
|
+
:code:`MetricRecord` and :code:`ArrayRecord`.
|
144
|
+
"""
|
145
|
+
|
146
|
+
def __init__(self, records: dict[str, RecordType] | None = None) -> None:
|
147
|
+
super().__init__(_check_key, _check_value)
|
148
|
+
if records is not None:
|
149
|
+
for key, record in records.items():
|
150
|
+
self[key] = record
|
151
|
+
|
152
|
+
@property
|
153
|
+
def array_records(self) -> TypedDict[str, ArrayRecord]:
|
154
|
+
"""Dictionary holding only ArrayRecord instances."""
|
155
|
+
synced_dict = _SyncedDict[ArrayRecord](self, ArrayRecord)
|
156
|
+
for key, record in self.items():
|
157
|
+
if isinstance(record, ArrayRecord):
|
158
|
+
synced_dict[key] = record
|
159
|
+
return synced_dict
|
160
|
+
|
161
|
+
@property
|
162
|
+
def metric_records(self) -> TypedDict[str, MetricRecord]:
|
163
|
+
"""Dictionary holding only MetricRecord instances."""
|
164
|
+
synced_dict = _SyncedDict[MetricRecord](self, MetricRecord)
|
165
|
+
for key, record in self.items():
|
166
|
+
if isinstance(record, MetricRecord):
|
167
|
+
synced_dict[key] = record
|
168
|
+
return synced_dict
|
169
|
+
|
170
|
+
@property
|
171
|
+
def config_records(self) -> TypedDict[str, ConfigRecord]:
|
172
|
+
"""Dictionary holding only ConfigRecord instances."""
|
173
|
+
synced_dict = _SyncedDict[ConfigRecord](self, ConfigRecord)
|
174
|
+
for key, record in self.items():
|
175
|
+
if isinstance(record, ConfigRecord):
|
176
|
+
synced_dict[key] = record
|
177
|
+
return synced_dict
|
178
|
+
|
179
|
+
def __repr__(self) -> str:
|
180
|
+
"""Return a string representation of this instance."""
|
181
|
+
flds = ("array_records", "metric_records", "config_records")
|
182
|
+
fld_views = [f"{fld}={dict(getattr(self, fld))!r}" for fld in flds]
|
183
|
+
view = indent(",\n".join(fld_views), " ")
|
184
|
+
return f"{self.__class__.__qualname__}(\n{view}\n)"
|
185
|
+
|
186
|
+
def __setitem__(self, key: str, value: RecordType) -> None:
|
187
|
+
"""Set the given key to the given value after type checking."""
|
188
|
+
original_value = self.get(key, None)
|
189
|
+
super().__setitem__(key, value)
|
190
|
+
if original_value is not None and not isinstance(value, type(original_value)):
|
191
|
+
log(
|
192
|
+
WARN,
|
193
|
+
"Key '%s' was overwritten: record of type `%s` replaced with type `%s`",
|
194
|
+
key,
|
195
|
+
type(original_value).__name__,
|
196
|
+
type(value).__name__,
|
197
|
+
)
|
198
|
+
|
199
|
+
|
200
|
+
class RecordSet(RecordDict):
|
201
|
+
"""Deprecated class ``RecordSet``, use ``RecordDict`` instead.
|
202
|
+
|
203
|
+
This class exists solely for backward compatibility with legacy
|
204
|
+
code that previously used ``RecordSet``. It has been renamed
|
205
|
+
to ``RecordDict`` and will be removed in a future release.
|
206
|
+
|
207
|
+
.. warning::
|
208
|
+
``RecordSet`` is deprecated and will be removed in a future release.
|
209
|
+
Use ``RecordDict`` instead.
|
210
|
+
|
211
|
+
Examples
|
212
|
+
--------
|
213
|
+
Legacy (deprecated) usage::
|
214
|
+
|
215
|
+
from flwr.common import RecordSet
|
216
|
+
|
217
|
+
my_content = RecordSet()
|
218
|
+
|
219
|
+
Updated usage::
|
220
|
+
|
221
|
+
from flwr.common import RecordDict
|
222
|
+
|
223
|
+
my_content = RecordDict()
|
224
|
+
"""
|
225
|
+
|
226
|
+
_warning_logged = False
|
227
|
+
_warning_logged_params = False
|
228
|
+
_warning_logged_metrics = False
|
229
|
+
_warning_logged_configs = False
|
230
|
+
|
231
|
+
def __init__(self, records: dict[str, RecordType] | None = None) -> None:
|
232
|
+
if not RecordSet._warning_logged:
|
233
|
+
RecordSet._warning_logged = True
|
234
|
+
log(
|
235
|
+
WARN,
|
236
|
+
"The `RecordSet` class has been renamed to `RecordDict`. "
|
237
|
+
"Support for `RecordSet` will be removed in a future release. "
|
238
|
+
"Please update your code accordingly.",
|
239
|
+
)
|
240
|
+
super().__init__(records)
|
241
|
+
|
242
|
+
@property
|
243
|
+
def parameters_records(self) -> TypedDict[str, ArrayRecord]:
|
244
|
+
"""Deprecated property.
|
245
|
+
|
246
|
+
Use ``array_records`` instead.
|
247
|
+
"""
|
248
|
+
if not RecordSet._warning_logged_params:
|
249
|
+
RecordSet._warning_logged_params = True
|
250
|
+
log(
|
251
|
+
WARN,
|
252
|
+
"`RecordSet.parameters_records` has been deprecated "
|
253
|
+
"and will be removed in a future release. Please use "
|
254
|
+
"`RecordDict.array_records` instead.",
|
255
|
+
)
|
256
|
+
return self.array_records
|
257
|
+
|
258
|
+
@property
|
259
|
+
def metrics_records(self) -> TypedDict[str, MetricRecord]:
|
260
|
+
"""Deprecated property.
|
261
|
+
|
262
|
+
Use ``metric_records`` instead.
|
263
|
+
"""
|
264
|
+
if not RecordSet._warning_logged_metrics:
|
265
|
+
RecordSet._warning_logged_metrics = True
|
266
|
+
log(
|
267
|
+
WARN,
|
268
|
+
"`RecordSet.metrics_records` has been deprecated "
|
269
|
+
"and will be removed in a future release. Please use "
|
270
|
+
"`RecordDict.metric_records` instead.",
|
271
|
+
)
|
272
|
+
return self.metric_records
|
273
|
+
|
274
|
+
@property
|
275
|
+
def configs_records(self) -> TypedDict[str, ConfigRecord]:
|
276
|
+
"""Deprecated property.
|
277
|
+
|
278
|
+
Use ``config_records`` instead.
|
279
|
+
"""
|
280
|
+
if not RecordSet._warning_logged_configs:
|
281
|
+
RecordSet._warning_logged_configs = True
|
282
|
+
log(
|
283
|
+
WARN,
|
284
|
+
"`RecordSet.configs_records` has been deprecated "
|
285
|
+
"and will be removed in a future release. Please use "
|
286
|
+
"`RecordDict.config_records` instead.",
|
287
|
+
)
|
288
|
+
return self.config_records
|
@@ -0,0 +1,410 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""RecordDict utilities."""
|
16
|
+
|
17
|
+
|
18
|
+
from collections import OrderedDict
|
19
|
+
from collections.abc import Mapping
|
20
|
+
from typing import Union, cast, get_args
|
21
|
+
|
22
|
+
from . import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
23
|
+
from .typing import (
|
24
|
+
Code,
|
25
|
+
ConfigRecordValues,
|
26
|
+
EvaluateIns,
|
27
|
+
EvaluateRes,
|
28
|
+
FitIns,
|
29
|
+
FitRes,
|
30
|
+
GetParametersIns,
|
31
|
+
GetParametersRes,
|
32
|
+
GetPropertiesIns,
|
33
|
+
GetPropertiesRes,
|
34
|
+
MetricRecordValues,
|
35
|
+
Parameters,
|
36
|
+
Scalar,
|
37
|
+
Status,
|
38
|
+
)
|
39
|
+
|
40
|
+
EMPTY_TENSOR_KEY = "_empty"
|
41
|
+
|
42
|
+
|
43
|
+
def arrayrecord_to_parameters(record: ArrayRecord, keep_input: bool) -> Parameters:
|
44
|
+
"""Convert ParameterRecord to legacy Parameters.
|
45
|
+
|
46
|
+
Warnings
|
47
|
+
--------
|
48
|
+
Because `Array`s in `ArrayRecord` encode more information of the
|
49
|
+
array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
|
50
|
+
might not be possible to reconstruct such data structures from `Parameters` objects
|
51
|
+
alone. Additional information or metadata must be provided from elsewhere.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
record : ArrayRecord
|
56
|
+
The record to be conveted into Parameters.
|
57
|
+
keep_input : bool
|
58
|
+
A boolean indicating whether entries in the record should be deleted from the
|
59
|
+
input dictionary immediately after adding them to the record.
|
60
|
+
|
61
|
+
Returns
|
62
|
+
-------
|
63
|
+
parameters : Parameters
|
64
|
+
The parameters in the legacy format Parameters.
|
65
|
+
"""
|
66
|
+
parameters = Parameters(tensors=[], tensor_type="")
|
67
|
+
|
68
|
+
for key in list(record.keys()):
|
69
|
+
if key != EMPTY_TENSOR_KEY:
|
70
|
+
parameters.tensors.append(record[key].data)
|
71
|
+
|
72
|
+
if not parameters.tensor_type:
|
73
|
+
# Setting from first array in record. Recall the warning in the docstrings
|
74
|
+
# of this function.
|
75
|
+
parameters.tensor_type = record[key].stype
|
76
|
+
|
77
|
+
if not keep_input:
|
78
|
+
del record[key]
|
79
|
+
|
80
|
+
return parameters
|
81
|
+
|
82
|
+
|
83
|
+
def parameters_to_arrayrecord(parameters: Parameters, keep_input: bool) -> ArrayRecord:
|
84
|
+
"""Convert legacy Parameters into a single ArrayRecord.
|
85
|
+
|
86
|
+
Because there is no concept of names in the legacy Parameters, arbitrary keys will
|
87
|
+
be used when constructing the ArrayRecord. Similarly, the shape and data type
|
88
|
+
won't be recorded in the Array objects.
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
----------
|
92
|
+
parameters : Parameters
|
93
|
+
Parameters object to be represented as a ArrayRecord.
|
94
|
+
keep_input : bool
|
95
|
+
A boolean indicating whether parameters should be deleted from the input
|
96
|
+
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
|
97
|
+
adding them to the record.
|
98
|
+
|
99
|
+
Returns
|
100
|
+
-------
|
101
|
+
ArrayRecord
|
102
|
+
The ArrayRecord containing the provided parameters.
|
103
|
+
"""
|
104
|
+
tensor_type = parameters.tensor_type
|
105
|
+
|
106
|
+
num_arrays = len(parameters.tensors)
|
107
|
+
ordered_dict = OrderedDict()
|
108
|
+
for idx in range(num_arrays):
|
109
|
+
if keep_input:
|
110
|
+
tensor = parameters.tensors[idx]
|
111
|
+
else:
|
112
|
+
tensor = parameters.tensors.pop(0)
|
113
|
+
ordered_dict[str(idx)] = Array(
|
114
|
+
data=tensor, dtype="", stype=tensor_type, shape=[]
|
115
|
+
)
|
116
|
+
|
117
|
+
if num_arrays == 0:
|
118
|
+
ordered_dict[EMPTY_TENSOR_KEY] = Array(
|
119
|
+
data=b"", dtype="", stype=tensor_type, shape=[]
|
120
|
+
)
|
121
|
+
return ArrayRecord(ordered_dict, keep_input=keep_input)
|
122
|
+
|
123
|
+
|
124
|
+
def _check_mapping_from_recordscalartype_to_scalar(
|
125
|
+
record_data: Mapping[str, Union[ConfigRecordValues, MetricRecordValues]]
|
126
|
+
) -> dict[str, Scalar]:
|
127
|
+
"""Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
|
128
|
+
for value in record_data.values():
|
129
|
+
if not isinstance(value, get_args(Scalar)):
|
130
|
+
raise TypeError(
|
131
|
+
"There is not a 1:1 mapping between `common.Scalar` types and those "
|
132
|
+
"supported in `common.ConfigRecordValues` or "
|
133
|
+
"`common.ConfigRecordValues`. Consider casting your values to a type "
|
134
|
+
"supported by the `common.RecordDict` infrastructure. "
|
135
|
+
f"You used type: {type(value)}"
|
136
|
+
)
|
137
|
+
return cast(dict[str, Scalar], record_data)
|
138
|
+
|
139
|
+
|
140
|
+
def _recorddict_to_fit_or_evaluate_ins_components(
|
141
|
+
recorddict: RecordDict,
|
142
|
+
ins_str: str,
|
143
|
+
keep_input: bool,
|
144
|
+
) -> tuple[Parameters, dict[str, Scalar]]:
|
145
|
+
"""Derive Fit/Evaluate Ins from a RecordDict."""
|
146
|
+
# get Array and construct Parameters
|
147
|
+
array_record = recorddict.array_records[f"{ins_str}.parameters"]
|
148
|
+
|
149
|
+
parameters = arrayrecord_to_parameters(array_record, keep_input=keep_input)
|
150
|
+
|
151
|
+
# get config dict
|
152
|
+
config_record = recorddict.config_records[f"{ins_str}.config"]
|
153
|
+
# pylint: disable-next=protected-access
|
154
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
155
|
+
|
156
|
+
return parameters, config_dict
|
157
|
+
|
158
|
+
|
159
|
+
def _fit_or_evaluate_ins_to_recorddict(
|
160
|
+
ins: Union[FitIns, EvaluateIns], keep_input: bool
|
161
|
+
) -> RecordDict:
|
162
|
+
recorddict = RecordDict()
|
163
|
+
|
164
|
+
ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins"
|
165
|
+
arr_record = parameters_to_arrayrecord(ins.parameters, keep_input)
|
166
|
+
recorddict.array_records[f"{ins_str}.parameters"] = arr_record
|
167
|
+
|
168
|
+
recorddict.config_records[f"{ins_str}.config"] = ConfigRecord(
|
169
|
+
ins.config # type: ignore
|
170
|
+
)
|
171
|
+
|
172
|
+
return recorddict
|
173
|
+
|
174
|
+
|
175
|
+
def _embed_status_into_recorddict(
|
176
|
+
res_str: str, status: Status, recorddict: RecordDict
|
177
|
+
) -> RecordDict:
|
178
|
+
status_dict: dict[str, ConfigRecordValues] = {
|
179
|
+
"code": int(status.code.value),
|
180
|
+
"message": status.message,
|
181
|
+
}
|
182
|
+
# we add it to a `ConfigRecord` because the `status.message` is a string
|
183
|
+
# and `str` values aren't supported in `MetricRecords`
|
184
|
+
recorddict.config_records[f"{res_str}.status"] = ConfigRecord(status_dict)
|
185
|
+
return recorddict
|
186
|
+
|
187
|
+
|
188
|
+
def _extract_status_from_recorddict(res_str: str, recorddict: RecordDict) -> Status:
|
189
|
+
status = recorddict.config_records[f"{res_str}.status"]
|
190
|
+
code = cast(int, status["code"])
|
191
|
+
return Status(code=Code(code), message=str(status["message"]))
|
192
|
+
|
193
|
+
|
194
|
+
def recorddict_to_fitins(recorddict: RecordDict, keep_input: bool) -> FitIns:
|
195
|
+
"""Derive FitIns from a RecordDict object."""
|
196
|
+
parameters, config = _recorddict_to_fit_or_evaluate_ins_components(
|
197
|
+
recorddict,
|
198
|
+
ins_str="fitins",
|
199
|
+
keep_input=keep_input,
|
200
|
+
)
|
201
|
+
|
202
|
+
return FitIns(parameters=parameters, config=config)
|
203
|
+
|
204
|
+
|
205
|
+
def fitins_to_recorddict(fitins: FitIns, keep_input: bool) -> RecordDict:
|
206
|
+
"""Construct a RecordDict from a FitIns object."""
|
207
|
+
return _fit_or_evaluate_ins_to_recorddict(fitins, keep_input)
|
208
|
+
|
209
|
+
|
210
|
+
def recorddict_to_fitres(recorddict: RecordDict, keep_input: bool) -> FitRes:
|
211
|
+
"""Derive FitRes from a RecordDict object."""
|
212
|
+
ins_str = "fitres"
|
213
|
+
parameters = arrayrecord_to_parameters(
|
214
|
+
recorddict.array_records[f"{ins_str}.parameters"], keep_input=keep_input
|
215
|
+
)
|
216
|
+
|
217
|
+
num_examples = cast(
|
218
|
+
int, recorddict.metric_records[f"{ins_str}.num_examples"]["num_examples"]
|
219
|
+
)
|
220
|
+
config_record = recorddict.config_records[f"{ins_str}.metrics"]
|
221
|
+
# pylint: disable-next=protected-access
|
222
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
223
|
+
status = _extract_status_from_recorddict(ins_str, recorddict)
|
224
|
+
|
225
|
+
return FitRes(
|
226
|
+
status=status, parameters=parameters, num_examples=num_examples, metrics=metrics
|
227
|
+
)
|
228
|
+
|
229
|
+
|
230
|
+
def fitres_to_recorddict(fitres: FitRes, keep_input: bool) -> RecordDict:
|
231
|
+
"""Construct a RecordDict from a FitRes object."""
|
232
|
+
recorddict = RecordDict()
|
233
|
+
|
234
|
+
res_str = "fitres"
|
235
|
+
|
236
|
+
recorddict.config_records[f"{res_str}.metrics"] = ConfigRecord(
|
237
|
+
fitres.metrics # type: ignore
|
238
|
+
)
|
239
|
+
recorddict.metric_records[f"{res_str}.num_examples"] = MetricRecord(
|
240
|
+
{"num_examples": fitres.num_examples},
|
241
|
+
)
|
242
|
+
recorddict.array_records[f"{res_str}.parameters"] = parameters_to_arrayrecord(
|
243
|
+
fitres.parameters,
|
244
|
+
keep_input,
|
245
|
+
)
|
246
|
+
|
247
|
+
# status
|
248
|
+
recorddict = _embed_status_into_recorddict(res_str, fitres.status, recorddict)
|
249
|
+
|
250
|
+
return recorddict
|
251
|
+
|
252
|
+
|
253
|
+
def recorddict_to_evaluateins(recorddict: RecordDict, keep_input: bool) -> EvaluateIns:
|
254
|
+
"""Derive EvaluateIns from a RecordDict object."""
|
255
|
+
parameters, config = _recorddict_to_fit_or_evaluate_ins_components(
|
256
|
+
recorddict,
|
257
|
+
ins_str="evaluateins",
|
258
|
+
keep_input=keep_input,
|
259
|
+
)
|
260
|
+
|
261
|
+
return EvaluateIns(parameters=parameters, config=config)
|
262
|
+
|
263
|
+
|
264
|
+
def evaluateins_to_recorddict(evaluateins: EvaluateIns, keep_input: bool) -> RecordDict:
|
265
|
+
"""Construct a RecordDict from a EvaluateIns object."""
|
266
|
+
return _fit_or_evaluate_ins_to_recorddict(evaluateins, keep_input)
|
267
|
+
|
268
|
+
|
269
|
+
def recorddict_to_evaluateres(recorddict: RecordDict) -> EvaluateRes:
|
270
|
+
"""Derive EvaluateRes from a RecordDict object."""
|
271
|
+
ins_str = "evaluateres"
|
272
|
+
|
273
|
+
loss = cast(int, recorddict.metric_records[f"{ins_str}.loss"]["loss"])
|
274
|
+
|
275
|
+
num_examples = cast(
|
276
|
+
int, recorddict.metric_records[f"{ins_str}.num_examples"]["num_examples"]
|
277
|
+
)
|
278
|
+
config_record = recorddict.config_records[f"{ins_str}.metrics"]
|
279
|
+
|
280
|
+
# pylint: disable-next=protected-access
|
281
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
282
|
+
status = _extract_status_from_recorddict(ins_str, recorddict)
|
283
|
+
|
284
|
+
return EvaluateRes(
|
285
|
+
status=status, loss=loss, num_examples=num_examples, metrics=metrics
|
286
|
+
)
|
287
|
+
|
288
|
+
|
289
|
+
def evaluateres_to_recorddict(evaluateres: EvaluateRes) -> RecordDict:
|
290
|
+
"""Construct a RecordDict from a EvaluateRes object."""
|
291
|
+
recorddict = RecordDict()
|
292
|
+
|
293
|
+
res_str = "evaluateres"
|
294
|
+
# loss
|
295
|
+
recorddict.metric_records[f"{res_str}.loss"] = MetricRecord(
|
296
|
+
{"loss": evaluateres.loss},
|
297
|
+
)
|
298
|
+
|
299
|
+
# num_examples
|
300
|
+
recorddict.metric_records[f"{res_str}.num_examples"] = MetricRecord(
|
301
|
+
{"num_examples": evaluateres.num_examples},
|
302
|
+
)
|
303
|
+
|
304
|
+
# metrics
|
305
|
+
recorddict.config_records[f"{res_str}.metrics"] = ConfigRecord(
|
306
|
+
evaluateres.metrics, # type: ignore
|
307
|
+
)
|
308
|
+
|
309
|
+
# status
|
310
|
+
recorddict = _embed_status_into_recorddict(
|
311
|
+
f"{res_str}", evaluateres.status, recorddict
|
312
|
+
)
|
313
|
+
|
314
|
+
return recorddict
|
315
|
+
|
316
|
+
|
317
|
+
def recorddict_to_getparametersins(recorddict: RecordDict) -> GetParametersIns:
|
318
|
+
"""Derive GetParametersIns from a RecordDict object."""
|
319
|
+
config_record = recorddict.config_records["getparametersins.config"]
|
320
|
+
# pylint: disable-next=protected-access
|
321
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
322
|
+
|
323
|
+
return GetParametersIns(config=config_dict)
|
324
|
+
|
325
|
+
|
326
|
+
def getparametersins_to_recorddict(getparameters_ins: GetParametersIns) -> RecordDict:
|
327
|
+
"""Construct a RecordDict from a GetParametersIns object."""
|
328
|
+
recorddict = RecordDict()
|
329
|
+
|
330
|
+
recorddict.config_records["getparametersins.config"] = ConfigRecord(
|
331
|
+
getparameters_ins.config, # type: ignore
|
332
|
+
)
|
333
|
+
return recorddict
|
334
|
+
|
335
|
+
|
336
|
+
def getparametersres_to_recorddict(
|
337
|
+
getparametersres: GetParametersRes, keep_input: bool
|
338
|
+
) -> RecordDict:
|
339
|
+
"""Construct a RecordDict from a GetParametersRes object."""
|
340
|
+
recorddict = RecordDict()
|
341
|
+
res_str = "getparametersres"
|
342
|
+
array_record = parameters_to_arrayrecord(
|
343
|
+
getparametersres.parameters, keep_input=keep_input
|
344
|
+
)
|
345
|
+
recorddict.array_records[f"{res_str}.parameters"] = array_record
|
346
|
+
|
347
|
+
# status
|
348
|
+
recorddict = _embed_status_into_recorddict(
|
349
|
+
res_str, getparametersres.status, recorddict
|
350
|
+
)
|
351
|
+
|
352
|
+
return recorddict
|
353
|
+
|
354
|
+
|
355
|
+
def recorddict_to_getparametersres(
|
356
|
+
recorddict: RecordDict, keep_input: bool
|
357
|
+
) -> GetParametersRes:
|
358
|
+
"""Derive GetParametersRes from a RecordDict object."""
|
359
|
+
res_str = "getparametersres"
|
360
|
+
parameters = arrayrecord_to_parameters(
|
361
|
+
recorddict.array_records[f"{res_str}.parameters"], keep_input=keep_input
|
362
|
+
)
|
363
|
+
|
364
|
+
status = _extract_status_from_recorddict(res_str, recorddict)
|
365
|
+
return GetParametersRes(status=status, parameters=parameters)
|
366
|
+
|
367
|
+
|
368
|
+
def recorddict_to_getpropertiesins(recorddict: RecordDict) -> GetPropertiesIns:
|
369
|
+
"""Derive GetPropertiesIns from a RecordDict object."""
|
370
|
+
config_record = recorddict.config_records["getpropertiesins.config"]
|
371
|
+
# pylint: disable-next=protected-access
|
372
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
373
|
+
|
374
|
+
return GetPropertiesIns(config=config_dict)
|
375
|
+
|
376
|
+
|
377
|
+
def getpropertiesins_to_recorddict(getpropertiesins: GetPropertiesIns) -> RecordDict:
|
378
|
+
"""Construct a RecordDict from a GetPropertiesRes object."""
|
379
|
+
recorddict = RecordDict()
|
380
|
+
recorddict.config_records["getpropertiesins.config"] = ConfigRecord(
|
381
|
+
getpropertiesins.config, # type: ignore
|
382
|
+
)
|
383
|
+
return recorddict
|
384
|
+
|
385
|
+
|
386
|
+
def recorddict_to_getpropertiesres(recorddict: RecordDict) -> GetPropertiesRes:
|
387
|
+
"""Derive GetPropertiesRes from a RecordDict object."""
|
388
|
+
res_str = "getpropertiesres"
|
389
|
+
config_record = recorddict.config_records[f"{res_str}.properties"]
|
390
|
+
# pylint: disable-next=protected-access
|
391
|
+
properties = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
392
|
+
|
393
|
+
status = _extract_status_from_recorddict(res_str, recorddict=recorddict)
|
394
|
+
|
395
|
+
return GetPropertiesRes(status=status, properties=properties)
|
396
|
+
|
397
|
+
|
398
|
+
def getpropertiesres_to_recorddict(getpropertiesres: GetPropertiesRes) -> RecordDict:
|
399
|
+
"""Construct a RecordDict from a GetPropertiesRes object."""
|
400
|
+
recorddict = RecordDict()
|
401
|
+
res_str = "getpropertiesres"
|
402
|
+
recorddict.config_records[f"{res_str}.properties"] = ConfigRecord(
|
403
|
+
getpropertiesres.properties, # type: ignore
|
404
|
+
)
|
405
|
+
# status
|
406
|
+
recorddict = _embed_status_into_recorddict(
|
407
|
+
res_str, getpropertiesres.status, recorddict
|
408
|
+
)
|
409
|
+
|
410
|
+
return recorddict
|