flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +82 -57
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +10 -18
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +31 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +4 -4
- flwr/client/grpc_rere_client/connection.py +130 -60
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +173 -67
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +36 -7
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -183
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +19 -159
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +68 -186
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +18 -0
- flwr/server/superlink/ffs/__init__.py +2 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
- flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
- flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
- flwr/server/superlink/utils.py +44 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +7 -3
- flwr/superexec/exec_servicer.py +125 -23
- flwr/superexec/exec_user_auth_interceptor.py +37 -8
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
|
@@ -15,12 +15,21 @@
|
|
|
15
15
|
"""ConfigRecord."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
from logging import WARN
|
|
19
|
-
from typing import
|
|
21
|
+
from typing import cast, get_args
|
|
20
22
|
|
|
21
23
|
from flwr.common.typing import ConfigRecordValues, ConfigScalar
|
|
22
24
|
|
|
25
|
+
# pylint: disable=E0611
|
|
26
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
27
|
+
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
|
28
|
+
|
|
29
|
+
# pylint: enable=E0611
|
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
23
31
|
from ..logger import log
|
|
32
|
+
from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
|
24
33
|
from .typeddict import TypedDict
|
|
25
34
|
|
|
26
35
|
|
|
@@ -59,7 +68,7 @@ def _check_value(value: ConfigRecordValues) -> None:
|
|
|
59
68
|
is_valid(value)
|
|
60
69
|
|
|
61
70
|
|
|
62
|
-
class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
71
|
+
class ConfigRecord(TypedDict[str, ConfigRecordValues], InflatableObject):
|
|
63
72
|
"""Config record.
|
|
64
73
|
|
|
65
74
|
A :code:`ConfigRecord` is a Python dictionary designed to ensure that
|
|
@@ -111,7 +120,7 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
|
111
120
|
|
|
112
121
|
def __init__(
|
|
113
122
|
self,
|
|
114
|
-
config_dict:
|
|
123
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
|
115
124
|
keep_input: bool = True,
|
|
116
125
|
) -> None:
|
|
117
126
|
|
|
@@ -164,6 +173,52 @@ class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
|
|
164
173
|
|
|
165
174
|
return num_bytes
|
|
166
175
|
|
|
176
|
+
def deflate(self) -> bytes:
|
|
177
|
+
"""Deflate object."""
|
|
178
|
+
protos = record_value_dict_to_proto(
|
|
179
|
+
self,
|
|
180
|
+
[bool, int, float, str, bytes],
|
|
181
|
+
ProtoConfigRecordValue,
|
|
182
|
+
)
|
|
183
|
+
obj_body = ProtoConfigRecord(
|
|
184
|
+
items=[ProtoConfigRecord.Item(key=k, value=v) for k, v in protos.items()]
|
|
185
|
+
).SerializeToString()
|
|
186
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def inflate(
|
|
190
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
191
|
+
) -> ConfigRecord:
|
|
192
|
+
"""Inflate a ConfigRecord from bytes.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
object_content : bytes
|
|
197
|
+
The deflated object content of the ConfigRecord.
|
|
198
|
+
|
|
199
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
200
|
+
Must be ``None``. ``ConfigRecord`` does not support child objects.
|
|
201
|
+
Providing any children will raise a ``ValueError``.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
ConfigRecord
|
|
206
|
+
The inflated ConfigRecord.
|
|
207
|
+
"""
|
|
208
|
+
if children:
|
|
209
|
+
raise ValueError("`ConfigRecord` objects do not have children.")
|
|
210
|
+
|
|
211
|
+
obj_body = get_object_body(object_content, cls)
|
|
212
|
+
config_record_proto = ProtoConfigRecord.FromString(obj_body)
|
|
213
|
+
protos = {item.key: item.value for item in config_record_proto.items}
|
|
214
|
+
return ConfigRecord(
|
|
215
|
+
config_dict=cast(
|
|
216
|
+
dict[str, ConfigRecordValues],
|
|
217
|
+
record_value_dict_from_proto(protos),
|
|
218
|
+
),
|
|
219
|
+
keep_input=False,
|
|
220
|
+
)
|
|
221
|
+
|
|
167
222
|
|
|
168
223
|
class ConfigsRecord(ConfigRecord):
|
|
169
224
|
"""Deprecated class ``ConfigsRecord``, use ``ConfigRecord`` instead.
|
|
@@ -195,7 +250,7 @@ class ConfigsRecord(ConfigRecord):
|
|
|
195
250
|
|
|
196
251
|
def __init__(
|
|
197
252
|
self,
|
|
198
|
-
config_dict:
|
|
253
|
+
config_dict: dict[str, ConfigRecordValues] | None = None,
|
|
199
254
|
keep_input: bool = True,
|
|
200
255
|
):
|
|
201
256
|
if not ConfigsRecord._warning_logged:
|
|
@@ -15,12 +15,21 @@
|
|
|
15
15
|
"""MetricRecord."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
from logging import WARN
|
|
19
|
-
from typing import
|
|
21
|
+
from typing import cast, get_args
|
|
20
22
|
|
|
21
23
|
from flwr.common.typing import MetricRecordValues, MetricScalar
|
|
22
24
|
|
|
25
|
+
# pylint: disable=E0611
|
|
26
|
+
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
|
|
27
|
+
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
|
|
28
|
+
|
|
29
|
+
# pylint: enable=E0611
|
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
23
31
|
from ..logger import log
|
|
32
|
+
from ..serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
|
|
24
33
|
from .typeddict import TypedDict
|
|
25
34
|
|
|
26
35
|
|
|
@@ -59,7 +68,7 @@ def _check_value(value: MetricRecordValues) -> None:
|
|
|
59
68
|
is_valid(value)
|
|
60
69
|
|
|
61
70
|
|
|
62
|
-
class MetricRecord(TypedDict[str, MetricRecordValues]):
|
|
71
|
+
class MetricRecord(TypedDict[str, MetricRecordValues], InflatableObject):
|
|
63
72
|
"""Metric record.
|
|
64
73
|
|
|
65
74
|
A :code:`MetricRecord` is a Python dictionary designed to ensure that
|
|
@@ -117,7 +126,7 @@ class MetricRecord(TypedDict[str, MetricRecordValues]):
|
|
|
117
126
|
|
|
118
127
|
def __init__(
|
|
119
128
|
self,
|
|
120
|
-
metric_dict:
|
|
129
|
+
metric_dict: dict[str, MetricRecordValues] | None = None,
|
|
121
130
|
keep_input: bool = True,
|
|
122
131
|
) -> None:
|
|
123
132
|
super().__init__(_check_key, _check_value)
|
|
@@ -143,6 +152,48 @@ class MetricRecord(TypedDict[str, MetricRecordValues]):
|
|
|
143
152
|
num_bytes += len(k)
|
|
144
153
|
return num_bytes
|
|
145
154
|
|
|
155
|
+
def deflate(self) -> bytes:
|
|
156
|
+
"""Deflate object."""
|
|
157
|
+
protos = record_value_dict_to_proto(self, [float, int], ProtoMetricRecordValue)
|
|
158
|
+
obj_body = ProtoMetricRecord(
|
|
159
|
+
items=[ProtoMetricRecord.Item(key=k, value=v) for k, v in protos.items()]
|
|
160
|
+
).SerializeToString()
|
|
161
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def inflate(
|
|
165
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
166
|
+
) -> MetricRecord:
|
|
167
|
+
"""Inflate a MetricRecord from bytes.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
object_content : bytes
|
|
172
|
+
The deflated object content of the MetricRecord.
|
|
173
|
+
|
|
174
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
175
|
+
Must be ``None``. ``MetricRecord`` does not support child objects.
|
|
176
|
+
Providing any children will raise a ``ValueError``.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
MetricRecord
|
|
181
|
+
The inflated MetricRecord.
|
|
182
|
+
"""
|
|
183
|
+
if children:
|
|
184
|
+
raise ValueError("`MetricRecord` objects do not have children.")
|
|
185
|
+
|
|
186
|
+
obj_body = get_object_body(object_content, cls)
|
|
187
|
+
metric_record_proto = ProtoMetricRecord.FromString(obj_body)
|
|
188
|
+
protos = {item.key: item.value for item in metric_record_proto.items}
|
|
189
|
+
return cls(
|
|
190
|
+
metric_dict=cast(
|
|
191
|
+
dict[str, MetricRecordValues],
|
|
192
|
+
record_value_dict_from_proto(protos),
|
|
193
|
+
),
|
|
194
|
+
keep_input=False,
|
|
195
|
+
)
|
|
196
|
+
|
|
146
197
|
|
|
147
198
|
class MetricsRecord(MetricRecord):
|
|
148
199
|
"""Deprecated class ``MetricsRecord``, use ``MetricRecord`` instead.
|
|
@@ -174,7 +225,7 @@ class MetricsRecord(MetricRecord):
|
|
|
174
225
|
|
|
175
226
|
def __init__(
|
|
176
227
|
self,
|
|
177
|
-
metric_dict:
|
|
228
|
+
metric_dict: dict[str, MetricRecordValues] | None = None,
|
|
178
229
|
keep_input: bool = True,
|
|
179
230
|
):
|
|
180
231
|
if not MetricsRecord._warning_logged:
|
flwr/common/record/recorddict.py
CHANGED
|
@@ -17,10 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
+
import json
|
|
20
21
|
from logging import WARN
|
|
21
22
|
from textwrap import indent
|
|
22
23
|
from typing import TypeVar, Union, cast
|
|
23
24
|
|
|
25
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
24
26
|
from ..logger import log
|
|
25
27
|
from .arrayrecord import ArrayRecord
|
|
26
28
|
from .configrecord import ConfigRecord
|
|
@@ -97,7 +99,7 @@ class _SyncedDict(TypedDict[str, T]):
|
|
|
97
99
|
)
|
|
98
100
|
|
|
99
101
|
|
|
100
|
-
class RecordDict(TypedDict[str, RecordType]):
|
|
102
|
+
class RecordDict(TypedDict[str, RecordType], InflatableObject):
|
|
101
103
|
"""RecordDict stores groups of arrays, metrics and configs.
|
|
102
104
|
|
|
103
105
|
A :class:`RecordDict` is the unified mechanism by which arrays,
|
|
@@ -286,6 +288,72 @@ class RecordDict(TypedDict[str, RecordType]):
|
|
|
286
288
|
)
|
|
287
289
|
return self.config_records
|
|
288
290
|
|
|
291
|
+
@property
|
|
292
|
+
def children(self) -> dict[str, InflatableObject]:
|
|
293
|
+
"""Return a dictionary of records with their Object IDs as keys."""
|
|
294
|
+
return {record.object_id: record for record in self.values()}
|
|
295
|
+
|
|
296
|
+
def deflate(self) -> bytes:
|
|
297
|
+
"""Deflate the RecordDict."""
|
|
298
|
+
# record_name: record_object_id mapping
|
|
299
|
+
record_refs: dict[str, str] = {}
|
|
300
|
+
|
|
301
|
+
for record_name, record in self.items():
|
|
302
|
+
record_refs[record_name] = record.object_id
|
|
303
|
+
|
|
304
|
+
# Serialize references dict
|
|
305
|
+
object_body = json.dumps(record_refs).encode("utf-8")
|
|
306
|
+
return add_header_to_object_body(object_body=object_body, obj=self)
|
|
307
|
+
|
|
308
|
+
@classmethod
|
|
309
|
+
def inflate(
|
|
310
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
311
|
+
) -> RecordDict:
|
|
312
|
+
"""Inflate an RecordDict from bytes.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
object_content : bytes
|
|
317
|
+
The deflated object content of the RecordDict.
|
|
318
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
319
|
+
Dictionary of children InflatableObjects mapped to their Object IDs.
|
|
320
|
+
These children enable the full inflation of the RecordDict. Default is None.
|
|
321
|
+
|
|
322
|
+
Returns
|
|
323
|
+
-------
|
|
324
|
+
RecordDict
|
|
325
|
+
The inflated RecordDict.
|
|
326
|
+
"""
|
|
327
|
+
if children is None:
|
|
328
|
+
children = {}
|
|
329
|
+
|
|
330
|
+
# Inflate mapping of record_names (keys in the RecordDict) to Record' object IDs
|
|
331
|
+
obj_body = get_object_body(object_content, cls)
|
|
332
|
+
record_refs: dict[str, str] = json.loads(obj_body.decode(encoding="utf-8"))
|
|
333
|
+
|
|
334
|
+
unique_records = set(record_refs.values())
|
|
335
|
+
children_obj_ids = set(children.keys())
|
|
336
|
+
if unique_records != children_obj_ids:
|
|
337
|
+
raise ValueError(
|
|
338
|
+
"Unexpected set of `children`. "
|
|
339
|
+
f"Expected {unique_records} but got {children_obj_ids}."
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Ensure children are one of the *Record objects exepecte in a RecordDict
|
|
343
|
+
if not all(
|
|
344
|
+
isinstance(ch, (ArrayRecord, ConfigRecord, MetricRecord))
|
|
345
|
+
for ch in children.values()
|
|
346
|
+
):
|
|
347
|
+
raise ValueError(
|
|
348
|
+
"`Children` are expected to be of type `ArrayRecord`, "
|
|
349
|
+
"`ConfigRecord` or `MetricRecord`."
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Instantiate new RecordDict
|
|
353
|
+
return RecordDict(
|
|
354
|
+
{name: children[object_id] for name, object_id in record_refs.items()} # type: ignore
|
|
355
|
+
)
|
|
356
|
+
|
|
289
357
|
|
|
290
358
|
class RecordSet(RecordDict):
|
|
291
359
|
"""Deprecated class ``RecordSet``, use ``RecordDict`` instead.
|
flwr/common/recorddict_compat.py
CHANGED
|
@@ -111,12 +111,12 @@ def parameters_to_arrayrecord(parameters: Parameters, keep_input: bool) -> Array
|
|
|
111
111
|
else:
|
|
112
112
|
tensor = parameters.tensors.pop(0)
|
|
113
113
|
ordered_dict[str(idx)] = Array(
|
|
114
|
-
data=tensor, dtype="", stype=tensor_type, shape=
|
|
114
|
+
data=tensor, dtype="", stype=tensor_type, shape=()
|
|
115
115
|
)
|
|
116
116
|
|
|
117
117
|
if num_arrays == 0:
|
|
118
118
|
ordered_dict[EMPTY_TENSOR_KEY] = Array(
|
|
119
|
-
data=b"", dtype="", stype=tensor_type, shape=
|
|
119
|
+
data=b"", dtype="", stype=tensor_type, shape=()
|
|
120
120
|
)
|
|
121
121
|
return ArrayRecord(ordered_dict, keep_input=keep_input)
|
|
122
122
|
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -25,10 +25,12 @@ from typing import Any, Callable, Optional, Union, cast
|
|
|
25
25
|
|
|
26
26
|
import grpc
|
|
27
27
|
|
|
28
|
+
from flwr.client.grpc_rere_client.grpc_adapter import GrpcAdapter
|
|
28
29
|
from flwr.common.constant import MAX_RETRY_DELAY
|
|
29
30
|
from flwr.common.logger import log
|
|
30
31
|
from flwr.common.typing import RunNotRunningException
|
|
31
32
|
from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
|
|
33
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub
|
|
32
34
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
|
|
33
35
|
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub
|
|
34
36
|
|
|
@@ -366,7 +368,9 @@ def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
|
366
368
|
|
|
367
369
|
|
|
368
370
|
def _wrap_stub(
|
|
369
|
-
stub: Union[
|
|
371
|
+
stub: Union[
|
|
372
|
+
ServerAppIoStub, ClientAppIoStub, SimulationIoStub, FleetStub, GrpcAdapter
|
|
373
|
+
],
|
|
370
374
|
retry_invoker: RetryInvoker,
|
|
371
375
|
) -> None:
|
|
372
376
|
"""Wrap a gRPC stub with a retry invoker."""
|