flwr 1.15.1__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/new.py +1 -1
- flwr/cli/new/templates/app/README.baseline.md.tpl +4 -4
- flwr/cli/new/templates/app/README.md.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/client/client_app.py +147 -36
- flwr/client/clientapp/app.py +4 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/constant.py +16 -0
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/message.py +18 -7
- flwr/common/object_ref.py +0 -10
- flwr/common/record/conversion_utils.py +8 -17
- flwr/common/record/parametersrecord.py +151 -16
- flwr/common/record/recordset.py +95 -88
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/serde.py +8 -126
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +36 -0
- flwr/server/app.py +18 -2
- flwr/server/compat/app.py +4 -1
- flwr/server/compat/app_utils.py +10 -2
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +10 -1
- flwr/server/driver/inmemory_driver.py +17 -21
- flwr/server/run_serverapp.py +2 -13
- flwr/server/server_app.py +93 -20
- flwr/server/superlink/driver/serverappio_servicer.py +27 -33
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -16
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -36
- flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
- flwr/server/superlink/linkstate/linkstate.py +47 -60
- flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -282
- flwr/server/superlink/linkstate/utils.py +91 -119
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -71
- flwr/server/workflow/default_workflows.py +4 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/app.py +0 -14
- flwr/superexec/exec_servicer.py +4 -4
- flwr/superexec/exec_user_auth_interceptor.py +5 -3
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/METADATA +5 -5
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/RECORD +66 -69
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -103
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
flwr/common/record/recordset.py
CHANGED
|
@@ -17,82 +17,79 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
-
from
|
|
21
|
-
from
|
|
20
|
+
from logging import WARN
|
|
21
|
+
from textwrap import indent
|
|
22
|
+
from typing import TypeVar, Union, cast
|
|
22
23
|
|
|
24
|
+
from ..logger import log
|
|
23
25
|
from .configsrecord import ConfigsRecord
|
|
24
26
|
from .metricsrecord import MetricsRecord
|
|
25
27
|
from .parametersrecord import ParametersRecord
|
|
26
28
|
from .typeddict import TypedDict
|
|
27
29
|
|
|
30
|
+
RecordType = Union[ParametersRecord, MetricsRecord, ConfigsRecord]
|
|
28
31
|
|
|
29
|
-
|
|
30
|
-
class RecordSetData:
|
|
31
|
-
"""Inner data container for the RecordSet class."""
|
|
32
|
+
T = TypeVar("T")
|
|
32
33
|
|
|
33
|
-
parameters_records: TypedDict[str, ParametersRecord]
|
|
34
|
-
metrics_records: TypedDict[str, MetricsRecord]
|
|
35
|
-
configs_records: TypedDict[str, ConfigsRecord]
|
|
36
34
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
) -> None:
|
|
43
|
-
self.parameters_records = TypedDict[str, ParametersRecord](
|
|
44
|
-
self._check_fn_str, self._check_fn_params
|
|
35
|
+
def _check_key(key: str) -> None:
|
|
36
|
+
if not isinstance(key, str):
|
|
37
|
+
raise TypeError(
|
|
38
|
+
f"Expected `{str.__name__}`, but "
|
|
39
|
+
f"received `{type(key).__name__}` for the key."
|
|
45
40
|
)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _check_value(value: RecordType) -> None:
|
|
44
|
+
if not isinstance(value, (ParametersRecord, MetricsRecord, ConfigsRecord)):
|
|
45
|
+
raise TypeError(
|
|
46
|
+
f"Expected `{ParametersRecord.__name__}`, `{MetricsRecord.__name__}`, "
|
|
47
|
+
f"or `{ConfigsRecord.__name__}` but received "
|
|
48
|
+
f"`{type(value).__name__}` for the value."
|
|
51
49
|
)
|
|
52
|
-
if parameters_records is not None:
|
|
53
|
-
self.parameters_records.update(parameters_records)
|
|
54
|
-
if metrics_records is not None:
|
|
55
|
-
self.metrics_records.update(metrics_records)
|
|
56
|
-
if configs_records is not None:
|
|
57
|
-
self.configs_records.update(configs_records)
|
|
58
|
-
|
|
59
|
-
def _check_fn_str(self, key: str) -> None:
|
|
60
|
-
if not isinstance(key, str):
|
|
61
|
-
raise TypeError(
|
|
62
|
-
f"Expected `{str.__name__}`, but "
|
|
63
|
-
f"received `{type(key).__name__}` for the key."
|
|
64
|
-
)
|
|
65
50
|
|
|
66
|
-
def _check_fn_params(self, record: ParametersRecord) -> None:
|
|
67
|
-
if not isinstance(record, ParametersRecord):
|
|
68
|
-
raise TypeError(
|
|
69
|
-
f"Expected `{ParametersRecord.__name__}`, but "
|
|
70
|
-
f"received `{type(record).__name__}` for the value."
|
|
71
|
-
)
|
|
72
51
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
52
|
+
class _SyncedDict(TypedDict[str, T]):
|
|
53
|
+
"""A synchronized dictionary that mirrors changes to an underlying RecordSet.
|
|
54
|
+
|
|
55
|
+
This dictionary ensures that any modifications (set or delete operations)
|
|
56
|
+
are automatically reflected in the associated `RecordSet`. Only values of
|
|
57
|
+
the specified `allowed_type` are permitted.
|
|
58
|
+
"""
|
|
79
59
|
|
|
80
|
-
def
|
|
81
|
-
if not
|
|
60
|
+
def __init__(self, ref_recordset: RecordSet, allowed_type: type[T]) -> None:
|
|
61
|
+
if not issubclass(
|
|
62
|
+
allowed_type, (ParametersRecord, MetricsRecord, ConfigsRecord)
|
|
63
|
+
):
|
|
64
|
+
raise TypeError(f"{allowed_type} is not a valid type.")
|
|
65
|
+
super().__init__(_check_key, self.check_value)
|
|
66
|
+
self.recordset = ref_recordset
|
|
67
|
+
self.allowed_type = allowed_type
|
|
68
|
+
|
|
69
|
+
def __setitem__(self, key: str, value: T) -> None:
|
|
70
|
+
super().__setitem__(key, value)
|
|
71
|
+
self.recordset[key] = cast(RecordType, value)
|
|
72
|
+
|
|
73
|
+
def __delitem__(self, key: str) -> None:
|
|
74
|
+
super().__delitem__(key)
|
|
75
|
+
del self.recordset[key]
|
|
76
|
+
|
|
77
|
+
def check_value(self, value: T) -> None:
|
|
78
|
+
"""Check if value is of expected type."""
|
|
79
|
+
if not isinstance(value, self.allowed_type):
|
|
82
80
|
raise TypeError(
|
|
83
|
-
f"Expected `{
|
|
84
|
-
f"received `{type(
|
|
81
|
+
f"Expected `{self.allowed_type.__name__}`, but "
|
|
82
|
+
f"received `{type(value).__name__}` for the value."
|
|
85
83
|
)
|
|
86
84
|
|
|
87
85
|
|
|
88
|
-
class RecordSet:
|
|
86
|
+
class RecordSet(TypedDict[str, RecordType]):
|
|
89
87
|
"""RecordSet stores groups of parameters, metrics and configs.
|
|
90
88
|
|
|
91
|
-
A :
|
|
92
|
-
metrics and configs can be either stored as part of a
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
`flwr.common.Message <flwr.common.Message.html>`_ between your apps.
|
|
89
|
+
A :class:`RecordSet` is the unified mechanism by which parameters,
|
|
90
|
+
metrics and configs can be either stored as part of a :class:`Context`
|
|
91
|
+
in your apps or communicated as part of a :class:`Message` between
|
|
92
|
+
your apps.
|
|
96
93
|
|
|
97
94
|
Parameters
|
|
98
95
|
----------
|
|
@@ -127,12 +124,12 @@ class RecordSet:
|
|
|
127
124
|
>>> # We can create a ConfigsRecord
|
|
128
125
|
>>> c_record = ConfigsRecord({"lr": 0.1, "batch-size": 128})
|
|
129
126
|
>>> # Adding it to the record_set would look like this
|
|
130
|
-
>>> my_recordset
|
|
127
|
+
>>> my_recordset["my_config"] = c_record
|
|
131
128
|
>>>
|
|
132
129
|
>>> # We can create a MetricsRecord following a similar process
|
|
133
130
|
>>> m_record = MetricsRecord({"accuracy": 0.93, "losses": [0.23, 0.1]})
|
|
134
131
|
>>> # Adding it to the record_set would look like this
|
|
135
|
-
>>> my_recordset
|
|
132
|
+
>>> my_recordset["my_metrics"] = m_record
|
|
136
133
|
|
|
137
134
|
Adding a :code:`ParametersRecord` follows the same steps as above but first,
|
|
138
135
|
the array needs to be serialized and represented as a :code:`flwr.common.Array`.
|
|
@@ -151,52 +148,62 @@ class RecordSet:
|
|
|
151
148
|
>>> p_record = ParametersRecord({"my_array": arr})
|
|
152
149
|
>>>
|
|
153
150
|
>>> # Adding it to the record_set would look like this
|
|
154
|
-
>>> my_recordset
|
|
151
|
+
>>> my_recordset["my_parameters"] = p_record
|
|
155
152
|
|
|
156
153
|
For additional examples on how to construct each of the records types shown
|
|
157
154
|
above, please refer to the documentation for :code:`ConfigsRecord`,
|
|
158
155
|
:code:`MetricsRecord` and :code:`ParametersRecord`.
|
|
159
156
|
"""
|
|
160
157
|
|
|
161
|
-
def __init__(
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
) -> None:
|
|
167
|
-
data = RecordSetData(
|
|
168
|
-
parameters_records=parameters_records,
|
|
169
|
-
metrics_records=metrics_records,
|
|
170
|
-
configs_records=configs_records,
|
|
171
|
-
)
|
|
172
|
-
self.__dict__["_data"] = data
|
|
158
|
+
def __init__(self, records: dict[str, RecordType] | None = None) -> None:
|
|
159
|
+
super().__init__(_check_key, _check_value)
|
|
160
|
+
if records is not None:
|
|
161
|
+
for key, record in records.items():
|
|
162
|
+
self[key] = record
|
|
173
163
|
|
|
174
164
|
@property
|
|
175
165
|
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
|
|
176
|
-
"""Dictionary holding ParametersRecord instances."""
|
|
177
|
-
|
|
178
|
-
|
|
166
|
+
"""Dictionary holding only ParametersRecord instances."""
|
|
167
|
+
synced_dict = _SyncedDict[ParametersRecord](self, ParametersRecord)
|
|
168
|
+
for key, record in self.items():
|
|
169
|
+
if isinstance(record, ParametersRecord):
|
|
170
|
+
synced_dict[key] = record
|
|
171
|
+
return synced_dict
|
|
179
172
|
|
|
180
173
|
@property
|
|
181
174
|
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
|
|
182
|
-
"""Dictionary holding MetricsRecord instances."""
|
|
183
|
-
|
|
184
|
-
|
|
175
|
+
"""Dictionary holding only MetricsRecord instances."""
|
|
176
|
+
synced_dict = _SyncedDict[MetricsRecord](self, MetricsRecord)
|
|
177
|
+
for key, record in self.items():
|
|
178
|
+
if isinstance(record, MetricsRecord):
|
|
179
|
+
synced_dict[key] = record
|
|
180
|
+
return synced_dict
|
|
185
181
|
|
|
186
182
|
@property
|
|
187
183
|
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
|
|
188
|
-
"""Dictionary holding ConfigsRecord instances."""
|
|
189
|
-
|
|
190
|
-
|
|
184
|
+
"""Dictionary holding only ConfigsRecord instances."""
|
|
185
|
+
synced_dict = _SyncedDict[ConfigsRecord](self, ConfigsRecord)
|
|
186
|
+
for key, record in self.items():
|
|
187
|
+
if isinstance(record, ConfigsRecord):
|
|
188
|
+
synced_dict[key] = record
|
|
189
|
+
return synced_dict
|
|
191
190
|
|
|
192
191
|
def __repr__(self) -> str:
|
|
193
192
|
"""Return a string representation of this instance."""
|
|
194
193
|
flds = ("parameters_records", "metrics_records", "configs_records")
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
194
|
+
fld_views = [f"{fld}={dict(getattr(self, fld))!r}" for fld in flds]
|
|
195
|
+
view = indent(",\n".join(fld_views), " ")
|
|
196
|
+
return f"{self.__class__.__qualname__}(\n{view}\n)"
|
|
197
|
+
|
|
198
|
+
def __setitem__(self, key: str, value: RecordType) -> None:
|
|
199
|
+
"""Set the given key to the given value after type checking."""
|
|
200
|
+
original_value = self.get(key, None)
|
|
201
|
+
super().__setitem__(key, value)
|
|
202
|
+
if original_value is not None and not isinstance(value, type(original_value)):
|
|
203
|
+
log(
|
|
204
|
+
WARN,
|
|
205
|
+
"Key '%s' was overwritten: record of type `%s` replaced with type `%s`",
|
|
206
|
+
key,
|
|
207
|
+
type(original_value).__name__,
|
|
208
|
+
type(value).__name__,
|
|
209
|
+
)
|
|
@@ -25,7 +25,11 @@ from flwr.common.typing import NDArrayFloat, NDArrayInt
|
|
|
25
25
|
def _stochastic_round(arr: NDArrayFloat) -> NDArrayInt:
|
|
26
26
|
ret: NDArrayInt = np.ceil(arr).astype(np.int32)
|
|
27
27
|
rand_arr = np.random.rand(*ret.shape)
|
|
28
|
-
|
|
28
|
+
if len(ret.shape) == 0:
|
|
29
|
+
if rand_arr < ret - arr:
|
|
30
|
+
ret -= 1
|
|
31
|
+
else:
|
|
32
|
+
ret[rand_arr < ret - arr] -= 1
|
|
29
33
|
return ret
|
|
30
34
|
|
|
31
35
|
|
flwr/common/serde.py
CHANGED
|
@@ -21,8 +21,6 @@ from typing import Any, TypeVar, cast
|
|
|
21
21
|
|
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
|
23
23
|
|
|
24
|
-
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
25
|
-
|
|
26
24
|
# pylint: disable=E0611
|
|
27
25
|
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
28
26
|
from flwr.proto.error_pb2 import Error as ProtoError
|
|
@@ -30,7 +28,6 @@ from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
|
30
28
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
31
29
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
32
30
|
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
|
33
|
-
from flwr.proto.node_pb2 import Node
|
|
34
31
|
from flwr.proto.recordset_pb2 import Array as ProtoArray
|
|
35
32
|
from flwr.proto.recordset_pb2 import BoolList, BytesList
|
|
36
33
|
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
|
@@ -43,7 +40,6 @@ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
|
43
40
|
from flwr.proto.recordset_pb2 import SintList, StringList, UintList
|
|
44
41
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
|
45
42
|
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
|
|
46
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
47
43
|
from flwr.proto.transport_pb2 import (
|
|
48
44
|
ClientMessage,
|
|
49
45
|
Code,
|
|
@@ -583,128 +579,14 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
|
|
|
583
579
|
|
|
584
580
|
def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
|
|
585
581
|
"""Deserialize RecordSet from ProtoBuf."""
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
configs_records={
|
|
595
|
-
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
|
|
596
|
-
},
|
|
597
|
-
)
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
# === Message ===
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def message_to_taskins(message: Message) -> TaskIns:
|
|
604
|
-
"""Create a TaskIns from the Message."""
|
|
605
|
-
md = message.metadata
|
|
606
|
-
return TaskIns(
|
|
607
|
-
group_id=md.group_id,
|
|
608
|
-
run_id=md.run_id,
|
|
609
|
-
task=Task(
|
|
610
|
-
producer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
|
|
611
|
-
consumer=Node(node_id=md.dst_node_id),
|
|
612
|
-
created_at=md.created_at,
|
|
613
|
-
ttl=md.ttl,
|
|
614
|
-
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
|
|
615
|
-
task_type=md.message_type,
|
|
616
|
-
recordset=(
|
|
617
|
-
recordset_to_proto(message.content) if message.has_content() else None
|
|
618
|
-
),
|
|
619
|
-
error=error_to_proto(message.error) if message.has_error() else None,
|
|
620
|
-
),
|
|
621
|
-
)
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
def message_from_taskins(taskins: TaskIns) -> Message:
|
|
625
|
-
"""Create a Message from the TaskIns."""
|
|
626
|
-
# Retrieve the Metadata
|
|
627
|
-
metadata = Metadata(
|
|
628
|
-
run_id=taskins.run_id,
|
|
629
|
-
message_id=taskins.task_id,
|
|
630
|
-
src_node_id=taskins.task.producer.node_id,
|
|
631
|
-
dst_node_id=taskins.task.consumer.node_id,
|
|
632
|
-
reply_to_message=taskins.task.ancestry[0] if taskins.task.ancestry else "",
|
|
633
|
-
group_id=taskins.group_id,
|
|
634
|
-
ttl=taskins.task.ttl,
|
|
635
|
-
message_type=taskins.task.task_type,
|
|
636
|
-
)
|
|
637
|
-
|
|
638
|
-
# Construct Message
|
|
639
|
-
message = Message(
|
|
640
|
-
metadata=metadata,
|
|
641
|
-
content=(
|
|
642
|
-
recordset_from_proto(taskins.task.recordset)
|
|
643
|
-
if taskins.task.HasField("recordset")
|
|
644
|
-
else None
|
|
645
|
-
),
|
|
646
|
-
error=(
|
|
647
|
-
error_from_proto(taskins.task.error)
|
|
648
|
-
if taskins.task.HasField("error")
|
|
649
|
-
else None
|
|
650
|
-
),
|
|
651
|
-
)
|
|
652
|
-
message.metadata.created_at = taskins.task.created_at
|
|
653
|
-
return message
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
def message_to_taskres(message: Message) -> TaskRes:
|
|
657
|
-
"""Create a TaskRes from the Message."""
|
|
658
|
-
md = message.metadata
|
|
659
|
-
return TaskRes(
|
|
660
|
-
task_id="", # This will be generated by the server
|
|
661
|
-
group_id=md.group_id,
|
|
662
|
-
run_id=md.run_id,
|
|
663
|
-
task=Task(
|
|
664
|
-
producer=Node(node_id=md.src_node_id),
|
|
665
|
-
consumer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
|
|
666
|
-
created_at=md.created_at,
|
|
667
|
-
ttl=md.ttl,
|
|
668
|
-
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
|
|
669
|
-
task_type=md.message_type,
|
|
670
|
-
recordset=(
|
|
671
|
-
recordset_to_proto(message.content) if message.has_content() else None
|
|
672
|
-
),
|
|
673
|
-
error=error_to_proto(message.error) if message.has_error() else None,
|
|
674
|
-
),
|
|
675
|
-
)
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
def message_from_taskres(taskres: TaskRes) -> Message:
|
|
679
|
-
"""Create a Message from the TaskIns."""
|
|
680
|
-
# Retrieve the MetaData
|
|
681
|
-
metadata = Metadata(
|
|
682
|
-
run_id=taskres.run_id,
|
|
683
|
-
message_id=taskres.task_id,
|
|
684
|
-
src_node_id=taskres.task.producer.node_id,
|
|
685
|
-
dst_node_id=taskres.task.consumer.node_id,
|
|
686
|
-
reply_to_message=taskres.task.ancestry[0] if taskres.task.ancestry else "",
|
|
687
|
-
group_id=taskres.group_id,
|
|
688
|
-
ttl=taskres.task.ttl,
|
|
689
|
-
message_type=taskres.task.task_type,
|
|
690
|
-
)
|
|
691
|
-
|
|
692
|
-
# Construct the Message
|
|
693
|
-
message = Message(
|
|
694
|
-
metadata=metadata,
|
|
695
|
-
content=(
|
|
696
|
-
recordset_from_proto(taskres.task.recordset)
|
|
697
|
-
if taskres.task.HasField("recordset")
|
|
698
|
-
else None
|
|
699
|
-
),
|
|
700
|
-
error=(
|
|
701
|
-
error_from_proto(taskres.task.error)
|
|
702
|
-
if taskres.task.HasField("error")
|
|
703
|
-
else None
|
|
704
|
-
),
|
|
705
|
-
)
|
|
706
|
-
message.metadata.created_at = taskres.task.created_at
|
|
707
|
-
return message
|
|
582
|
+
ret = RecordSet()
|
|
583
|
+
for k, p_record_proto in recordset_proto.parameters.items():
|
|
584
|
+
ret[k] = parameters_record_from_proto(p_record_proto)
|
|
585
|
+
for k, m_record_proto in recordset_proto.metrics.items():
|
|
586
|
+
ret[k] = metrics_record_from_proto(m_record_proto)
|
|
587
|
+
for k, c_record_proto in recordset_proto.configs.items():
|
|
588
|
+
ret[k] = configs_record_from_proto(c_record_proto)
|
|
589
|
+
return ret
|
|
708
590
|
|
|
709
591
|
|
|
710
592
|
# === FAB ===
|
flwr/common/telemetry.py
CHANGED
|
@@ -181,16 +181,6 @@ class EventType(str, Enum):
|
|
|
181
181
|
RUN_SUPERNODE_ENTER = auto()
|
|
182
182
|
RUN_SUPERNODE_LEAVE = auto()
|
|
183
183
|
|
|
184
|
-
# --- DEPRECATED -------------------------------------------------------------------
|
|
185
|
-
|
|
186
|
-
# [DEPRECATED] CLI: `flower-server-app`
|
|
187
|
-
RUN_SERVER_APP_ENTER = auto()
|
|
188
|
-
RUN_SERVER_APP_LEAVE = auto()
|
|
189
|
-
|
|
190
|
-
# [DEPRECATED] CLI: `flower-client-app`
|
|
191
|
-
RUN_CLIENT_APP_ENTER = auto()
|
|
192
|
-
RUN_CLIENT_APP_LEAVE = auto()
|
|
193
|
-
|
|
194
184
|
|
|
195
185
|
# Use the ThreadPoolExecutor with max_workers=1 to have a queue
|
|
196
186
|
# and also ensure that telemetry calls are not blocking.
|
flwr/common/typing.py
CHANGED
|
@@ -286,3 +286,39 @@ class UserAuthCredentials:
|
|
|
286
286
|
|
|
287
287
|
access_token: str
|
|
288
288
|
refresh_token: str
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@dataclass
|
|
292
|
+
class UserInfo:
|
|
293
|
+
"""User information for event log."""
|
|
294
|
+
|
|
295
|
+
user_id: Optional[str]
|
|
296
|
+
user_name: Optional[str]
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@dataclass
|
|
300
|
+
class Actor:
|
|
301
|
+
"""Event log actor."""
|
|
302
|
+
|
|
303
|
+
actor_id: Optional[str]
|
|
304
|
+
description: Optional[str]
|
|
305
|
+
ip_address: str
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@dataclass
|
|
309
|
+
class Event:
|
|
310
|
+
"""Event log description."""
|
|
311
|
+
|
|
312
|
+
action: str
|
|
313
|
+
run_id: Optional[int]
|
|
314
|
+
fab_hash: Optional[str]
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@dataclass
|
|
318
|
+
class LogEntry:
|
|
319
|
+
"""Event log record."""
|
|
320
|
+
|
|
321
|
+
timestamp: str
|
|
322
|
+
actor: Actor
|
|
323
|
+
event: Event
|
|
324
|
+
status: str
|
flwr/server/app.py
CHANGED
|
@@ -90,7 +90,11 @@ BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
try:
|
|
93
|
-
from flwr.ee import
|
|
93
|
+
from flwr.ee import (
|
|
94
|
+
add_ee_args_superlink,
|
|
95
|
+
get_dashboard_server,
|
|
96
|
+
get_exec_auth_plugins,
|
|
97
|
+
)
|
|
94
98
|
except ImportError:
|
|
95
99
|
|
|
96
100
|
# pylint: disable-next=unused-argument
|
|
@@ -431,6 +435,17 @@ def run_superlink() -> None:
|
|
|
431
435
|
scheduler_th.start()
|
|
432
436
|
bckg_threads.append(scheduler_th)
|
|
433
437
|
|
|
438
|
+
# Add Dashboard server if available
|
|
439
|
+
if dashboard_address := getattr(args, "dashboard_address", None):
|
|
440
|
+
dashboard_address_str, _, _ = _format_address(dashboard_address)
|
|
441
|
+
dashboard_server = get_dashboard_server(
|
|
442
|
+
address=dashboard_address_str,
|
|
443
|
+
state_factory=state_factory,
|
|
444
|
+
certificates=None,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
grpc_servers.append(dashboard_server)
|
|
448
|
+
|
|
434
449
|
# Graceful shutdown
|
|
435
450
|
register_exit_handlers(
|
|
436
451
|
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
|
@@ -710,7 +725,8 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
710
725
|
"--insecure",
|
|
711
726
|
action="store_true",
|
|
712
727
|
help="Run the server without HTTPS, regardless of whether certificate "
|
|
713
|
-
"paths are provided.
|
|
728
|
+
"paths are provided. Data transmitted between the gRPC client and server "
|
|
729
|
+
"is not encrypted. By default, the server runs with HTTPS enabled. "
|
|
714
730
|
"Use this flag only if you understand the risks.",
|
|
715
731
|
)
|
|
716
732
|
parser.add_argument(
|
flwr/server/compat/app.py
CHANGED
|
@@ -79,10 +79,13 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
79
79
|
log(INFO, "")
|
|
80
80
|
|
|
81
81
|
# Start the thread updating nodes
|
|
82
|
-
thread, f_stop = start_update_client_manager_thread(
|
|
82
|
+
thread, f_stop, c_done = start_update_client_manager_thread(
|
|
83
83
|
driver, initialized_server.client_manager()
|
|
84
84
|
)
|
|
85
85
|
|
|
86
|
+
# Wait until the node registration done
|
|
87
|
+
c_done.wait()
|
|
88
|
+
|
|
86
89
|
# Start training
|
|
87
90
|
hist = run_fl(
|
|
88
91
|
server=initialized_server,
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -27,7 +27,7 @@ from ..driver import Driver
|
|
|
27
27
|
def start_update_client_manager_thread(
|
|
28
28
|
driver: Driver,
|
|
29
29
|
client_manager: ClientManager,
|
|
30
|
-
) -> tuple[threading.Thread, threading.Event]:
|
|
30
|
+
) -> tuple[threading.Thread, threading.Event, threading.Event]:
|
|
31
31
|
"""Periodically update the nodes list in the client manager in a thread.
|
|
32
32
|
|
|
33
33
|
This function starts a thread that periodically uses the associated driver to
|
|
@@ -51,26 +51,31 @@ def start_update_client_manager_thread(
|
|
|
51
51
|
A thread that updates the ClientManager and handles the stop event.
|
|
52
52
|
threading.Event
|
|
53
53
|
An event that, when set, signals the thread to stop.
|
|
54
|
+
threading.Event
|
|
55
|
+
An event that, when set, signals the node registration done.
|
|
54
56
|
"""
|
|
55
57
|
f_stop = threading.Event()
|
|
58
|
+
c_done = threading.Event()
|
|
56
59
|
thread = threading.Thread(
|
|
57
60
|
target=_update_client_manager,
|
|
58
61
|
args=(
|
|
59
62
|
driver,
|
|
60
63
|
client_manager,
|
|
61
64
|
f_stop,
|
|
65
|
+
c_done,
|
|
62
66
|
),
|
|
63
67
|
daemon=True,
|
|
64
68
|
)
|
|
65
69
|
thread.start()
|
|
66
70
|
|
|
67
|
-
return thread, f_stop
|
|
71
|
+
return thread, f_stop, c_done
|
|
68
72
|
|
|
69
73
|
|
|
70
74
|
def _update_client_manager(
|
|
71
75
|
driver: Driver,
|
|
72
76
|
client_manager: ClientManager,
|
|
73
77
|
f_stop: threading.Event,
|
|
78
|
+
c_done: threading.Event,
|
|
74
79
|
) -> None:
|
|
75
80
|
"""Update the nodes list in the client manager."""
|
|
76
81
|
# Loop until the driver is disconnected
|
|
@@ -102,6 +107,9 @@ def _update_client_manager(
|
|
|
102
107
|
else:
|
|
103
108
|
raise RuntimeError("Could not register node.")
|
|
104
109
|
|
|
110
|
+
# Flag first pass for nodes registration is completed
|
|
111
|
+
c_done.set()
|
|
112
|
+
|
|
105
113
|
# Sleep for 3 seconds
|
|
106
114
|
if not f_stop.is_set():
|
|
107
115
|
f_stop.wait(3)
|
|
@@ -104,7 +104,7 @@ class DriverClientProxy(ClientProxy):
|
|
|
104
104
|
def _send_receive_recordset(
|
|
105
105
|
self,
|
|
106
106
|
recordset: RecordSet,
|
|
107
|
-
|
|
107
|
+
message_type: str,
|
|
108
108
|
timeout: Optional[float],
|
|
109
109
|
group_id: Optional[int],
|
|
110
110
|
) -> RecordSet:
|
|
@@ -112,7 +112,7 @@ class DriverClientProxy(ClientProxy):
|
|
|
112
112
|
# Create message
|
|
113
113
|
message = self.driver.create_message(
|
|
114
114
|
content=recordset,
|
|
115
|
-
message_type=
|
|
115
|
+
message_type=message_type,
|
|
116
116
|
dst_node_id=self.node_id,
|
|
117
117
|
group_id=str(group_id) if group_id else "",
|
|
118
118
|
ttl=timeout,
|
flwr/server/driver/driver.py
CHANGED
|
@@ -183,7 +183,7 @@ class GrpcDriver(Driver):
|
|
|
183
183
|
)
|
|
184
184
|
return Message(metadata=metadata, content=content)
|
|
185
185
|
|
|
186
|
-
def get_node_ids(self) ->
|
|
186
|
+
def get_node_ids(self) -> Iterable[int]:
|
|
187
187
|
"""Get node IDs."""
|
|
188
188
|
# Call GrpcDriverStub method
|
|
189
189
|
res: GetNodesResponse = self._stub.GetNodes(
|
|
@@ -212,6 +212,15 @@ class GrpcDriver(Driver):
|
|
|
212
212
|
messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
|
|
213
213
|
)
|
|
214
214
|
)
|
|
215
|
+
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
|
|
216
|
+
list(message_proto_list)
|
|
217
|
+
):
|
|
218
|
+
log(
|
|
219
|
+
WARNING,
|
|
220
|
+
"Not all messages could be pushed to the SuperLink. The returned "
|
|
221
|
+
"list has `None` for those messages (the order is preserved as passed "
|
|
222
|
+
"to `push_messages`). This could be due to a malformed message.",
|
|
223
|
+
)
|
|
215
224
|
return list(res.message_ids)
|
|
216
225
|
|
|
217
226
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|