flwr-nightly 1.11.0.dev20240811__py3-none-any.whl → 1.11.0.dev20240821__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +2 -2
- flwr/cli/install.py +3 -1
- flwr/cli/run/run.py +15 -11
- flwr/client/app.py +134 -15
- flwr/client/clientapp/__init__.py +22 -0
- flwr/client/clientapp/app.py +233 -0
- flwr/client/clientapp/clientappio_servicer.py +244 -0
- flwr/client/clientapp/utils.py +108 -0
- flwr/client/grpc_adapter_client/connection.py +3 -1
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +15 -2
- flwr/client/node_state.py +17 -4
- flwr/client/rest_client/connection.py +21 -3
- flwr/client/supernode/app.py +37 -97
- flwr/common/__init__.py +4 -0
- flwr/common/config.py +31 -10
- flwr/common/record/configsrecord.py +49 -15
- flwr/common/record/metricsrecord.py +54 -14
- flwr/common/record/parametersrecord.py +84 -17
- flwr/common/record/recordset.py +80 -8
- flwr/common/record/typeddict.py +20 -58
- flwr/common/recordset_compat.py +6 -6
- flwr/common/serde.py +178 -1
- flwr/common/typing.py +17 -0
- flwr/proto/clientappio_pb2.py +45 -0
- flwr/proto/clientappio_pb2.pyi +132 -0
- flwr/proto/clientappio_pb2_grpc.py +135 -0
- flwr/proto/clientappio_pb2_grpc.pyi +53 -0
- flwr/proto/exec_pb2.py +16 -15
- flwr/proto/exec_pb2.pyi +7 -4
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/server/app.py +15 -0
- flwr/server/driver/grpc_driver.py +1 -0
- flwr/server/run_serverapp.py +18 -2
- flwr/server/server.py +3 -1
- flwr/server/superlink/driver/driver_grpc.py +3 -0
- flwr/server/superlink/driver/driver_servicer.py +32 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +107 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/ffs/ffs_factory.py +47 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +12 -4
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +8 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +16 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -2
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/state/in_memory_state.py +7 -5
- flwr/server/superlink/state/sqlite_state.py +17 -7
- flwr/server/superlink/state/state.py +4 -3
- flwr/server/workflow/default_workflows.py +3 -1
- flwr/simulation/run_simulation.py +5 -67
- flwr/superexec/app.py +3 -3
- flwr/superexec/deployment.py +8 -9
- flwr/superexec/exec_servicer.py +1 -1
- {flwr_nightly-1.11.0.dev20240811.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/METADATA +2 -2
- {flwr_nightly-1.11.0.dev20240811.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/RECORD +62 -46
- {flwr_nightly-1.11.0.dev20240811.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.11.0.dev20240811.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240811.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/WHEEL +0 -0
flwr/common/record/recordset.py
CHANGED
|
@@ -15,8 +15,10 @@
|
|
|
15
15
|
"""RecordSet."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
from dataclasses import dataclass
|
|
19
|
-
from typing import
|
|
21
|
+
from typing import cast
|
|
20
22
|
|
|
21
23
|
from .configsrecord import ConfigsRecord
|
|
22
24
|
from .metricsrecord import MetricsRecord
|
|
@@ -34,9 +36,9 @@ class RecordSetData:
|
|
|
34
36
|
|
|
35
37
|
def __init__(
|
|
36
38
|
self,
|
|
37
|
-
parameters_records:
|
|
38
|
-
metrics_records:
|
|
39
|
-
configs_records:
|
|
39
|
+
parameters_records: dict[str, ParametersRecord] | None = None,
|
|
40
|
+
metrics_records: dict[str, MetricsRecord] | None = None,
|
|
41
|
+
configs_records: dict[str, ConfigsRecord] | None = None,
|
|
40
42
|
) -> None:
|
|
41
43
|
self.parameters_records = TypedDict[str, ParametersRecord](
|
|
42
44
|
self._check_fn_str, self._check_fn_params
|
|
@@ -84,13 +86,83 @@ class RecordSetData:
|
|
|
84
86
|
|
|
85
87
|
|
|
86
88
|
class RecordSet:
|
|
87
|
-
"""RecordSet stores groups of parameters, metrics and configs.
|
|
89
|
+
"""RecordSet stores groups of parameters, metrics and configs.
|
|
90
|
+
|
|
91
|
+
A :code:`RecordSet` is the unified mechanism by which parameters,
|
|
92
|
+
metrics and configs can be either stored as part of a
|
|
93
|
+
`flwr.common.Context <flwr.common.Context.html>`_ in your apps
|
|
94
|
+
or communicated as part of a
|
|
95
|
+
`flwr.common.Message <flwr.common.Message.html>`_ between your apps.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
parameters_records : Optional[Dict[str, ParametersRecord]]
|
|
100
|
+
A dictionary of :code:`ParametersRecords` that can be used to record
|
|
101
|
+
and communicate model parameters and high-dimensional arrays.
|
|
102
|
+
metrics_records : Optional[Dict[str, MetricsRecord]]
|
|
103
|
+
A dictionary of :code:`MetricsRecord` that can be used to record
|
|
104
|
+
and communicate scalar-valued metrics that are the result of performing
|
|
105
|
+
and action, for example, by a :code:`ClientApp`.
|
|
106
|
+
configs_records : Optional[Dict[str, ConfigsRecord]]
|
|
107
|
+
A dictionary of :code:`ConfigsRecord` that can be used to record
|
|
108
|
+
and communicate configuration values to an entity (e.g. to a
|
|
109
|
+
:code:`ClientApp`)
|
|
110
|
+
for it to adjust how an action is performed.
|
|
111
|
+
|
|
112
|
+
Examples
|
|
113
|
+
--------
|
|
114
|
+
A :code:`RecordSet` can hold three types of records, each designed
|
|
115
|
+
with an specific purpose. What is common to all of them is that they
|
|
116
|
+
are Python dictionaries designed to ensure that each key-value pair
|
|
117
|
+
adheres to specified data types.
|
|
118
|
+
|
|
119
|
+
Let's see an example.
|
|
120
|
+
|
|
121
|
+
>>> from flwr.common import RecordSet
|
|
122
|
+
>>> from flwr.common import ConfigsRecords, MetricsRecords, ParametersRecord
|
|
123
|
+
>>>
|
|
124
|
+
>>> # Let's begin with an empty record
|
|
125
|
+
>>> my_recordset = RecordSet()
|
|
126
|
+
>>>
|
|
127
|
+
>>> # We can create a ConfigsRecord
|
|
128
|
+
>>> c_record = ConfigsRecord({"lr": 0.1, "batch-size": 128})
|
|
129
|
+
>>> # Adding it to the record_set would look like this
|
|
130
|
+
>>> my_recordset.configs_records["my_config"] = c_record
|
|
131
|
+
>>>
|
|
132
|
+
>>> # We can create a MetricsRecord following a similar process
|
|
133
|
+
>>> m_record = MetricsRecord({"accuracy": 0.93, "losses": [0.23, 0.1]})
|
|
134
|
+
>>> # Adding it to the record_set would look like this
|
|
135
|
+
>>> my_recordset.metrics_records["my_metrics"] = m_record
|
|
136
|
+
|
|
137
|
+
Adding a :code:`ParametersRecord` follows the same steps as above but first,
|
|
138
|
+
the array needs to be serialized and represented as a :code:`flwr.common.Array`.
|
|
139
|
+
If the array is a :code:`NumPy` array, you can use the built-in utility function
|
|
140
|
+
`array_from_numpy <flwr.common.array_from_numpy.html>`_. It is often possible to
|
|
141
|
+
convert an array first to :code:`NumPy` and then use the aforementioned function.
|
|
142
|
+
|
|
143
|
+
>>> from flwr.common import array_from_numpy
|
|
144
|
+
>>> # Creating a ParametersRecord would look like this
|
|
145
|
+
>>> arr_np = np.random.randn(3, 3)
|
|
146
|
+
>>>
|
|
147
|
+
>>> # You can use the built-in tool to serialize the array
|
|
148
|
+
>>> arr = array_from_numpy(arr_np)
|
|
149
|
+
>>>
|
|
150
|
+
>>> # Finally, create the record
|
|
151
|
+
>>> p_record = ParametersRecord({"my_array": arr})
|
|
152
|
+
>>>
|
|
153
|
+
>>> # Adding it to the record_set would look like this
|
|
154
|
+
>>> my_recordset.configs_records["my_config"] = c_record
|
|
155
|
+
|
|
156
|
+
For additional examples on how to construct each of the records types shown
|
|
157
|
+
above, please refer to the documentation for :code:`ConfigsRecord`,
|
|
158
|
+
:code:`MetricsRecord` and :code:`ParametersRecord`.
|
|
159
|
+
"""
|
|
88
160
|
|
|
89
161
|
def __init__(
|
|
90
162
|
self,
|
|
91
|
-
parameters_records:
|
|
92
|
-
metrics_records:
|
|
93
|
-
configs_records:
|
|
163
|
+
parameters_records: dict[str, ParametersRecord] | None = None,
|
|
164
|
+
metrics_records: dict[str, MetricsRecord] | None = None,
|
|
165
|
+
configs_records: dict[str, ConfigsRecord] | None = None,
|
|
94
166
|
) -> None:
|
|
95
167
|
data = RecordSetData(
|
|
96
168
|
parameters_records=parameters_records,
|
flwr/common/record/typeddict.py
CHANGED
|
@@ -15,99 +15,61 @@
|
|
|
15
15
|
"""Typed dict base class for *Records."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import Callable, Dict, Generic, Iterator, MutableMapping, TypeVar, cast
|
|
19
19
|
|
|
20
20
|
K = TypeVar("K") # Key type
|
|
21
21
|
V = TypeVar("V") # Value type
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class TypedDict(Generic[K, V]):
|
|
24
|
+
class TypedDict(MutableMapping[K, V], Generic[K, V]):
|
|
25
25
|
"""Typed dictionary."""
|
|
26
26
|
|
|
27
27
|
def __init__(
|
|
28
28
|
self, check_key_fn: Callable[[K], None], check_value_fn: Callable[[V], None]
|
|
29
29
|
):
|
|
30
|
-
self.
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
30
|
+
self.__dict__["_check_key_fn"] = check_key_fn
|
|
31
|
+
self.__dict__["_check_value_fn"] = check_value_fn
|
|
32
|
+
self.__dict__["_data"] = {}
|
|
33
33
|
|
|
34
34
|
def __setitem__(self, key: K, value: V) -> None:
|
|
35
35
|
"""Set the given key to the given value after type checking."""
|
|
36
36
|
# Check the types of key and value
|
|
37
|
-
self._check_key_fn(key)
|
|
38
|
-
self._check_value_fn(value)
|
|
37
|
+
cast(Callable[[K], None], self.__dict__["_check_key_fn"])(key)
|
|
38
|
+
cast(Callable[[V], None], self.__dict__["_check_value_fn"])(value)
|
|
39
|
+
|
|
39
40
|
# Set key-value pair
|
|
40
|
-
self._data[key] = value
|
|
41
|
+
cast(Dict[K, V], self.__dict__["_data"])[key] = value
|
|
41
42
|
|
|
42
43
|
def __delitem__(self, key: K) -> None:
|
|
43
44
|
"""Remove the item with the specified key."""
|
|
44
|
-
del self._data[key]
|
|
45
|
+
del cast(Dict[K, V], self.__dict__["_data"])[key]
|
|
45
46
|
|
|
46
47
|
def __getitem__(self, item: K) -> V:
|
|
47
48
|
"""Return the value for the specified key."""
|
|
48
|
-
return self._data[item]
|
|
49
|
+
return cast(Dict[K, V], self.__dict__["_data"])[item]
|
|
49
50
|
|
|
50
51
|
def __iter__(self) -> Iterator[K]:
|
|
51
52
|
"""Yield an iterator over the keys of the dictionary."""
|
|
52
|
-
return iter(self._data)
|
|
53
|
+
return iter(cast(Dict[K, V], self.__dict__["_data"]))
|
|
53
54
|
|
|
54
55
|
def __repr__(self) -> str:
|
|
55
56
|
"""Return a string representation of the dictionary."""
|
|
56
|
-
return self._data.__repr__()
|
|
57
|
+
return cast(Dict[K, V], self.__dict__["_data"]).__repr__()
|
|
57
58
|
|
|
58
59
|
def __len__(self) -> int:
|
|
59
60
|
"""Return the number of items in the dictionary."""
|
|
60
|
-
return len(self._data)
|
|
61
|
+
return len(cast(Dict[K, V], self.__dict__["_data"]))
|
|
61
62
|
|
|
62
|
-
def __contains__(self, key:
|
|
63
|
+
def __contains__(self, key: object) -> bool:
|
|
63
64
|
"""Check if the dictionary contains the specified key."""
|
|
64
|
-
return key in self._data
|
|
65
|
+
return key in cast(Dict[K, V], self.__dict__["_data"])
|
|
65
66
|
|
|
66
67
|
def __eq__(self, other: object) -> bool:
|
|
67
68
|
"""Compare this instance to another dictionary or TypedDict."""
|
|
69
|
+
data = cast(Dict[K, V], self.__dict__["_data"])
|
|
68
70
|
if isinstance(other, TypedDict):
|
|
69
|
-
|
|
71
|
+
other_data = cast(Dict[K, V], other.__dict__["_data"])
|
|
72
|
+
return data == other_data
|
|
70
73
|
if isinstance(other, dict):
|
|
71
|
-
return
|
|
74
|
+
return data == other
|
|
72
75
|
return NotImplemented
|
|
73
|
-
|
|
74
|
-
def items(self) -> Iterator[Tuple[K, V]]:
|
|
75
|
-
"""R.items() -> a set-like object providing a view on R's items."""
|
|
76
|
-
return cast(Iterator[Tuple[K, V]], self._data.items())
|
|
77
|
-
|
|
78
|
-
def keys(self) -> Iterator[K]:
|
|
79
|
-
"""R.keys() -> a set-like object providing a view on R's keys."""
|
|
80
|
-
return cast(Iterator[K], self._data.keys())
|
|
81
|
-
|
|
82
|
-
def values(self) -> Iterator[V]:
|
|
83
|
-
"""R.values() -> an object providing a view on R's values."""
|
|
84
|
-
return cast(Iterator[V], self._data.values())
|
|
85
|
-
|
|
86
|
-
def update(self, *args: Any, **kwargs: Any) -> None:
|
|
87
|
-
"""R.update([E, ]**F) -> None.
|
|
88
|
-
|
|
89
|
-
Update R from dict/iterable E and F.
|
|
90
|
-
"""
|
|
91
|
-
for key, value in dict(*args, **kwargs).items():
|
|
92
|
-
self[key] = value
|
|
93
|
-
|
|
94
|
-
def pop(self, key: K) -> V:
|
|
95
|
-
"""R.pop(k[,d]) -> v, remove specified key and return the corresponding value.
|
|
96
|
-
|
|
97
|
-
If key is not found, d is returned if given, otherwise KeyError is raised.
|
|
98
|
-
"""
|
|
99
|
-
return self._data.pop(key)
|
|
100
|
-
|
|
101
|
-
def get(self, key: K, default: V) -> V:
|
|
102
|
-
"""R.get(k[,d]) -> R[k] if k in R, else d.
|
|
103
|
-
|
|
104
|
-
d defaults to None.
|
|
105
|
-
"""
|
|
106
|
-
return self._data.get(key, default)
|
|
107
|
-
|
|
108
|
-
def clear(self) -> None:
|
|
109
|
-
"""R.clear() -> None.
|
|
110
|
-
|
|
111
|
-
Remove all items from R.
|
|
112
|
-
"""
|
|
113
|
-
self._data.clear()
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -145,7 +145,7 @@ def _recordset_to_fit_or_evaluate_ins_components(
|
|
|
145
145
|
# get config dict
|
|
146
146
|
config_record = recordset.configs_records[f"{ins_str}.config"]
|
|
147
147
|
# pylint: disable-next=protected-access
|
|
148
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
148
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
149
149
|
|
|
150
150
|
return parameters, config_dict
|
|
151
151
|
|
|
@@ -213,7 +213,7 @@ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
|
|
|
213
213
|
)
|
|
214
214
|
configs_record = recordset.configs_records[f"{ins_str}.metrics"]
|
|
215
215
|
# pylint: disable-next=protected-access
|
|
216
|
-
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record
|
|
216
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record)
|
|
217
217
|
status = _extract_status_from_recordset(ins_str, recordset)
|
|
218
218
|
|
|
219
219
|
return FitRes(
|
|
@@ -274,7 +274,7 @@ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
|
|
|
274
274
|
configs_record = recordset.configs_records[f"{ins_str}.metrics"]
|
|
275
275
|
|
|
276
276
|
# pylint: disable-next=protected-access
|
|
277
|
-
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record
|
|
277
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record)
|
|
278
278
|
status = _extract_status_from_recordset(ins_str, recordset)
|
|
279
279
|
|
|
280
280
|
return EvaluateRes(
|
|
@@ -314,7 +314,7 @@ def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns:
|
|
|
314
314
|
"""Derive GetParametersIns from a RecordSet object."""
|
|
315
315
|
config_record = recordset.configs_records["getparametersins.config"]
|
|
316
316
|
# pylint: disable-next=protected-access
|
|
317
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
317
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
318
318
|
|
|
319
319
|
return GetParametersIns(config=config_dict)
|
|
320
320
|
|
|
@@ -365,7 +365,7 @@ def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns:
|
|
|
365
365
|
"""Derive GetPropertiesIns from a RecordSet object."""
|
|
366
366
|
config_record = recordset.configs_records["getpropertiesins.config"]
|
|
367
367
|
# pylint: disable-next=protected-access
|
|
368
|
-
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
368
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
369
369
|
|
|
370
370
|
return GetPropertiesIns(config=config_dict)
|
|
371
371
|
|
|
@@ -384,7 +384,7 @@ def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes:
|
|
|
384
384
|
res_str = "getpropertiesres"
|
|
385
385
|
config_record = recordset.configs_records[f"{res_str}.properties"]
|
|
386
386
|
# pylint: disable-next=protected-access
|
|
387
|
-
properties = _check_mapping_from_recordscalartype_to_scalar(config_record
|
|
387
|
+
properties = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
388
388
|
|
|
389
389
|
status = _extract_status_from_recordset(res_str, recordset=recordset)
|
|
390
390
|
|
flwr/common/serde.py
CHANGED
|
@@ -20,7 +20,12 @@ from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar,
|
|
|
20
20
|
from google.protobuf.message import Message as GrpcMessage
|
|
21
21
|
|
|
22
22
|
# pylint: disable=E0611
|
|
23
|
+
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
23
24
|
from flwr.proto.error_pb2 import Error as ProtoError
|
|
25
|
+
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
26
|
+
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
27
|
+
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
28
|
+
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
|
24
29
|
from flwr.proto.node_pb2 import Node
|
|
25
30
|
from flwr.proto.recordset_pb2 import Array as ProtoArray
|
|
26
31
|
from flwr.proto.recordset_pb2 import BoolList, BytesList
|
|
@@ -32,6 +37,7 @@ from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordVal
|
|
|
32
37
|
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
|
|
33
38
|
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
34
39
|
from flwr.proto.recordset_pb2 import Sint64List, StringList
|
|
40
|
+
from flwr.proto.run_pb2 import Run as ProtoRun
|
|
35
41
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
36
42
|
from flwr.proto.transport_pb2 import (
|
|
37
43
|
ClientMessage,
|
|
@@ -44,7 +50,15 @@ from flwr.proto.transport_pb2 import (
|
|
|
44
50
|
)
|
|
45
51
|
|
|
46
52
|
# pylint: enable=E0611
|
|
47
|
-
from . import
|
|
53
|
+
from . import (
|
|
54
|
+
Array,
|
|
55
|
+
ConfigsRecord,
|
|
56
|
+
Context,
|
|
57
|
+
MetricsRecord,
|
|
58
|
+
ParametersRecord,
|
|
59
|
+
RecordSet,
|
|
60
|
+
typing,
|
|
61
|
+
)
|
|
48
62
|
from .message import Error, Message, Metadata
|
|
49
63
|
from .record.typeddict import TypedDict
|
|
50
64
|
|
|
@@ -673,6 +687,19 @@ def message_from_taskres(taskres: TaskRes) -> Message:
|
|
|
673
687
|
return message
|
|
674
688
|
|
|
675
689
|
|
|
690
|
+
# === FAB ===
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def fab_to_proto(fab: typing.Fab) -> ProtoFab:
|
|
694
|
+
"""Create a proto Fab object from a Python Fab."""
|
|
695
|
+
return ProtoFab(hash_str=fab.hash_str, content=fab.content)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def fab_from_proto(fab: ProtoFab) -> typing.Fab:
|
|
699
|
+
"""Create a Python Fab object from a proto Fab."""
|
|
700
|
+
return typing.Fab(fab.hash_str, fab.content)
|
|
701
|
+
|
|
702
|
+
|
|
676
703
|
# === User configs ===
|
|
677
704
|
|
|
678
705
|
|
|
@@ -716,3 +743,153 @@ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
|
|
|
716
743
|
scalar_field = scalar_msg.WhichOneof("scalar")
|
|
717
744
|
scalar = getattr(scalar_msg, cast(str, scalar_field))
|
|
718
745
|
return cast(typing.UserConfigValue, scalar)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
# === Metadata messages ===
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
|
752
|
+
"""Serialize `Metadata` to ProtoBuf."""
|
|
753
|
+
proto = ProtoMetadata( # pylint: disable=E1101
|
|
754
|
+
run_id=metadata.run_id,
|
|
755
|
+
message_id=metadata.message_id,
|
|
756
|
+
src_node_id=metadata.src_node_id,
|
|
757
|
+
dst_node_id=metadata.dst_node_id,
|
|
758
|
+
reply_to_message=metadata.reply_to_message,
|
|
759
|
+
group_id=metadata.group_id,
|
|
760
|
+
ttl=metadata.ttl,
|
|
761
|
+
message_type=metadata.message_type,
|
|
762
|
+
created_at=metadata.created_at,
|
|
763
|
+
)
|
|
764
|
+
return proto
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
|
|
768
|
+
"""Deserialize `Metadata` from ProtoBuf."""
|
|
769
|
+
metadata = Metadata(
|
|
770
|
+
run_id=metadata_proto.run_id,
|
|
771
|
+
message_id=metadata_proto.message_id,
|
|
772
|
+
src_node_id=metadata_proto.src_node_id,
|
|
773
|
+
dst_node_id=metadata_proto.dst_node_id,
|
|
774
|
+
reply_to_message=metadata_proto.reply_to_message,
|
|
775
|
+
group_id=metadata_proto.group_id,
|
|
776
|
+
ttl=metadata_proto.ttl,
|
|
777
|
+
message_type=metadata_proto.message_type,
|
|
778
|
+
)
|
|
779
|
+
return metadata
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
# === Message messages ===
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def message_to_proto(message: Message) -> ProtoMessage:
|
|
786
|
+
"""Serialize `Message` to ProtoBuf."""
|
|
787
|
+
proto = ProtoMessage(
|
|
788
|
+
metadata=metadata_to_proto(message.metadata),
|
|
789
|
+
content=(
|
|
790
|
+
recordset_to_proto(message.content) if message.has_content() else None
|
|
791
|
+
),
|
|
792
|
+
error=error_to_proto(message.error) if message.has_error() else None,
|
|
793
|
+
)
|
|
794
|
+
return proto
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def message_from_proto(message_proto: ProtoMessage) -> Message:
|
|
798
|
+
"""Deserialize `Message` from ProtoBuf."""
|
|
799
|
+
created_at = message_proto.metadata.created_at
|
|
800
|
+
message = Message(
|
|
801
|
+
metadata=metadata_from_proto(message_proto.metadata),
|
|
802
|
+
content=(
|
|
803
|
+
recordset_from_proto(message_proto.content)
|
|
804
|
+
if message_proto.HasField("content")
|
|
805
|
+
else None
|
|
806
|
+
),
|
|
807
|
+
error=(
|
|
808
|
+
error_from_proto(message_proto.error)
|
|
809
|
+
if message_proto.HasField("error")
|
|
810
|
+
else None
|
|
811
|
+
),
|
|
812
|
+
)
|
|
813
|
+
# `.created_at` is set upon Message object construction
|
|
814
|
+
# we need to manually set it to the original value
|
|
815
|
+
message.metadata.created_at = created_at
|
|
816
|
+
return message
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
# === Context messages ===
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def context_to_proto(context: Context) -> ProtoContext:
|
|
823
|
+
"""Serialize `Context` to ProtoBuf."""
|
|
824
|
+
proto = ProtoContext(
|
|
825
|
+
node_id=context.node_id,
|
|
826
|
+
node_config=user_config_to_proto(context.node_config),
|
|
827
|
+
state=recordset_to_proto(context.state),
|
|
828
|
+
run_config=user_config_to_proto(context.run_config),
|
|
829
|
+
)
|
|
830
|
+
return proto
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
def context_from_proto(context_proto: ProtoContext) -> Context:
|
|
834
|
+
"""Deserialize `Context` from ProtoBuf."""
|
|
835
|
+
context = Context(
|
|
836
|
+
node_id=context_proto.node_id,
|
|
837
|
+
node_config=user_config_from_proto(context_proto.node_config),
|
|
838
|
+
state=recordset_from_proto(context_proto.state),
|
|
839
|
+
run_config=user_config_from_proto(context_proto.run_config),
|
|
840
|
+
)
|
|
841
|
+
return context
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
# === Run messages ===
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
def run_to_proto(run: typing.Run) -> ProtoRun:
|
|
848
|
+
"""Serialize `Run` to ProtoBuf."""
|
|
849
|
+
proto = ProtoRun(
|
|
850
|
+
run_id=run.run_id,
|
|
851
|
+
fab_id=run.fab_id,
|
|
852
|
+
fab_version=run.fab_version,
|
|
853
|
+
fab_hash=run.fab_hash,
|
|
854
|
+
override_config=user_config_to_proto(run.override_config),
|
|
855
|
+
)
|
|
856
|
+
return proto
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
860
|
+
"""Deserialize `Run` from ProtoBuf."""
|
|
861
|
+
run = typing.Run(
|
|
862
|
+
run_id=run_proto.run_id,
|
|
863
|
+
fab_id=run_proto.fab_id,
|
|
864
|
+
fab_version=run_proto.fab_version,
|
|
865
|
+
fab_hash=run_proto.fab_hash,
|
|
866
|
+
override_config=user_config_from_proto(run_proto.override_config),
|
|
867
|
+
)
|
|
868
|
+
return run
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
# === ClientApp status messages ===
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def clientappstatus_to_proto(
|
|
875
|
+
status: typing.ClientAppOutputStatus,
|
|
876
|
+
) -> ClientAppOutputStatus:
|
|
877
|
+
"""Serialize `ClientAppOutputStatus` to ProtoBuf."""
|
|
878
|
+
code = ClientAppOutputCode.SUCCESS
|
|
879
|
+
if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
880
|
+
code = ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
881
|
+
if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
|
|
882
|
+
code = ClientAppOutputCode.UNKNOWN_ERROR
|
|
883
|
+
return ClientAppOutputStatus(code=code, message=status.message)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def clientappstatus_from_proto(
|
|
887
|
+
msg: ClientAppOutputStatus,
|
|
888
|
+
) -> typing.ClientAppOutputStatus:
|
|
889
|
+
"""Deserialize `ClientAppOutputStatus` from ProtoBuf."""
|
|
890
|
+
code = typing.ClientAppOutputCode.SUCCESS
|
|
891
|
+
if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
892
|
+
code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
893
|
+
if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
|
|
894
|
+
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
|
|
895
|
+
return typing.ClientAppOutputStatus(code=code, message=msg.message)
|
flwr/common/typing.py
CHANGED
|
@@ -83,6 +83,22 @@ class Status:
|
|
|
83
83
|
message: str
|
|
84
84
|
|
|
85
85
|
|
|
86
|
+
class ClientAppOutputCode(Enum):
|
|
87
|
+
"""ClientAppIO status codes."""
|
|
88
|
+
|
|
89
|
+
SUCCESS = 0
|
|
90
|
+
DEADLINE_EXCEEDED = 1
|
|
91
|
+
UNKNOWN_ERROR = 2
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class ClientAppOutputStatus:
|
|
96
|
+
"""ClientAppIO status."""
|
|
97
|
+
|
|
98
|
+
code: ClientAppOutputCode
|
|
99
|
+
message: str
|
|
100
|
+
|
|
101
|
+
|
|
86
102
|
@dataclass
|
|
87
103
|
class Parameters:
|
|
88
104
|
"""Model parameters."""
|
|
@@ -198,6 +214,7 @@ class Run:
|
|
|
198
214
|
run_id: int
|
|
199
215
|
fab_id: str
|
|
200
216
|
fab_version: str
|
|
217
|
+
fab_hash: str
|
|
201
218
|
override_config: UserConfig
|
|
202
219
|
|
|
203
220
|
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# source: flwr/proto/clientappio.proto
|
|
4
|
+
# Protobuf Python Version: 4.25.0
|
|
5
|
+
"""Generated protocol buffer code."""
|
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
9
|
+
from google.protobuf.internal import builder as _builder
|
|
10
|
+
# @@protoc_insertion_point(imports)
|
|
11
|
+
|
|
12
|
+
_sym_db = _symbol_database.Default()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
16
|
+
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
17
|
+
from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/clientappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x18\x66lwr/proto/message.proto\"W\n\x15\x43lientAppOutputStatus\x12-\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1f.flwr.proto.ClientAppOutputCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x11\n\x0fGetTokenRequest\"!\n\x10GetTokenResponse\x12\r\n\x05token\x18\x01 \x01(\x12\"+\n\x1aPullClientAppInputsRequest\x12\r\n\x05token\x18\x01 \x01(\x12\"\xa5\x01\n\x1bPullClientAppInputsResponse\x12$\n\x07message\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Message\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\"x\n\x1bPushClientAppOutputsRequest\x12\r\n\x05token\x18\x01 \x01(\x12\x12$\n\x07message\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Message\x12$\n\x07\x63ontext\x18\x03 \x01(\x0b\x32\x13.flwr.proto.Context\"Q\n\x1cPushClientAppOutputsResponse\x12\x31\n\x06status\x18\x01 \x01(\x0b\x32!.flwr.proto.ClientAppOutputStatus*L\n\x13\x43lientAppOutputCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x15\n\x11\x44\x45\x41\x44LINE_EXCEEDED\x10\x01\x12\x11\n\rUNKNOWN_ERROR\x10\x02\x32\xad\x02\n\x0b\x43lientAppIo\x12G\n\x08GetToken\x12\x1b.flwr.proto.GetTokenRequest\x1a\x1c.flwr.proto.GetTokenResponse\"\x00\x12h\n\x13PullClientAppInputs\x12&.flwr.proto.PullClientAppInputsRequest\x1a\'.flwr.proto.PullClientAppInputsResponse\"\x00\x12k\n\x14PushClientAppOutputs\x12\'.flwr.proto.PushClientAppOutputsRequest\x1a(.flwr.proto.PushClientAppOutputsResponse\"\x00\x62\x06proto3')
|
|
21
|
+
|
|
22
|
+
_globals = globals()
|
|
23
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
24
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.clientappio_pb2', _globals)
|
|
25
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
26
|
+
DESCRIPTOR._options = None
|
|
27
|
+
_globals['_CLIENTAPPOUTPUTCODE']._serialized_start=675
|
|
28
|
+
_globals['_CLIENTAPPOUTPUTCODE']._serialized_end=751
|
|
29
|
+
_globals['_CLIENTAPPOUTPUTSTATUS']._serialized_start=114
|
|
30
|
+
_globals['_CLIENTAPPOUTPUTSTATUS']._serialized_end=201
|
|
31
|
+
_globals['_GETTOKENREQUEST']._serialized_start=203
|
|
32
|
+
_globals['_GETTOKENREQUEST']._serialized_end=220
|
|
33
|
+
_globals['_GETTOKENRESPONSE']._serialized_start=222
|
|
34
|
+
_globals['_GETTOKENRESPONSE']._serialized_end=255
|
|
35
|
+
_globals['_PULLCLIENTAPPINPUTSREQUEST']._serialized_start=257
|
|
36
|
+
_globals['_PULLCLIENTAPPINPUTSREQUEST']._serialized_end=300
|
|
37
|
+
_globals['_PULLCLIENTAPPINPUTSRESPONSE']._serialized_start=303
|
|
38
|
+
_globals['_PULLCLIENTAPPINPUTSRESPONSE']._serialized_end=468
|
|
39
|
+
_globals['_PUSHCLIENTAPPOUTPUTSREQUEST']._serialized_start=470
|
|
40
|
+
_globals['_PUSHCLIENTAPPOUTPUTSREQUEST']._serialized_end=590
|
|
41
|
+
_globals['_PUSHCLIENTAPPOUTPUTSRESPONSE']._serialized_start=592
|
|
42
|
+
_globals['_PUSHCLIENTAPPOUTPUTSRESPONSE']._serialized_end=673
|
|
43
|
+
_globals['_CLIENTAPPIO']._serialized_start=754
|
|
44
|
+
_globals['_CLIENTAPPIO']._serialized_end=1055
|
|
45
|
+
# @@protoc_insertion_point(module_scope)
|