flwr 1.15.1__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/new.py +1 -1
  4. flwr/cli/new/templates/app/README.baseline.md.tpl +4 -4
  5. flwr/cli/new/templates/app/README.md.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  13. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  14. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  15. flwr/client/client_app.py +147 -36
  16. flwr/client/clientapp/app.py +4 -0
  17. flwr/client/message_handler/message_handler.py +1 -1
  18. flwr/client/rest_client/connection.py +4 -6
  19. flwr/client/supernode/__init__.py +0 -2
  20. flwr/client/supernode/app.py +1 -11
  21. flwr/common/address.py +35 -0
  22. flwr/common/args.py +8 -2
  23. flwr/common/auth_plugin/auth_plugin.py +2 -1
  24. flwr/common/constant.py +16 -0
  25. flwr/common/event_log_plugin/__init__.py +22 -0
  26. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  27. flwr/common/grpc.py +1 -1
  28. flwr/common/message.py +18 -7
  29. flwr/common/object_ref.py +0 -10
  30. flwr/common/record/conversion_utils.py +8 -17
  31. flwr/common/record/parametersrecord.py +151 -16
  32. flwr/common/record/recordset.py +95 -88
  33. flwr/common/secure_aggregation/quantization.py +5 -1
  34. flwr/common/serde.py +8 -126
  35. flwr/common/telemetry.py +0 -10
  36. flwr/common/typing.py +36 -0
  37. flwr/server/app.py +18 -2
  38. flwr/server/compat/app.py +4 -1
  39. flwr/server/compat/app_utils.py +10 -2
  40. flwr/server/compat/driver_client_proxy.py +2 -2
  41. flwr/server/driver/driver.py +1 -1
  42. flwr/server/driver/grpc_driver.py +10 -1
  43. flwr/server/driver/inmemory_driver.py +17 -21
  44. flwr/server/run_serverapp.py +2 -13
  45. flwr/server/server_app.py +93 -20
  46. flwr/server/superlink/driver/serverappio_servicer.py +27 -33
  47. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  48. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -16
  49. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  50. flwr/server/superlink/fleet/vce/vce_api.py +32 -36
  51. flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
  52. flwr/server/superlink/linkstate/linkstate.py +47 -60
  53. flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -282
  54. flwr/server/superlink/linkstate/utils.py +91 -119
  55. flwr/server/utils/__init__.py +2 -2
  56. flwr/server/utils/validator.py +53 -71
  57. flwr/server/workflow/default_workflows.py +4 -1
  58. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
  59. flwr/superexec/app.py +0 -14
  60. flwr/superexec/exec_servicer.py +4 -4
  61. flwr/superexec/exec_user_auth_interceptor.py +5 -3
  62. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/METADATA +5 -5
  63. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/RECORD +66 -69
  64. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
  65. flwr/client/message_handler/task_handler.py +0 -37
  66. flwr/proto/task_pb2.py +0 -33
  67. flwr/proto/task_pb2.pyi +0 -103
  68. flwr/proto/task_pb2_grpc.py +0 -4
  69. flwr/proto/task_pb2_grpc.pyi +0 -4
  70. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
  71. {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
@@ -17,82 +17,79 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from dataclasses import dataclass
21
- from typing import cast
20
+ from logging import WARN
21
+ from textwrap import indent
22
+ from typing import TypeVar, Union, cast
22
23
 
24
+ from ..logger import log
23
25
  from .configsrecord import ConfigsRecord
24
26
  from .metricsrecord import MetricsRecord
25
27
  from .parametersrecord import ParametersRecord
26
28
  from .typeddict import TypedDict
27
29
 
30
+ RecordType = Union[ParametersRecord, MetricsRecord, ConfigsRecord]
28
31
 
29
- @dataclass
30
- class RecordSetData:
31
- """Inner data container for the RecordSet class."""
32
+ T = TypeVar("T")
32
33
 
33
- parameters_records: TypedDict[str, ParametersRecord]
34
- metrics_records: TypedDict[str, MetricsRecord]
35
- configs_records: TypedDict[str, ConfigsRecord]
36
34
 
37
- def __init__(
38
- self,
39
- parameters_records: dict[str, ParametersRecord] | None = None,
40
- metrics_records: dict[str, MetricsRecord] | None = None,
41
- configs_records: dict[str, ConfigsRecord] | None = None,
42
- ) -> None:
43
- self.parameters_records = TypedDict[str, ParametersRecord](
44
- self._check_fn_str, self._check_fn_params
35
+ def _check_key(key: str) -> None:
36
+ if not isinstance(key, str):
37
+ raise TypeError(
38
+ f"Expected `{str.__name__}`, but "
39
+ f"received `{type(key).__name__}` for the key."
45
40
  )
46
- self.metrics_records = TypedDict[str, MetricsRecord](
47
- self._check_fn_str, self._check_fn_metrics
48
- )
49
- self.configs_records = TypedDict[str, ConfigsRecord](
50
- self._check_fn_str, self._check_fn_configs
41
+
42
+
43
+ def _check_value(value: RecordType) -> None:
44
+ if not isinstance(value, (ParametersRecord, MetricsRecord, ConfigsRecord)):
45
+ raise TypeError(
46
+ f"Expected `{ParametersRecord.__name__}`, `{MetricsRecord.__name__}`, "
47
+ f"or `{ConfigsRecord.__name__}` but received "
48
+ f"`{type(value).__name__}` for the value."
51
49
  )
52
- if parameters_records is not None:
53
- self.parameters_records.update(parameters_records)
54
- if metrics_records is not None:
55
- self.metrics_records.update(metrics_records)
56
- if configs_records is not None:
57
- self.configs_records.update(configs_records)
58
-
59
- def _check_fn_str(self, key: str) -> None:
60
- if not isinstance(key, str):
61
- raise TypeError(
62
- f"Expected `{str.__name__}`, but "
63
- f"received `{type(key).__name__}` for the key."
64
- )
65
50
 
66
- def _check_fn_params(self, record: ParametersRecord) -> None:
67
- if not isinstance(record, ParametersRecord):
68
- raise TypeError(
69
- f"Expected `{ParametersRecord.__name__}`, but "
70
- f"received `{type(record).__name__}` for the value."
71
- )
72
51
 
73
- def _check_fn_metrics(self, record: MetricsRecord) -> None:
74
- if not isinstance(record, MetricsRecord):
75
- raise TypeError(
76
- f"Expected `{MetricsRecord.__name__}`, but "
77
- f"received `{type(record).__name__}` for the value."
78
- )
52
+ class _SyncedDict(TypedDict[str, T]):
53
+ """A synchronized dictionary that mirrors changes to an underlying RecordSet.
54
+
55
+ This dictionary ensures that any modifications (set or delete operations)
56
+ are automatically reflected in the associated `RecordSet`. Only values of
57
+ the specified `allowed_type` are permitted.
58
+ """
79
59
 
80
- def _check_fn_configs(self, record: ConfigsRecord) -> None:
81
- if not isinstance(record, ConfigsRecord):
60
+ def __init__(self, ref_recordset: RecordSet, allowed_type: type[T]) -> None:
61
+ if not issubclass(
62
+ allowed_type, (ParametersRecord, MetricsRecord, ConfigsRecord)
63
+ ):
64
+ raise TypeError(f"{allowed_type} is not a valid type.")
65
+ super().__init__(_check_key, self.check_value)
66
+ self.recordset = ref_recordset
67
+ self.allowed_type = allowed_type
68
+
69
+ def __setitem__(self, key: str, value: T) -> None:
70
+ super().__setitem__(key, value)
71
+ self.recordset[key] = cast(RecordType, value)
72
+
73
+ def __delitem__(self, key: str) -> None:
74
+ super().__delitem__(key)
75
+ del self.recordset[key]
76
+
77
+ def check_value(self, value: T) -> None:
78
+ """Check if value is of expected type."""
79
+ if not isinstance(value, self.allowed_type):
82
80
  raise TypeError(
83
- f"Expected `{ConfigsRecord.__name__}`, but "
84
- f"received `{type(record).__name__}` for the value."
81
+ f"Expected `{self.allowed_type.__name__}`, but "
82
+ f"received `{type(value).__name__}` for the value."
85
83
  )
86
84
 
87
85
 
88
- class RecordSet:
86
+ class RecordSet(TypedDict[str, RecordType]):
89
87
  """RecordSet stores groups of parameters, metrics and configs.
90
88
 
91
- A :code:`RecordSet` is the unified mechanism by which parameters,
92
- metrics and configs can be either stored as part of a
93
- `flwr.common.Context <flwr.common.Context.html>`_ in your apps
94
- or communicated as part of a
95
- `flwr.common.Message <flwr.common.Message.html>`_ between your apps.
89
+ A :class:`RecordSet` is the unified mechanism by which parameters,
90
+ metrics and configs can be either stored as part of a :class:`Context`
91
+ in your apps or communicated as part of a :class:`Message` between
92
+ your apps.
96
93
 
97
94
  Parameters
98
95
  ----------
@@ -127,12 +124,12 @@ class RecordSet:
127
124
  >>> # We can create a ConfigsRecord
128
125
  >>> c_record = ConfigsRecord({"lr": 0.1, "batch-size": 128})
129
126
  >>> # Adding it to the record_set would look like this
130
- >>> my_recordset.configs_records["my_config"] = c_record
127
+ >>> my_recordset["my_config"] = c_record
131
128
  >>>
132
129
  >>> # We can create a MetricsRecord following a similar process
133
130
  >>> m_record = MetricsRecord({"accuracy": 0.93, "losses": [0.23, 0.1]})
134
131
  >>> # Adding it to the record_set would look like this
135
- >>> my_recordset.metrics_records["my_metrics"] = m_record
132
+ >>> my_recordset["my_metrics"] = m_record
136
133
 
137
134
  Adding a :code:`ParametersRecord` follows the same steps as above but first,
138
135
  the array needs to be serialized and represented as a :code:`flwr.common.Array`.
@@ -151,52 +148,62 @@ class RecordSet:
151
148
  >>> p_record = ParametersRecord({"my_array": arr})
152
149
  >>>
153
150
  >>> # Adding it to the record_set would look like this
154
- >>> my_recordset.parameters_records["my_parameters"] = p_record
151
+ >>> my_recordset["my_parameters"] = p_record
155
152
 
156
153
  For additional examples on how to construct each of the records types shown
157
154
  above, please refer to the documentation for :code:`ConfigsRecord`,
158
155
  :code:`MetricsRecord` and :code:`ParametersRecord`.
159
156
  """
160
157
 
161
- def __init__(
162
- self,
163
- parameters_records: dict[str, ParametersRecord] | None = None,
164
- metrics_records: dict[str, MetricsRecord] | None = None,
165
- configs_records: dict[str, ConfigsRecord] | None = None,
166
- ) -> None:
167
- data = RecordSetData(
168
- parameters_records=parameters_records,
169
- metrics_records=metrics_records,
170
- configs_records=configs_records,
171
- )
172
- self.__dict__["_data"] = data
158
+ def __init__(self, records: dict[str, RecordType] | None = None) -> None:
159
+ super().__init__(_check_key, _check_value)
160
+ if records is not None:
161
+ for key, record in records.items():
162
+ self[key] = record
173
163
 
174
164
  @property
175
165
  def parameters_records(self) -> TypedDict[str, ParametersRecord]:
176
- """Dictionary holding ParametersRecord instances."""
177
- data = cast(RecordSetData, self.__dict__["_data"])
178
- return data.parameters_records
166
+ """Dictionary holding only ParametersRecord instances."""
167
+ synced_dict = _SyncedDict[ParametersRecord](self, ParametersRecord)
168
+ for key, record in self.items():
169
+ if isinstance(record, ParametersRecord):
170
+ synced_dict[key] = record
171
+ return synced_dict
179
172
 
180
173
  @property
181
174
  def metrics_records(self) -> TypedDict[str, MetricsRecord]:
182
- """Dictionary holding MetricsRecord instances."""
183
- data = cast(RecordSetData, self.__dict__["_data"])
184
- return data.metrics_records
175
+ """Dictionary holding only MetricsRecord instances."""
176
+ synced_dict = _SyncedDict[MetricsRecord](self, MetricsRecord)
177
+ for key, record in self.items():
178
+ if isinstance(record, MetricsRecord):
179
+ synced_dict[key] = record
180
+ return synced_dict
185
181
 
186
182
  @property
187
183
  def configs_records(self) -> TypedDict[str, ConfigsRecord]:
188
- """Dictionary holding ConfigsRecord instances."""
189
- data = cast(RecordSetData, self.__dict__["_data"])
190
- return data.configs_records
184
+ """Dictionary holding only ConfigsRecord instances."""
185
+ synced_dict = _SyncedDict[ConfigsRecord](self, ConfigsRecord)
186
+ for key, record in self.items():
187
+ if isinstance(record, ConfigsRecord):
188
+ synced_dict[key] = record
189
+ return synced_dict
191
190
 
192
191
  def __repr__(self) -> str:
193
192
  """Return a string representation of this instance."""
194
193
  flds = ("parameters_records", "metrics_records", "configs_records")
195
- view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
196
- return f"{self.__class__.__qualname__}({view})"
197
-
198
- def __eq__(self, other: object) -> bool:
199
- """Compare two instances of the class."""
200
- if not isinstance(other, self.__class__):
201
- raise NotImplementedError
202
- return self.__dict__ == other.__dict__
194
+ fld_views = [f"{fld}={dict(getattr(self, fld))!r}" for fld in flds]
195
+ view = indent(",\n".join(fld_views), " ")
196
+ return f"{self.__class__.__qualname__}(\n{view}\n)"
197
+
198
+ def __setitem__(self, key: str, value: RecordType) -> None:
199
+ """Set the given key to the given value after type checking."""
200
+ original_value = self.get(key, None)
201
+ super().__setitem__(key, value)
202
+ if original_value is not None and not isinstance(value, type(original_value)):
203
+ log(
204
+ WARN,
205
+ "Key '%s' was overwritten: record of type `%s` replaced with type `%s`",
206
+ key,
207
+ type(original_value).__name__,
208
+ type(value).__name__,
209
+ )
@@ -25,7 +25,11 @@ from flwr.common.typing import NDArrayFloat, NDArrayInt
25
25
  def _stochastic_round(arr: NDArrayFloat) -> NDArrayInt:
26
26
  ret: NDArrayInt = np.ceil(arr).astype(np.int32)
27
27
  rand_arr = np.random.rand(*ret.shape)
28
- ret[rand_arr < ret - arr] -= 1
28
+ if len(ret.shape) == 0:
29
+ if rand_arr < ret - arr:
30
+ ret -= 1
31
+ else:
32
+ ret[rand_arr < ret - arr] -= 1
29
33
  return ret
30
34
 
31
35
 
flwr/common/serde.py CHANGED
@@ -21,8 +21,6 @@ from typing import Any, TypeVar, cast
21
21
 
22
22
  from google.protobuf.message import Message as GrpcMessage
23
23
 
24
- from flwr.common.constant import SUPERLINK_NODE_ID
25
-
26
24
  # pylint: disable=E0611
27
25
  from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
28
26
  from flwr.proto.error_pb2 import Error as ProtoError
@@ -30,7 +28,6 @@ from flwr.proto.fab_pb2 import Fab as ProtoFab
30
28
  from flwr.proto.message_pb2 import Context as ProtoContext
31
29
  from flwr.proto.message_pb2 import Message as ProtoMessage
32
30
  from flwr.proto.message_pb2 import Metadata as ProtoMetadata
33
- from flwr.proto.node_pb2 import Node
34
31
  from flwr.proto.recordset_pb2 import Array as ProtoArray
35
32
  from flwr.proto.recordset_pb2 import BoolList, BytesList
36
33
  from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
@@ -43,7 +40,6 @@ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
43
40
  from flwr.proto.recordset_pb2 import SintList, StringList, UintList
44
41
  from flwr.proto.run_pb2 import Run as ProtoRun
45
42
  from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
46
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
47
43
  from flwr.proto.transport_pb2 import (
48
44
  ClientMessage,
49
45
  Code,
@@ -583,128 +579,14 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
583
579
 
584
580
  def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
585
581
  """Deserialize RecordSet from ProtoBuf."""
586
- return RecordSet(
587
- parameters_records={
588
- k: parameters_record_from_proto(v)
589
- for k, v in recordset_proto.parameters.items()
590
- },
591
- metrics_records={
592
- k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
593
- },
594
- configs_records={
595
- k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
596
- },
597
- )
598
-
599
-
600
- # === Message ===
601
-
602
-
603
- def message_to_taskins(message: Message) -> TaskIns:
604
- """Create a TaskIns from the Message."""
605
- md = message.metadata
606
- return TaskIns(
607
- group_id=md.group_id,
608
- run_id=md.run_id,
609
- task=Task(
610
- producer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
611
- consumer=Node(node_id=md.dst_node_id),
612
- created_at=md.created_at,
613
- ttl=md.ttl,
614
- ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
615
- task_type=md.message_type,
616
- recordset=(
617
- recordset_to_proto(message.content) if message.has_content() else None
618
- ),
619
- error=error_to_proto(message.error) if message.has_error() else None,
620
- ),
621
- )
622
-
623
-
624
- def message_from_taskins(taskins: TaskIns) -> Message:
625
- """Create a Message from the TaskIns."""
626
- # Retrieve the Metadata
627
- metadata = Metadata(
628
- run_id=taskins.run_id,
629
- message_id=taskins.task_id,
630
- src_node_id=taskins.task.producer.node_id,
631
- dst_node_id=taskins.task.consumer.node_id,
632
- reply_to_message=taskins.task.ancestry[0] if taskins.task.ancestry else "",
633
- group_id=taskins.group_id,
634
- ttl=taskins.task.ttl,
635
- message_type=taskins.task.task_type,
636
- )
637
-
638
- # Construct Message
639
- message = Message(
640
- metadata=metadata,
641
- content=(
642
- recordset_from_proto(taskins.task.recordset)
643
- if taskins.task.HasField("recordset")
644
- else None
645
- ),
646
- error=(
647
- error_from_proto(taskins.task.error)
648
- if taskins.task.HasField("error")
649
- else None
650
- ),
651
- )
652
- message.metadata.created_at = taskins.task.created_at
653
- return message
654
-
655
-
656
- def message_to_taskres(message: Message) -> TaskRes:
657
- """Create a TaskRes from the Message."""
658
- md = message.metadata
659
- return TaskRes(
660
- task_id="", # This will be generated by the server
661
- group_id=md.group_id,
662
- run_id=md.run_id,
663
- task=Task(
664
- producer=Node(node_id=md.src_node_id),
665
- consumer=Node(node_id=SUPERLINK_NODE_ID), # Assume driver node
666
- created_at=md.created_at,
667
- ttl=md.ttl,
668
- ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
669
- task_type=md.message_type,
670
- recordset=(
671
- recordset_to_proto(message.content) if message.has_content() else None
672
- ),
673
- error=error_to_proto(message.error) if message.has_error() else None,
674
- ),
675
- )
676
-
677
-
678
- def message_from_taskres(taskres: TaskRes) -> Message:
679
- """Create a Message from the TaskIns."""
680
- # Retrieve the MetaData
681
- metadata = Metadata(
682
- run_id=taskres.run_id,
683
- message_id=taskres.task_id,
684
- src_node_id=taskres.task.producer.node_id,
685
- dst_node_id=taskres.task.consumer.node_id,
686
- reply_to_message=taskres.task.ancestry[0] if taskres.task.ancestry else "",
687
- group_id=taskres.group_id,
688
- ttl=taskres.task.ttl,
689
- message_type=taskres.task.task_type,
690
- )
691
-
692
- # Construct the Message
693
- message = Message(
694
- metadata=metadata,
695
- content=(
696
- recordset_from_proto(taskres.task.recordset)
697
- if taskres.task.HasField("recordset")
698
- else None
699
- ),
700
- error=(
701
- error_from_proto(taskres.task.error)
702
- if taskres.task.HasField("error")
703
- else None
704
- ),
705
- )
706
- message.metadata.created_at = taskres.task.created_at
707
- return message
582
+ ret = RecordSet()
583
+ for k, p_record_proto in recordset_proto.parameters.items():
584
+ ret[k] = parameters_record_from_proto(p_record_proto)
585
+ for k, m_record_proto in recordset_proto.metrics.items():
586
+ ret[k] = metrics_record_from_proto(m_record_proto)
587
+ for k, c_record_proto in recordset_proto.configs.items():
588
+ ret[k] = configs_record_from_proto(c_record_proto)
589
+ return ret
708
590
 
709
591
 
710
592
  # === FAB ===
flwr/common/telemetry.py CHANGED
@@ -181,16 +181,6 @@ class EventType(str, Enum):
181
181
  RUN_SUPERNODE_ENTER = auto()
182
182
  RUN_SUPERNODE_LEAVE = auto()
183
183
 
184
- # --- DEPRECATED -------------------------------------------------------------------
185
-
186
- # [DEPRECATED] CLI: `flower-server-app`
187
- RUN_SERVER_APP_ENTER = auto()
188
- RUN_SERVER_APP_LEAVE = auto()
189
-
190
- # [DEPRECATED] CLI: `flower-client-app`
191
- RUN_CLIENT_APP_ENTER = auto()
192
- RUN_CLIENT_APP_LEAVE = auto()
193
-
194
184
 
195
185
  # Use the ThreadPoolExecutor with max_workers=1 to have a queue
196
186
  # and also ensure that telemetry calls are not blocking.
flwr/common/typing.py CHANGED
@@ -286,3 +286,39 @@ class UserAuthCredentials:
286
286
 
287
287
  access_token: str
288
288
  refresh_token: str
289
+
290
+
291
+ @dataclass
292
+ class UserInfo:
293
+ """User information for event log."""
294
+
295
+ user_id: Optional[str]
296
+ user_name: Optional[str]
297
+
298
+
299
+ @dataclass
300
+ class Actor:
301
+ """Event log actor."""
302
+
303
+ actor_id: Optional[str]
304
+ description: Optional[str]
305
+ ip_address: str
306
+
307
+
308
+ @dataclass
309
+ class Event:
310
+ """Event log description."""
311
+
312
+ action: str
313
+ run_id: Optional[int]
314
+ fab_hash: Optional[str]
315
+
316
+
317
+ @dataclass
318
+ class LogEntry:
319
+ """Event log record."""
320
+
321
+ timestamp: str
322
+ actor: Actor
323
+ event: Event
324
+ status: str
flwr/server/app.py CHANGED
@@ -90,7 +90,11 @@ BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
90
90
 
91
91
 
92
92
  try:
93
- from flwr.ee import add_ee_args_superlink, get_exec_auth_plugins
93
+ from flwr.ee import (
94
+ add_ee_args_superlink,
95
+ get_dashboard_server,
96
+ get_exec_auth_plugins,
97
+ )
94
98
  except ImportError:
95
99
 
96
100
  # pylint: disable-next=unused-argument
@@ -431,6 +435,17 @@ def run_superlink() -> None:
431
435
  scheduler_th.start()
432
436
  bckg_threads.append(scheduler_th)
433
437
 
438
+ # Add Dashboard server if available
439
+ if dashboard_address := getattr(args, "dashboard_address", None):
440
+ dashboard_address_str, _, _ = _format_address(dashboard_address)
441
+ dashboard_server = get_dashboard_server(
442
+ address=dashboard_address_str,
443
+ state_factory=state_factory,
444
+ certificates=None,
445
+ )
446
+
447
+ grpc_servers.append(dashboard_server)
448
+
434
449
  # Graceful shutdown
435
450
  register_exit_handlers(
436
451
  event_type=EventType.RUN_SUPERLINK_LEAVE,
@@ -710,7 +725,8 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
710
725
  "--insecure",
711
726
  action="store_true",
712
727
  help="Run the server without HTTPS, regardless of whether certificate "
713
- "paths are provided. By default, the server runs with HTTPS enabled. "
728
+ "paths are provided. Data transmitted between the gRPC client and server "
729
+ "is not encrypted. By default, the server runs with HTTPS enabled. "
714
730
  "Use this flag only if you understand the risks.",
715
731
  )
716
732
  parser.add_argument(
flwr/server/compat/app.py CHANGED
@@ -79,10 +79,13 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
79
79
  log(INFO, "")
80
80
 
81
81
  # Start the thread updating nodes
82
- thread, f_stop = start_update_client_manager_thread(
82
+ thread, f_stop, c_done = start_update_client_manager_thread(
83
83
  driver, initialized_server.client_manager()
84
84
  )
85
85
 
86
+ # Wait until the node registration done
87
+ c_done.wait()
88
+
86
89
  # Start training
87
90
  hist = run_fl(
88
91
  server=initialized_server,
@@ -27,7 +27,7 @@ from ..driver import Driver
27
27
  def start_update_client_manager_thread(
28
28
  driver: Driver,
29
29
  client_manager: ClientManager,
30
- ) -> tuple[threading.Thread, threading.Event]:
30
+ ) -> tuple[threading.Thread, threading.Event, threading.Event]:
31
31
  """Periodically update the nodes list in the client manager in a thread.
32
32
 
33
33
  This function starts a thread that periodically uses the associated driver to
@@ -51,26 +51,31 @@ def start_update_client_manager_thread(
51
51
  A thread that updates the ClientManager and handles the stop event.
52
52
  threading.Event
53
53
  An event that, when set, signals the thread to stop.
54
+ threading.Event
55
+ An event that, when set, signals the node registration done.
54
56
  """
55
57
  f_stop = threading.Event()
58
+ c_done = threading.Event()
56
59
  thread = threading.Thread(
57
60
  target=_update_client_manager,
58
61
  args=(
59
62
  driver,
60
63
  client_manager,
61
64
  f_stop,
65
+ c_done,
62
66
  ),
63
67
  daemon=True,
64
68
  )
65
69
  thread.start()
66
70
 
67
- return thread, f_stop
71
+ return thread, f_stop, c_done
68
72
 
69
73
 
70
74
  def _update_client_manager(
71
75
  driver: Driver,
72
76
  client_manager: ClientManager,
73
77
  f_stop: threading.Event,
78
+ c_done: threading.Event,
74
79
  ) -> None:
75
80
  """Update the nodes list in the client manager."""
76
81
  # Loop until the driver is disconnected
@@ -102,6 +107,9 @@ def _update_client_manager(
102
107
  else:
103
108
  raise RuntimeError("Could not register node.")
104
109
 
110
+ # Flag first pass for nodes registration is completed
111
+ c_done.set()
112
+
105
113
  # Sleep for 3 seconds
106
114
  if not f_stop.is_set():
107
115
  f_stop.wait(3)
@@ -104,7 +104,7 @@ class DriverClientProxy(ClientProxy):
104
104
  def _send_receive_recordset(
105
105
  self,
106
106
  recordset: RecordSet,
107
- task_type: str,
107
+ message_type: str,
108
108
  timeout: Optional[float],
109
109
  group_id: Optional[int],
110
110
  ) -> RecordSet:
@@ -112,7 +112,7 @@ class DriverClientProxy(ClientProxy):
112
112
  # Create message
113
113
  message = self.driver.create_message(
114
114
  content=recordset,
115
- message_type=task_type,
115
+ message_type=message_type,
116
116
  dst_node_id=self.node_id,
117
117
  group_id=str(group_id) if group_id else "",
118
118
  ttl=timeout,
@@ -85,7 +85,7 @@ class Driver(ABC):
85
85
  """
86
86
 
87
87
  @abstractmethod
88
- def get_node_ids(self) -> list[int]:
88
+ def get_node_ids(self) -> Iterable[int]:
89
89
  """Get node IDs."""
90
90
 
91
91
  @abstractmethod
@@ -183,7 +183,7 @@ class GrpcDriver(Driver):
183
183
  )
184
184
  return Message(metadata=metadata, content=content)
185
185
 
186
- def get_node_ids(self) -> list[int]:
186
+ def get_node_ids(self) -> Iterable[int]:
187
187
  """Get node IDs."""
188
188
  # Call GrpcDriverStub method
189
189
  res: GetNodesResponse = self._stub.GetNodes(
@@ -212,6 +212,15 @@ class GrpcDriver(Driver):
212
212
  messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
213
213
  )
214
214
  )
215
+ if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
216
+ list(message_proto_list)
217
+ ):
218
+ log(
219
+ WARNING,
220
+ "Not all messages could be pushed to the SuperLink. The returned "
221
+ "list has `None` for those messages (the order is preserved as passed "
222
+ "to `push_messages`). This could be due to a malformed message.",
223
+ )
215
224
  return list(res.message_ids)
216
225
 
217
226
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: