flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +82 -57
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  9. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  10. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  20. flwr/cli/run/run.py +10 -18
  21. flwr/cli/stop.py +2 -2
  22. flwr/cli/utils.py +31 -5
  23. flwr/client/__init__.py +2 -2
  24. flwr/client/client_app.py +1 -1
  25. flwr/client/clientapp/__init__.py +0 -7
  26. flwr/client/grpc_adapter_client/connection.py +4 -4
  27. flwr/client/grpc_rere_client/connection.py +130 -60
  28. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  29. flwr/client/message_handler/message_handler.py +1 -1
  30. flwr/client/mod/comms_mods.py +36 -17
  31. flwr/client/rest_client/connection.py +173 -67
  32. flwr/clientapp/__init__.py +15 -0
  33. flwr/common/__init__.py +2 -2
  34. flwr/common/auth_plugin/__init__.py +2 -0
  35. flwr/common/auth_plugin/auth_plugin.py +29 -3
  36. flwr/common/constant.py +36 -7
  37. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  38. flwr/common/exit_handlers.py +30 -0
  39. flwr/common/heartbeat.py +165 -0
  40. flwr/common/inflatable.py +290 -0
  41. flwr/common/inflatable_grpc_utils.py +99 -0
  42. flwr/common/inflatable_rest_utils.py +99 -0
  43. flwr/common/inflatable_utils.py +341 -0
  44. flwr/common/message.py +110 -242
  45. flwr/common/record/__init__.py +2 -1
  46. flwr/common/record/array.py +323 -0
  47. flwr/common/record/arrayrecord.py +103 -225
  48. flwr/common/record/configrecord.py +59 -4
  49. flwr/common/record/conversion_utils.py +1 -1
  50. flwr/common/record/metricrecord.py +55 -4
  51. flwr/common/record/recorddict.py +69 -1
  52. flwr/common/recorddict_compat.py +2 -2
  53. flwr/common/retry_invoker.py +5 -1
  54. flwr/common/serde.py +59 -183
  55. flwr/common/serde_utils.py +175 -0
  56. flwr/common/typing.py +5 -3
  57. flwr/compat/__init__.py +15 -0
  58. flwr/compat/client/__init__.py +15 -0
  59. flwr/{client → compat/client}/app.py +19 -159
  60. flwr/compat/common/__init__.py +15 -0
  61. flwr/compat/server/__init__.py +15 -0
  62. flwr/compat/server/app.py +174 -0
  63. flwr/compat/simulation/__init__.py +15 -0
  64. flwr/proto/fleet_pb2.py +32 -27
  65. flwr/proto/fleet_pb2.pyi +49 -35
  66. flwr/proto/fleet_pb2_grpc.py +117 -13
  67. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  68. flwr/proto/heartbeat_pb2.py +33 -0
  69. flwr/proto/heartbeat_pb2.pyi +66 -0
  70. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  71. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  72. flwr/proto/message_pb2.py +28 -11
  73. flwr/proto/message_pb2.pyi +125 -0
  74. flwr/proto/recorddict_pb2.py +16 -28
  75. flwr/proto/recorddict_pb2.pyi +46 -64
  76. flwr/proto/run_pb2.py +24 -32
  77. flwr/proto/run_pb2.pyi +4 -52
  78. flwr/proto/serverappio_pb2.py +32 -23
  79. flwr/proto/serverappio_pb2.pyi +45 -3
  80. flwr/proto/serverappio_pb2_grpc.py +138 -34
  81. flwr/proto/serverappio_pb2_grpc.pyi +54 -13
  82. flwr/proto/simulationio_pb2.py +12 -11
  83. flwr/proto/simulationio_pb2_grpc.py +35 -0
  84. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  85. flwr/server/__init__.py +1 -1
  86. flwr/server/app.py +68 -186
  87. flwr/server/compat/app_utils.py +50 -28
  88. flwr/server/fleet_event_log_interceptor.py +2 -2
  89. flwr/server/grid/grpc_grid.py +104 -34
  90. flwr/server/grid/inmemory_grid.py +5 -4
  91. flwr/server/serverapp/app.py +18 -0
  92. flwr/server/superlink/ffs/__init__.py +2 -0
  93. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
  94. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
  95. flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
  96. flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
  97. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  98. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  99. flwr/server/superlink/linkstate/linkstate.py +53 -20
  100. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  101. flwr/server/superlink/linkstate/utils.py +33 -29
  102. flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
  103. flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
  104. flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
  105. flwr/server/superlink/utils.py +44 -2
  106. flwr/server/utils/validator.py +2 -2
  107. flwr/serverapp/__init__.py +15 -0
  108. flwr/simulation/app.py +17 -0
  109. flwr/supercore/__init__.py +15 -0
  110. flwr/supercore/object_store/__init__.py +24 -0
  111. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  112. flwr/supercore/object_store/object_store.py +192 -0
  113. flwr/supercore/object_store/object_store_factory.py +44 -0
  114. flwr/superexec/deployment.py +6 -2
  115. flwr/superexec/exec_event_log_interceptor.py +4 -4
  116. flwr/superexec/exec_grpc.py +7 -3
  117. flwr/superexec/exec_servicer.py +125 -23
  118. flwr/superexec/exec_user_auth_interceptor.py +37 -8
  119. flwr/superexec/executor.py +4 -0
  120. flwr/superexec/simulation.py +7 -1
  121. flwr/superlink/__init__.py +15 -0
  122. flwr/{client/supernode → supernode}/__init__.py +0 -7
  123. flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
  124. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
  125. flwr/supernode/cli/flwr_clientapp.py +81 -0
  126. flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
  127. flwr/supernode/nodestate/nodestate.py +212 -0
  128. flwr/supernode/runtime/__init__.py +15 -0
  129. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
  130. flwr/supernode/servicer/__init__.py +15 -0
  131. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  132. flwr/supernode/start_client_internal.py +491 -0
  133. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
  134. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
  135. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
  136. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
  137. flwr/client/heartbeat.py +0 -74
  138. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  139. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  140. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  141. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  142. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
  143. /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
flwr/common/message.py CHANGED
@@ -18,14 +18,32 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from logging import WARNING
21
- from typing import Any, Optional, cast, overload
21
+ from typing import Any, cast, overload
22
22
 
23
23
  from flwr.common.date import now
24
24
  from flwr.common.logger import warn_deprecated_feature
25
-
26
- from .constant import MESSAGE_TTL_TOLERANCE, MessageType, MessageTypeLegacy
25
+ from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
26
+ from flwr.proto.message_pb2 import Metadata as ProtoMetadata # pylint: disable=E0611
27
+ from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
28
+
29
+ from ..app.error import Error
30
+ from ..app.metadata import Metadata
31
+ from .constant import MESSAGE_TTL_TOLERANCE
32
+ from .inflatable import (
33
+ InflatableObject,
34
+ add_header_to_object_body,
35
+ get_descendant_object_ids,
36
+ get_object_body,
37
+ get_object_children_ids_from_object_content,
38
+ )
27
39
  from .logger import log
28
40
  from .record import RecordDict
41
+ from .serde_utils import (
42
+ error_from_proto,
43
+ error_to_proto,
44
+ metadata_from_proto,
45
+ metadata_to_proto,
46
+ )
29
47
 
30
48
  DEFAULT_TTL = 43200 # This is 12 hours
31
49
  MESSAGE_INIT_ERROR_MESSAGE = (
@@ -56,203 +74,7 @@ class MessageInitializationError(TypeError):
56
74
  super().__init__(message or MESSAGE_INIT_ERROR_MESSAGE)
57
75
 
58
76
 
59
- class Metadata: # pylint: disable=too-many-instance-attributes
60
- """The class representing metadata associated with the current message.
61
-
62
- Parameters
63
- ----------
64
- run_id : int
65
- An identifier for the current run.
66
- message_id : str
67
- An identifier for the current message.
68
- src_node_id : int
69
- An identifier for the node sending this message.
70
- dst_node_id : int
71
- An identifier for the node receiving this message.
72
- reply_to_message_id : str
73
- An identifier for the message to which this message is a reply.
74
- group_id : str
75
- An identifier for grouping messages. In some settings,
76
- this is used as the FL round.
77
- created_at : float
78
- Unix timestamp when the message was created.
79
- ttl : float
80
- Time-to-live for this message in seconds.
81
- message_type : str
82
- A string that encodes the action to be executed on
83
- the receiving end.
84
- """
85
-
86
- def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
87
- self,
88
- run_id: int,
89
- message_id: str,
90
- src_node_id: int,
91
- dst_node_id: int,
92
- reply_to_message_id: str,
93
- group_id: str,
94
- created_at: float,
95
- ttl: float,
96
- message_type: str,
97
- ) -> None:
98
- var_dict = {
99
- "_run_id": run_id,
100
- "_message_id": message_id,
101
- "_src_node_id": src_node_id,
102
- "_dst_node_id": dst_node_id,
103
- "_reply_to_message_id": reply_to_message_id,
104
- "_group_id": group_id,
105
- "_created_at": created_at,
106
- "_ttl": ttl,
107
- "_message_type": message_type,
108
- }
109
- self.__dict__.update(var_dict)
110
- self.message_type = message_type # Trigger validation
111
-
112
- @property
113
- def run_id(self) -> int:
114
- """An identifier for the current run."""
115
- return cast(int, self.__dict__["_run_id"])
116
-
117
- @property
118
- def message_id(self) -> str:
119
- """An identifier for the current message."""
120
- return cast(str, self.__dict__["_message_id"])
121
-
122
- @property
123
- def src_node_id(self) -> int:
124
- """An identifier for the node sending this message."""
125
- return cast(int, self.__dict__["_src_node_id"])
126
-
127
- @property
128
- def reply_to_message_id(self) -> str:
129
- """An identifier for the message to which this message is a reply."""
130
- return cast(str, self.__dict__["_reply_to_message_id"])
131
-
132
- @property
133
- def dst_node_id(self) -> int:
134
- """An identifier for the node receiving this message."""
135
- return cast(int, self.__dict__["_dst_node_id"])
136
-
137
- @dst_node_id.setter
138
- def dst_node_id(self, value: int) -> None:
139
- """Set dst_node_id."""
140
- self.__dict__["_dst_node_id"] = value
141
-
142
- @property
143
- def group_id(self) -> str:
144
- """An identifier for grouping messages."""
145
- return cast(str, self.__dict__["_group_id"])
146
-
147
- @group_id.setter
148
- def group_id(self, value: str) -> None:
149
- """Set group_id."""
150
- self.__dict__["_group_id"] = value
151
-
152
- @property
153
- def created_at(self) -> float:
154
- """Unix timestamp when the message was created."""
155
- return cast(float, self.__dict__["_created_at"])
156
-
157
- @created_at.setter
158
- def created_at(self, value: float) -> None:
159
- """Set creation timestamp of this message."""
160
- self.__dict__["_created_at"] = value
161
-
162
- @property
163
- def delivered_at(self) -> str:
164
- """Unix timestamp when the message was delivered."""
165
- return cast(str, self.__dict__["_delivered_at"])
166
-
167
- @delivered_at.setter
168
- def delivered_at(self, value: str) -> None:
169
- """Set delivery timestamp of this message."""
170
- self.__dict__["_delivered_at"] = value
171
-
172
- @property
173
- def ttl(self) -> float:
174
- """Time-to-live for this message."""
175
- return cast(float, self.__dict__["_ttl"])
176
-
177
- @ttl.setter
178
- def ttl(self, value: float) -> None:
179
- """Set ttl."""
180
- self.__dict__["_ttl"] = value
181
-
182
- @property
183
- def message_type(self) -> str:
184
- """A string that encodes the action to be executed on the receiving end."""
185
- return cast(str, self.__dict__["_message_type"])
186
-
187
- @message_type.setter
188
- def message_type(self, value: str) -> None:
189
- """Set message_type."""
190
- # Validate message type
191
- if validate_legacy_message_type(value):
192
- pass # Backward compatibility for legacy message types
193
- elif not validate_message_type(value):
194
- raise ValueError(
195
- f"Invalid message type: '{value}'. "
196
- "Expected format: '<category>' or '<category>.<action>', "
197
- "where <category> must be 'train', 'evaluate', or 'query', "
198
- "and <action> must be a valid Python identifier."
199
- )
200
-
201
- self.__dict__["_message_type"] = value
202
-
203
- def __repr__(self) -> str:
204
- """Return a string representation of this instance."""
205
- view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
206
- return f"{self.__class__.__qualname__}({view})"
207
-
208
- def __eq__(self, other: object) -> bool:
209
- """Compare two instances of the class."""
210
- if not isinstance(other, self.__class__):
211
- raise NotImplementedError
212
- return self.__dict__ == other.__dict__
213
-
214
-
215
- class Error:
216
- """The class storing information about an error that occurred.
217
-
218
- Parameters
219
- ----------
220
- code : int
221
- An identifier for the error.
222
- reason : Optional[str]
223
- A reason for why the error arose (e.g. an exception stack-trace)
224
- """
225
-
226
- def __init__(self, code: int, reason: str | None = None) -> None:
227
- var_dict = {
228
- "_code": code,
229
- "_reason": reason,
230
- }
231
- self.__dict__.update(var_dict)
232
-
233
- @property
234
- def code(self) -> int:
235
- """Error code."""
236
- return cast(int, self.__dict__["_code"])
237
-
238
- @property
239
- def reason(self) -> str | None:
240
- """Reason reported about the error."""
241
- return cast(Optional[str], self.__dict__["_reason"])
242
-
243
- def __repr__(self) -> str:
244
- """Return a string representation of this instance."""
245
- view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
246
- return f"{self.__class__.__qualname__}({view})"
247
-
248
- def __eq__(self, other: object) -> bool:
249
- """Compare two instances of the class."""
250
- if not isinstance(other, self.__class__):
251
- raise NotImplementedError
252
- return self.__dict__ == other.__dict__
253
-
254
-
255
- class Message:
77
+ class Message(InflatableObject):
256
78
  """Represents a message exchanged between ClientApp and ServerApp.
257
79
 
258
80
  This class encapsulates the payload and metadata necessary for communication
@@ -525,6 +347,77 @@ class Message:
525
347
  )
526
348
  return f"{self.__class__.__qualname__}({view})"
527
349
 
350
+ @property
351
+ def children(self) -> dict[str, InflatableObject] | None:
352
+ """Return a dictionary of a single RecordDict with its Object IDs as key."""
353
+ return {self.content.object_id: self.content} if self.has_content() else None
354
+
355
+ def deflate(self) -> bytes:
356
+ """Deflate message."""
357
+ # Exclude message_id from serialization
358
+ proto_metadata: ProtoMetadata = metadata_to_proto(self.metadata)
359
+ proto_metadata.message_id = ""
360
+ # Store message metadata and error in object body
361
+ obj_body = ProtoMessage(
362
+ metadata=proto_metadata,
363
+ content=None,
364
+ error=error_to_proto(self.error) if self.has_error() else None,
365
+ ).SerializeToString(deterministic=True)
366
+
367
+ return add_header_to_object_body(object_body=obj_body, obj=self)
368
+
369
+ @classmethod
370
+ def inflate(
371
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
372
+ ) -> Message:
373
+ """Inflate an Message from bytes.
374
+
375
+ Parameters
376
+ ----------
377
+ object_content : bytes
378
+ The deflated object content of the Message.
379
+ children : Optional[dict[str, InflatableObject]] (default: None)
380
+ Dictionary of children InflatableObjects mapped to their Object IDs.
381
+ These children enable the full inflation of the Message.
382
+
383
+ Returns
384
+ -------
385
+ Message
386
+ The inflated Message.
387
+ """
388
+ if children is None:
389
+ children = {}
390
+
391
+ # Get the children id from the deflated message
392
+ children_ids = get_object_children_ids_from_object_content(object_content)
393
+
394
+ # If the message had content, only one children is possible
395
+ # If the message carried an error, the returned listed should be empty
396
+ if children_ids != list(children.keys()):
397
+ raise ValueError(
398
+ f"Mismatch in children object IDs: expected {children_ids}, but "
399
+ f"received {list(children.keys())}. The provided children must exactly "
400
+ "match the IDs specified in the object head."
401
+ )
402
+
403
+ # Inflate content
404
+ obj_body = get_object_body(object_content, cls)
405
+ proto_message = ProtoMessage.FromString(obj_body)
406
+
407
+ # Prepare content if error wasn't set in protobuf message
408
+ if proto_message.HasField("error"):
409
+ content = None
410
+ error = error_from_proto(proto_message.error)
411
+ else:
412
+ content = cast(RecordDict, children[children_ids[0]])
413
+ error = None
414
+ # Return message
415
+ return make_message(
416
+ metadata=metadata_from_proto(proto_message.metadata),
417
+ content=content,
418
+ error=error,
419
+ )
420
+
528
421
 
529
422
  def make_message(
530
423
  metadata: Metadata, content: RecordDict | None = None, error: Error | None = None
@@ -533,6 +426,17 @@ def make_message(
533
426
  return Message(metadata=metadata, content=content, error=error) # type: ignore
534
427
 
535
428
 
429
+ def remove_content_from_message(message: Message) -> Message:
430
+ """Return a copy of the Message but with an empty RecordDict as content.
431
+
432
+ If message has no content, it returns itself.
433
+ """
434
+ if message.has_error():
435
+ return message
436
+
437
+ return make_message(metadata=message.metadata, content=RecordDict())
438
+
439
+
536
440
  def _limit_reply_ttl(
537
441
  current: float, reply_ttl: float | None, reply_to: Message
538
442
  ) -> float:
@@ -616,46 +520,10 @@ def _check_arg_types( # pylint: disable=too-many-arguments, R0917
616
520
  raise MessageInitializationError()
617
521
 
618
522
 
619
- def validate_message_type(message_type: str) -> bool:
620
- """Validate if the message type is valid.
621
-
622
- A valid message type format must be one of the following:
623
-
624
- - "<category>"
625
- - "<category>.<action>"
626
-
627
- where `category` must be one of "train", "evaluate", or "query",
628
- and `action` must be a valid Python identifier.
629
- """
630
- # Check if conforming to the format "<category>"
631
- valid_types = {
632
- MessageType.TRAIN,
633
- MessageType.EVALUATE,
634
- MessageType.QUERY,
635
- MessageType.SYSTEM,
523
+ def get_message_to_descendant_id_mapping(message: Message) -> dict[str, ObjectIDs]:
524
+ """Construct a mapping between message object_id and that of its descendants."""
525
+ return {
526
+ message.object_id: ObjectIDs(
527
+ object_ids=list(get_descendant_object_ids(message))
528
+ )
636
529
  }
637
- if message_type in valid_types:
638
- return True
639
-
640
- # Check if conforming to the format "<category>.<action>"
641
- if message_type.count(".") != 1:
642
- return False
643
-
644
- category, action = message_type.split(".")
645
- if category in valid_types and action.isidentifier():
646
- return True
647
-
648
- return False
649
-
650
-
651
- def validate_legacy_message_type(message_type: str) -> bool:
652
- """Validate if the legacy message type is valid."""
653
- # Backward compatibility for legacy message types
654
- if message_type in (
655
- MessageTypeLegacy.GET_PARAMETERS,
656
- MessageTypeLegacy.GET_PROPERTIES,
657
- "reconnect",
658
- ):
659
- return True
660
-
661
- return False
@@ -15,7 +15,8 @@
15
15
  """Record APIs."""
16
16
 
17
17
 
18
- from .arrayrecord import Array, ArrayRecord, ParametersRecord
18
+ from .array import Array
19
+ from .arrayrecord import ArrayRecord, ParametersRecord
19
20
  from .configrecord import ConfigRecord, ConfigsRecord
20
21
  from .conversion_utils import array_from_numpy
21
22
  from .metricrecord import MetricRecord, MetricsRecord