flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.8.0.dev20240327__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. flwr/client/client_app.py +4 -4
  2. flwr/client/grpc_client/connection.py +2 -1
  3. flwr/client/message_handler/message_handler.py +3 -2
  4. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  5. flwr/common/__init__.py +2 -0
  6. flwr/common/message.py +34 -13
  7. flwr/common/serde.py +8 -2
  8. flwr/proto/fleet_pb2.py +19 -15
  9. flwr/proto/fleet_pb2.pyi +28 -0
  10. flwr/proto/fleet_pb2_grpc.py +33 -0
  11. flwr/proto/fleet_pb2_grpc.pyi +10 -0
  12. flwr/proto/task_pb2.py +6 -6
  13. flwr/proto/task_pb2.pyi +8 -5
  14. flwr/server/compat/driver_client_proxy.py +9 -1
  15. flwr/server/driver/driver.py +6 -5
  16. flwr/server/superlink/driver/driver_servicer.py +6 -0
  17. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +11 -1
  18. flwr/server/superlink/fleet/message_handler/message_handler.py +14 -0
  19. flwr/server/superlink/state/in_memory_state.py +38 -26
  20. flwr/server/superlink/state/sqlite_state.py +42 -21
  21. flwr/server/superlink/state/state.py +19 -0
  22. flwr/server/utils/validator.py +23 -9
  23. flwr/server/workflow/default_workflows.py +4 -4
  24. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +5 -4
  25. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  26. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/METADATA +1 -1
  27. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/RECORD +30 -30
  28. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/entry_points.txt +0 -0
flwr/client/client_app.py CHANGED
@@ -115,7 +115,7 @@ class ClientApp:
115
115
  >>> def train(message: Message, context: Context) -> Message:
116
116
  >>> print("ClientApp training running")
117
117
  >>> # Create and return an echo reply message
118
- >>> return message.create_reply(content=message.content(), ttl="")
118
+ >>> return message.create_reply(content=message.content())
119
119
  """
120
120
 
121
121
  def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable:
@@ -143,7 +143,7 @@ class ClientApp:
143
143
  >>> def evaluate(message: Message, context: Context) -> Message:
144
144
  >>> print("ClientApp evaluation running")
145
145
  >>> # Create and return an echo reply message
146
- >>> return message.create_reply(content=message.content(), ttl="")
146
+ >>> return message.create_reply(content=message.content())
147
147
  """
148
148
 
149
149
  def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable:
@@ -171,7 +171,7 @@ class ClientApp:
171
171
  >>> def query(message: Message, context: Context) -> Message:
172
172
  >>> print("ClientApp query running")
173
173
  >>> # Create and return an echo reply message
174
- >>> return message.create_reply(content=message.content(), ttl="")
174
+ >>> return message.create_reply(content=message.content())
175
175
  """
176
176
 
177
177
  def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:
@@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError:
218
218
  >>> print("ClientApp {fn_name} running")
219
219
  >>> # Create and return an echo reply message
220
220
  >>> return message.create_reply(
221
- >>> content=message.content(), ttl=""
221
+ >>> content=message.content()
222
222
  >>> )
223
223
  """,
224
224
  )
@@ -23,6 +23,7 @@ from queue import Queue
23
23
  from typing import Callable, Iterator, Optional, Tuple, Union, cast
24
24
 
25
25
  from flwr.common import (
26
+ DEFAULT_TTL,
26
27
  GRPC_MAX_MESSAGE_LENGTH,
27
28
  ConfigsRecord,
28
29
  Message,
@@ -180,7 +181,7 @@ def grpc_connection( # pylint: disable=R0915
180
181
  dst_node_id=0,
181
182
  reply_to_message="",
182
183
  group_id="",
183
- ttl="",
184
+ ttl=DEFAULT_TTL,
184
185
  message_type=message_type,
185
186
  ),
186
187
  content=recordset,
@@ -81,7 +81,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
81
81
  reason = cast(int, disconnect_msg.disconnect_res.reason)
82
82
  recordset = RecordSet()
83
83
  recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
84
- out_message = message.create_reply(recordset, ttl="")
84
+ out_message = message.create_reply(recordset)
85
85
  # Return TaskRes and sleep duration
86
86
  return out_message, sleep_duration
87
87
 
@@ -143,7 +143,7 @@ def handle_legacy_message_from_msgtype(
143
143
  raise ValueError(f"Invalid message type: {message_type}")
144
144
 
145
145
  # Return Message
146
- return message.create_reply(out_recordset, ttl="")
146
+ return message.create_reply(out_recordset)
147
147
 
148
148
 
149
149
  def _reconnect(
@@ -172,6 +172,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
172
172
  and out_meta.reply_to_message == in_meta.message_id
173
173
  and out_meta.group_id == in_meta.group_id
174
174
  and out_meta.message_type == in_meta.message_type
175
+ and out_meta.created_at > in_meta.created_at
175
176
  ):
176
177
  return True
177
178
  return False
@@ -187,7 +187,7 @@ def secaggplus_mod(
187
187
 
188
188
  # Return message
189
189
  out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
190
- return msg.create_reply(out_content, ttl="")
190
+ return msg.create_reply(out_content)
191
191
 
192
192
 
193
193
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
flwr/common/__init__.py CHANGED
@@ -22,6 +22,7 @@ from .date import now as now
22
22
  from .grpc import GRPC_MAX_MESSAGE_LENGTH
23
23
  from .logger import configure as configure
24
24
  from .logger import log as log
25
+ from .message import DEFAULT_TTL
25
26
  from .message import Error as Error
26
27
  from .message import Message as Message
27
28
  from .message import Metadata as Metadata
@@ -87,6 +88,7 @@ __all__ = [
87
88
  "Message",
88
89
  "MessageType",
89
90
  "MessageTypeLegacy",
91
+ "DEFAULT_TTL",
90
92
  "Metadata",
91
93
  "Metrics",
92
94
  "MetricsAggregationFn",
flwr/common/message.py CHANGED
@@ -16,10 +16,13 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
+ import time
19
20
  from dataclasses import dataclass
20
21
 
21
22
  from .record import RecordSet
22
23
 
24
+ DEFAULT_TTL = 3600
25
+
23
26
 
24
27
  @dataclass
25
28
  class Metadata: # pylint: disable=too-many-instance-attributes
@@ -40,8 +43,8 @@ class Metadata: # pylint: disable=too-many-instance-attributes
40
43
  group_id : str
41
44
  An identifier for grouping messages. In some settings,
42
45
  this is used as the FL round.
43
- ttl : str
44
- Time-to-live for this message.
46
+ ttl : float
47
+ Time-to-live for this message in seconds.
45
48
  message_type : str
46
49
  A string that encodes the action to be executed on
47
50
  the receiving end.
@@ -57,9 +60,10 @@ class Metadata: # pylint: disable=too-many-instance-attributes
57
60
  _dst_node_id: int
58
61
  _reply_to_message: str
59
62
  _group_id: str
60
- _ttl: str
63
+ _ttl: float
61
64
  _message_type: str
62
65
  _partition_id: int | None
66
+ _created_at: float # Unix timestamp (in seconds) to be set upon message creation
63
67
 
64
68
  def __init__( # pylint: disable=too-many-arguments
65
69
  self,
@@ -69,7 +73,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
69
73
  dst_node_id: int,
70
74
  reply_to_message: str,
71
75
  group_id: str,
72
- ttl: str,
76
+ ttl: float,
73
77
  message_type: str,
74
78
  partition_id: int | None = None,
75
79
  ) -> None:
@@ -124,12 +128,22 @@ class Metadata: # pylint: disable=too-many-instance-attributes
124
128
  self._group_id = value
125
129
 
126
130
  @property
127
- def ttl(self) -> str:
131
+ def created_at(self) -> float:
132
+ """Unix timestamp when the message was created."""
133
+ return self._created_at
134
+
135
+ @created_at.setter
136
+ def created_at(self, value: float) -> None:
137
+ """Set creation timestamp for this messages."""
138
+ self._created_at = value
139
+
140
+ @property
141
+ def ttl(self) -> float:
128
142
  """Time-to-live for this message."""
129
143
  return self._ttl
130
144
 
131
145
  @ttl.setter
132
- def ttl(self, value: str) -> None:
146
+ def ttl(self, value: float) -> None:
133
147
  """Set ttl."""
134
148
  self._ttl = value
135
149
 
@@ -212,6 +226,9 @@ class Message:
212
226
  ) -> None:
213
227
  self._metadata = metadata
214
228
 
229
+ # Set message creation timestamp
230
+ self._metadata.created_at = time.time()
231
+
215
232
  if not (content is None) ^ (error is None):
216
233
  raise ValueError("Either `content` or `error` must be set, but not both.")
217
234
 
@@ -266,7 +283,7 @@ class Message:
266
283
  """Return True if message has an error, else False."""
267
284
  return self._error is not None
268
285
 
269
- def _create_reply_metadata(self, ttl: str) -> Metadata:
286
+ def _create_reply_metadata(self, ttl: float) -> Metadata:
270
287
  """Construct metadata for a reply message."""
271
288
  return Metadata(
272
289
  run_id=self.metadata.run_id,
@@ -283,7 +300,7 @@ class Message:
283
300
  def create_error_reply(
284
301
  self,
285
302
  error: Error,
286
- ttl: str,
303
+ ttl: float,
287
304
  ) -> Message:
288
305
  """Construct a reply message indicating an error happened.
289
306
 
@@ -291,14 +308,14 @@ class Message:
291
308
  ----------
292
309
  error : Error
293
310
  The error that was encountered.
294
- ttl : str
295
- Time-to-live for this message.
311
+ ttl : float
312
+ Time-to-live for this message in seconds.
296
313
  """
297
314
  # Create reply with error
298
315
  message = Message(metadata=self._create_reply_metadata(ttl), error=error)
299
316
  return message
300
317
 
301
- def create_reply(self, content: RecordSet, ttl: str) -> Message:
318
+ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
302
319
  """Create a reply to this message with specified content and TTL.
303
320
 
304
321
  The method generates a new `Message` as a reply to this message.
@@ -309,14 +326,18 @@ class Message:
309
326
  ----------
310
327
  content : RecordSet
311
328
  The content for the reply message.
312
- ttl : str
313
- Time-to-live for this message.
329
+ ttl : Optional[float] (default: None)
330
+ Time-to-live for this message in seconds. If unset, it will use
331
+ the `common.DEFAULT_TTL` value.
314
332
 
315
333
  Returns
316
334
  -------
317
335
  Message
318
336
  A new `Message` instance representing the reply.
319
337
  """
338
+ if ttl is None:
339
+ ttl = DEFAULT_TTL
340
+
320
341
  return Message(
321
342
  metadata=self._create_reply_metadata(ttl),
322
343
  content=content,
flwr/common/serde.py CHANGED
@@ -575,6 +575,7 @@ def message_to_taskins(message: Message) -> TaskIns:
575
575
  task=Task(
576
576
  producer=Node(node_id=0, anonymous=True), # Assume driver node
577
577
  consumer=Node(node_id=md.dst_node_id, anonymous=False),
578
+ created_at=md.created_at,
578
579
  ttl=md.ttl,
579
580
  ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
580
581
  task_type=md.message_type,
@@ -601,7 +602,7 @@ def message_from_taskins(taskins: TaskIns) -> Message:
601
602
  )
602
603
 
603
604
  # Construct Message
604
- return Message(
605
+ message = Message(
605
606
  metadata=metadata,
606
607
  content=(
607
608
  recordset_from_proto(taskins.task.recordset)
@@ -614,6 +615,8 @@ def message_from_taskins(taskins: TaskIns) -> Message:
614
615
  else None
615
616
  ),
616
617
  )
618
+ message.metadata.created_at = taskins.task.created_at
619
+ return message
617
620
 
618
621
 
619
622
  def message_to_taskres(message: Message) -> TaskRes:
@@ -626,6 +629,7 @@ def message_to_taskres(message: Message) -> TaskRes:
626
629
  task=Task(
627
630
  producer=Node(node_id=md.src_node_id, anonymous=False),
628
631
  consumer=Node(node_id=0, anonymous=True), # Assume driver node
632
+ created_at=md.created_at,
629
633
  ttl=md.ttl,
630
634
  ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
631
635
  task_type=md.message_type,
@@ -652,7 +656,7 @@ def message_from_taskres(taskres: TaskRes) -> Message:
652
656
  )
653
657
 
654
658
  # Construct the Message
655
- return Message(
659
+ message = Message(
656
660
  metadata=metadata,
657
661
  content=(
658
662
  recordset_from_proto(taskres.task.recordset)
@@ -665,3 +669,5 @@ def message_from_taskres(taskres: TaskRes) -> Message:
665
669
  else None
666
670
  ),
667
671
  )
672
+ message.metadata.created_at = taskres.task.created_at
673
+ return message
flwr/proto/fleet_pb2.py CHANGED
@@ -16,7 +16,7 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x02\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -33,18 +33,22 @@ if _descriptor._USE_C_DESCRIPTORS == False:
33
33
  _globals['_DELETENODEREQUEST']._serialized_end=210
34
34
  _globals['_DELETENODERESPONSE']._serialized_start=212
35
35
  _globals['_DELETENODERESPONSE']._serialized_end=232
36
- _globals['_PULLTASKINSREQUEST']._serialized_start=234
37
- _globals['_PULLTASKINSREQUEST']._serialized_end=304
38
- _globals['_PULLTASKINSRESPONSE']._serialized_start=306
39
- _globals['_PULLTASKINSRESPONSE']._serialized_end=413
40
- _globals['_PUSHTASKRESREQUEST']._serialized_start=415
41
- _globals['_PUSHTASKRESREQUEST']._serialized_end=479
42
- _globals['_PUSHTASKRESRESPONSE']._serialized_start=482
43
- _globals['_PUSHTASKRESRESPONSE']._serialized_end=656
44
- _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=610
45
- _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=656
46
- _globals['_RECONNECT']._serialized_start=658
47
- _globals['_RECONNECT']._serialized_end=688
48
- _globals['_FLEET']._serialized_start=691
49
- _globals['_FLEET']._serialized_end=1020
36
+ _globals['_PINGREQUEST']._serialized_start=234
37
+ _globals['_PINGREQUEST']._serialized_end=302
38
+ _globals['_PINGRESPONSE']._serialized_start=304
39
+ _globals['_PINGRESPONSE']._serialized_end=335
40
+ _globals['_PULLTASKINSREQUEST']._serialized_start=337
41
+ _globals['_PULLTASKINSREQUEST']._serialized_end=407
42
+ _globals['_PULLTASKINSRESPONSE']._serialized_start=409
43
+ _globals['_PULLTASKINSRESPONSE']._serialized_end=516
44
+ _globals['_PUSHTASKRESREQUEST']._serialized_start=518
45
+ _globals['_PUSHTASKRESREQUEST']._serialized_end=582
46
+ _globals['_PUSHTASKRESRESPONSE']._serialized_start=585
47
+ _globals['_PUSHTASKRESRESPONSE']._serialized_end=759
48
+ _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=713
49
+ _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=759
50
+ _globals['_RECONNECT']._serialized_start=761
51
+ _globals['_RECONNECT']._serialized_end=791
52
+ _globals['_FLEET']._serialized_start=794
53
+ _globals['_FLEET']._serialized_end=1184
50
54
  # @@protoc_insertion_point(module_scope)
flwr/proto/fleet_pb2.pyi CHANGED
@@ -53,6 +53,34 @@ class DeleteNodeResponse(google.protobuf.message.Message):
53
53
  ) -> None: ...
54
54
  global___DeleteNodeResponse = DeleteNodeResponse
55
55
 
56
+ class PingRequest(google.protobuf.message.Message):
57
+ """Ping messages"""
58
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
59
+ NODE_FIELD_NUMBER: builtins.int
60
+ PING_INTERVAL_FIELD_NUMBER: builtins.int
61
+ @property
62
+ def node(self) -> flwr.proto.node_pb2.Node: ...
63
+ ping_interval: builtins.float
64
+ def __init__(self,
65
+ *,
66
+ node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
67
+ ping_interval: builtins.float = ...,
68
+ ) -> None: ...
69
+ def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
70
+ def ClearField(self, field_name: typing_extensions.Literal["node",b"node","ping_interval",b"ping_interval"]) -> None: ...
71
+ global___PingRequest = PingRequest
72
+
73
+ class PingResponse(google.protobuf.message.Message):
74
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
75
+ SUCCESS_FIELD_NUMBER: builtins.int
76
+ success: builtins.bool
77
+ def __init__(self,
78
+ *,
79
+ success: builtins.bool = ...,
80
+ ) -> None: ...
81
+ def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ...
82
+ global___PingResponse = PingResponse
83
+
56
84
  class PullTaskInsRequest(google.protobuf.message.Message):
57
85
  """PullTaskIns messages"""
58
86
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -24,6 +24,11 @@ class FleetStub(object):
24
24
  request_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.SerializeToString,
25
25
  response_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.FromString,
26
26
  )
27
+ self.Ping = channel.unary_unary(
28
+ '/flwr.proto.Fleet/Ping',
29
+ request_serializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString,
30
+ response_deserializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString,
31
+ )
27
32
  self.PullTaskIns = channel.unary_unary(
28
33
  '/flwr.proto.Fleet/PullTaskIns',
29
34
  request_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString,
@@ -51,6 +56,12 @@ class FleetServicer(object):
51
56
  context.set_details('Method not implemented!')
52
57
  raise NotImplementedError('Method not implemented!')
53
58
 
59
+ def Ping(self, request, context):
60
+ """Missing associated documentation comment in .proto file."""
61
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
62
+ context.set_details('Method not implemented!')
63
+ raise NotImplementedError('Method not implemented!')
64
+
54
65
  def PullTaskIns(self, request, context):
55
66
  """Retrieve one or more tasks, if possible
56
67
 
@@ -82,6 +93,11 @@ def add_FleetServicer_to_server(servicer, server):
82
93
  request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString,
83
94
  response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString,
84
95
  ),
96
+ 'Ping': grpc.unary_unary_rpc_method_handler(
97
+ servicer.Ping,
98
+ request_deserializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.FromString,
99
+ response_serializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.SerializeToString,
100
+ ),
85
101
  'PullTaskIns': grpc.unary_unary_rpc_method_handler(
86
102
  servicer.PullTaskIns,
87
103
  request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString,
@@ -136,6 +152,23 @@ class Fleet(object):
136
152
  options, channel_credentials,
137
153
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
138
154
 
155
+ @staticmethod
156
+ def Ping(request,
157
+ target,
158
+ options=(),
159
+ channel_credentials=None,
160
+ call_credentials=None,
161
+ insecure=False,
162
+ compression=None,
163
+ wait_for_ready=None,
164
+ timeout=None,
165
+ metadata=None):
166
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/Ping',
167
+ flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString,
168
+ flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString,
169
+ options, channel_credentials,
170
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
171
+
139
172
  @staticmethod
140
173
  def PullTaskIns(request,
141
174
  target,
@@ -16,6 +16,10 @@ class FleetStub:
16
16
  flwr.proto.fleet_pb2.DeleteNodeRequest,
17
17
  flwr.proto.fleet_pb2.DeleteNodeResponse]
18
18
 
19
+ Ping: grpc.UnaryUnaryMultiCallable[
20
+ flwr.proto.fleet_pb2.PingRequest,
21
+ flwr.proto.fleet_pb2.PingResponse]
22
+
19
23
  PullTaskIns: grpc.UnaryUnaryMultiCallable[
20
24
  flwr.proto.fleet_pb2.PullTaskInsRequest,
21
25
  flwr.proto.fleet_pb2.PullTaskInsResponse]
@@ -46,6 +50,12 @@ class FleetServicer(metaclass=abc.ABCMeta):
46
50
  context: grpc.ServicerContext,
47
51
  ) -> flwr.proto.fleet_pb2.DeleteNodeResponse: ...
48
52
 
53
+ @abc.abstractmethod
54
+ def Ping(self,
55
+ request: flwr.proto.fleet_pb2.PingRequest,
56
+ context: grpc.ServicerContext,
57
+ ) -> flwr.proto.fleet_pb2.PingResponse: ...
58
+
49
59
  @abc.abstractmethod
50
60
  def PullTaskIns(self,
51
61
  request: flwr.proto.fleet_pb2.PullTaskInsRequest,
flwr/proto/task_pb2.py CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
18
  from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
19
19
 
20
20
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
22
22
 
23
23
  _globals = globals()
24
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -26,9 +26,9 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _glob
26
26
  if _descriptor._USE_C_DESCRIPTORS == False:
27
27
  DESCRIPTOR._options = None
28
28
  _globals['_TASK']._serialized_start=141
29
- _globals['_TASK']._serialized_end=387
30
- _globals['_TASKINS']._serialized_start=389
31
- _globals['_TASKINS']._serialized_end=481
32
- _globals['_TASKRES']._serialized_start=483
33
- _globals['_TASKRES']._serialized_end=575
29
+ _globals['_TASK']._serialized_end=406
30
+ _globals['_TASKINS']._serialized_start=408
31
+ _globals['_TASKINS']._serialized_end=500
32
+ _globals['_TASKRES']._serialized_start=502
33
+ _globals['_TASKRES']._serialized_end=594
34
34
  # @@protoc_insertion_point(module_scope)
flwr/proto/task_pb2.pyi CHANGED
@@ -20,6 +20,7 @@ class Task(google.protobuf.message.Message):
20
20
  CONSUMER_FIELD_NUMBER: builtins.int
21
21
  CREATED_AT_FIELD_NUMBER: builtins.int
22
22
  DELIVERED_AT_FIELD_NUMBER: builtins.int
23
+ PUSHED_AT_FIELD_NUMBER: builtins.int
23
24
  TTL_FIELD_NUMBER: builtins.int
24
25
  ANCESTRY_FIELD_NUMBER: builtins.int
25
26
  TASK_TYPE_FIELD_NUMBER: builtins.int
@@ -29,9 +30,10 @@ class Task(google.protobuf.message.Message):
29
30
  def producer(self) -> flwr.proto.node_pb2.Node: ...
30
31
  @property
31
32
  def consumer(self) -> flwr.proto.node_pb2.Node: ...
32
- created_at: typing.Text
33
+ created_at: builtins.float
33
34
  delivered_at: typing.Text
34
- ttl: typing.Text
35
+ pushed_at: builtins.float
36
+ ttl: builtins.float
35
37
  @property
36
38
  def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
37
39
  task_type: typing.Text
@@ -43,16 +45,17 @@ class Task(google.protobuf.message.Message):
43
45
  *,
44
46
  producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
45
47
  consumer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
46
- created_at: typing.Text = ...,
48
+ created_at: builtins.float = ...,
47
49
  delivered_at: typing.Text = ...,
48
- ttl: typing.Text = ...,
50
+ pushed_at: builtins.float = ...,
51
+ ttl: builtins.float = ...,
49
52
  ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
50
53
  task_type: typing.Text = ...,
51
54
  recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
52
55
  error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
53
56
  ) -> None: ...
54
57
  def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
55
- def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
58
+ def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","pushed_at",b"pushed_at","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
56
59
  global___Task = Task
57
60
 
58
61
  class TaskIns(google.protobuf.message.Message):
@@ -19,7 +19,7 @@ import time
19
19
  from typing import List, Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import MessageType, MessageTypeLegacy, RecordSet
22
+ from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
24
  from flwr.common import serde
25
25
  from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
@@ -129,8 +129,16 @@ class DriverClientProxy(ClientProxy):
129
129
  ),
130
130
  task_type=task_type,
131
131
  recordset=serde.recordset_to_proto(recordset),
132
+ ttl=DEFAULT_TTL,
132
133
  ),
133
134
  )
135
+
136
+ # This would normally be recorded upon common.Message creation
137
+ # but this compatibility stack doesn't create Messages,
138
+ # so we need to inject `created_at` manually (needed for
139
+ # taskins validation by server.utils.validator)
140
+ task_ins.task.created_at = time.time()
141
+
134
142
  push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
135
143
  task_ins_list=[task_ins]
136
144
  )
@@ -18,7 +18,7 @@
18
18
  import time
19
19
  from typing import Iterable, List, Optional, Tuple
20
20
 
21
- from flwr.common import Message, Metadata, RecordSet
21
+ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
22
22
  from flwr.common.serde import message_from_taskres, message_to_taskins
23
23
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
24
24
  CreateRunRequest,
@@ -81,6 +81,7 @@ class Driver:
81
81
  and message.metadata.src_node_id == self.node.node_id
82
82
  and message.metadata.message_id == ""
83
83
  and message.metadata.reply_to_message == ""
84
+ and message.metadata.ttl > 0
84
85
  ):
85
86
  raise ValueError(f"Invalid message: {message}")
86
87
 
@@ -90,7 +91,7 @@ class Driver:
90
91
  message_type: str,
91
92
  dst_node_id: int,
92
93
  group_id: str,
93
- ttl: str,
94
+ ttl: float = DEFAULT_TTL,
94
95
  ) -> Message:
95
96
  """Create a new message with specified parameters.
96
97
 
@@ -110,10 +111,10 @@ class Driver:
110
111
  group_id : str
111
112
  The ID of the group to which this message is associated. In some settings,
112
113
  this is used as the FL round.
113
- ttl : str
114
+ ttl : float (default: common.DEFAULT_TTL)
114
115
  Time-to-live for the round trip of this message, i.e., the time from sending
115
- this message to receiving a reply. It specifies the duration for which the
116
- message and its potential reply are considered valid.
116
+ this message to receiving a reply. It specifies in seconds the duration for
117
+ which the message and its potential reply are considered valid.
117
118
 
118
119
  Returns
119
120
  -------
@@ -15,6 +15,7 @@
15
15
  """Driver API servicer."""
16
16
 
17
17
 
18
+ import time
18
19
  from logging import DEBUG, INFO
19
20
  from typing import List, Optional, Set
20
21
  from uuid import UUID
@@ -72,6 +73,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
72
73
  """Push a set of TaskIns."""
73
74
  log(DEBUG, "DriverServicer.PushTaskIns")
74
75
 
76
+ # Set pushed_at (timestamp in seconds)
77
+ pushed_at = time.time()
78
+ for task_ins in request.task_ins_list:
79
+ task_ins.task.pushed_at = pushed_at
80
+
75
81
  # Validate request
76
82
  _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
77
83
  for task_ins in request.task_ins_list:
@@ -15,7 +15,7 @@
15
15
  """Fleet API gRPC request-response servicer."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG, INFO
19
19
 
20
20
  import grpc
21
21
 
@@ -26,6 +26,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
26
26
  CreateNodeResponse,
27
27
  DeleteNodeRequest,
28
28
  DeleteNodeResponse,
29
+ PingRequest,
30
+ PingResponse,
29
31
  PullTaskInsRequest,
30
32
  PullTaskInsResponse,
31
33
  PushTaskResRequest,
@@ -61,6 +63,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
61
63
  state=self.state_factory.state(),
62
64
  )
63
65
 
66
+ def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
67
+ """."""
68
+ log(DEBUG, "FleetServicer.Ping")
69
+ return message_handler.ping(
70
+ request=request,
71
+ state=self.state_factory.state(),
72
+ )
73
+
64
74
  def PullTaskIns(
65
75
  self, request: PullTaskInsRequest, context: grpc.ServicerContext
66
76
  ) -> PullTaskInsResponse: