flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.8.0.dev20240327__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/client_app.py +4 -4
- flwr/client/grpc_client/connection.py +2 -1
- flwr/client/message_handler/message_handler.py +3 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/common/__init__.py +2 -0
- flwr/common/message.py +34 -13
- flwr/common/serde.py +8 -2
- flwr/proto/fleet_pb2.py +19 -15
- flwr/proto/fleet_pb2.pyi +28 -0
- flwr/proto/fleet_pb2_grpc.py +33 -0
- flwr/proto/fleet_pb2_grpc.pyi +10 -0
- flwr/proto/task_pb2.py +6 -6
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/compat/driver_client_proxy.py +9 -1
- flwr/server/driver/driver.py +6 -5
- flwr/server/superlink/driver/driver_servicer.py +6 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +11 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +14 -0
- flwr/server/superlink/state/in_memory_state.py +38 -26
- flwr/server/superlink/state/sqlite_state.py +42 -21
- flwr/server/superlink/state/state.py +19 -0
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +4 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +5 -4
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/RECORD +30 -30
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/entry_points.txt +0 -0
flwr/client/client_app.py
CHANGED
|
@@ -115,7 +115,7 @@ class ClientApp:
|
|
|
115
115
|
>>> def train(message: Message, context: Context) -> Message:
|
|
116
116
|
>>> print("ClientApp training running")
|
|
117
117
|
>>> # Create and return an echo reply message
|
|
118
|
-
>>> return message.create_reply(content=message.content()
|
|
118
|
+
>>> return message.create_reply(content=message.content())
|
|
119
119
|
"""
|
|
120
120
|
|
|
121
121
|
def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable:
|
|
@@ -143,7 +143,7 @@ class ClientApp:
|
|
|
143
143
|
>>> def evaluate(message: Message, context: Context) -> Message:
|
|
144
144
|
>>> print("ClientApp evaluation running")
|
|
145
145
|
>>> # Create and return an echo reply message
|
|
146
|
-
>>> return message.create_reply(content=message.content()
|
|
146
|
+
>>> return message.create_reply(content=message.content())
|
|
147
147
|
"""
|
|
148
148
|
|
|
149
149
|
def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable:
|
|
@@ -171,7 +171,7 @@ class ClientApp:
|
|
|
171
171
|
>>> def query(message: Message, context: Context) -> Message:
|
|
172
172
|
>>> print("ClientApp query running")
|
|
173
173
|
>>> # Create and return an echo reply message
|
|
174
|
-
>>> return message.create_reply(content=message.content()
|
|
174
|
+
>>> return message.create_reply(content=message.content())
|
|
175
175
|
"""
|
|
176
176
|
|
|
177
177
|
def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:
|
|
@@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
|
218
218
|
>>> print("ClientApp {fn_name} running")
|
|
219
219
|
>>> # Create and return an echo reply message
|
|
220
220
|
>>> return message.create_reply(
|
|
221
|
-
>>> content=message.content()
|
|
221
|
+
>>> content=message.content()
|
|
222
222
|
>>> )
|
|
223
223
|
""",
|
|
224
224
|
)
|
|
@@ -23,6 +23,7 @@ from queue import Queue
|
|
|
23
23
|
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
24
|
|
|
25
25
|
from flwr.common import (
|
|
26
|
+
DEFAULT_TTL,
|
|
26
27
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
27
28
|
ConfigsRecord,
|
|
28
29
|
Message,
|
|
@@ -180,7 +181,7 @@ def grpc_connection( # pylint: disable=R0915
|
|
|
180
181
|
dst_node_id=0,
|
|
181
182
|
reply_to_message="",
|
|
182
183
|
group_id="",
|
|
183
|
-
ttl=
|
|
184
|
+
ttl=DEFAULT_TTL,
|
|
184
185
|
message_type=message_type,
|
|
185
186
|
),
|
|
186
187
|
content=recordset,
|
|
@@ -81,7 +81,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
|
|
|
81
81
|
reason = cast(int, disconnect_msg.disconnect_res.reason)
|
|
82
82
|
recordset = RecordSet()
|
|
83
83
|
recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
|
|
84
|
-
out_message = message.create_reply(recordset
|
|
84
|
+
out_message = message.create_reply(recordset)
|
|
85
85
|
# Return TaskRes and sleep duration
|
|
86
86
|
return out_message, sleep_duration
|
|
87
87
|
|
|
@@ -143,7 +143,7 @@ def handle_legacy_message_from_msgtype(
|
|
|
143
143
|
raise ValueError(f"Invalid message type: {message_type}")
|
|
144
144
|
|
|
145
145
|
# Return Message
|
|
146
|
-
return message.create_reply(out_recordset
|
|
146
|
+
return message.create_reply(out_recordset)
|
|
147
147
|
|
|
148
148
|
|
|
149
149
|
def _reconnect(
|
|
@@ -172,6 +172,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
|
172
172
|
and out_meta.reply_to_message == in_meta.message_id
|
|
173
173
|
and out_meta.group_id == in_meta.group_id
|
|
174
174
|
and out_meta.message_type == in_meta.message_type
|
|
175
|
+
and out_meta.created_at > in_meta.created_at
|
|
175
176
|
):
|
|
176
177
|
return True
|
|
177
178
|
return False
|
|
@@ -187,7 +187,7 @@ def secaggplus_mod(
|
|
|
187
187
|
|
|
188
188
|
# Return message
|
|
189
189
|
out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
|
|
190
|
-
return msg.create_reply(out_content
|
|
190
|
+
return msg.create_reply(out_content)
|
|
191
191
|
|
|
192
192
|
|
|
193
193
|
def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
flwr/common/__init__.py
CHANGED
|
@@ -22,6 +22,7 @@ from .date import now as now
|
|
|
22
22
|
from .grpc import GRPC_MAX_MESSAGE_LENGTH
|
|
23
23
|
from .logger import configure as configure
|
|
24
24
|
from .logger import log as log
|
|
25
|
+
from .message import DEFAULT_TTL
|
|
25
26
|
from .message import Error as Error
|
|
26
27
|
from .message import Message as Message
|
|
27
28
|
from .message import Metadata as Metadata
|
|
@@ -87,6 +88,7 @@ __all__ = [
|
|
|
87
88
|
"Message",
|
|
88
89
|
"MessageType",
|
|
89
90
|
"MessageTypeLegacy",
|
|
91
|
+
"DEFAULT_TTL",
|
|
90
92
|
"Metadata",
|
|
91
93
|
"Metrics",
|
|
92
94
|
"MetricsAggregationFn",
|
flwr/common/message.py
CHANGED
|
@@ -16,10 +16,13 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
+
import time
|
|
19
20
|
from dataclasses import dataclass
|
|
20
21
|
|
|
21
22
|
from .record import RecordSet
|
|
22
23
|
|
|
24
|
+
DEFAULT_TTL = 3600
|
|
25
|
+
|
|
23
26
|
|
|
24
27
|
@dataclass
|
|
25
28
|
class Metadata: # pylint: disable=too-many-instance-attributes
|
|
@@ -40,8 +43,8 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
40
43
|
group_id : str
|
|
41
44
|
An identifier for grouping messages. In some settings,
|
|
42
45
|
this is used as the FL round.
|
|
43
|
-
ttl :
|
|
44
|
-
Time-to-live for this message.
|
|
46
|
+
ttl : float
|
|
47
|
+
Time-to-live for this message in seconds.
|
|
45
48
|
message_type : str
|
|
46
49
|
A string that encodes the action to be executed on
|
|
47
50
|
the receiving end.
|
|
@@ -57,9 +60,10 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
57
60
|
_dst_node_id: int
|
|
58
61
|
_reply_to_message: str
|
|
59
62
|
_group_id: str
|
|
60
|
-
_ttl:
|
|
63
|
+
_ttl: float
|
|
61
64
|
_message_type: str
|
|
62
65
|
_partition_id: int | None
|
|
66
|
+
_created_at: float # Unix timestamp (in seconds) to be set upon message creation
|
|
63
67
|
|
|
64
68
|
def __init__( # pylint: disable=too-many-arguments
|
|
65
69
|
self,
|
|
@@ -69,7 +73,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
69
73
|
dst_node_id: int,
|
|
70
74
|
reply_to_message: str,
|
|
71
75
|
group_id: str,
|
|
72
|
-
ttl:
|
|
76
|
+
ttl: float,
|
|
73
77
|
message_type: str,
|
|
74
78
|
partition_id: int | None = None,
|
|
75
79
|
) -> None:
|
|
@@ -124,12 +128,22 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
124
128
|
self._group_id = value
|
|
125
129
|
|
|
126
130
|
@property
|
|
127
|
-
def
|
|
131
|
+
def created_at(self) -> float:
|
|
132
|
+
"""Unix timestamp when the message was created."""
|
|
133
|
+
return self._created_at
|
|
134
|
+
|
|
135
|
+
@created_at.setter
|
|
136
|
+
def created_at(self, value: float) -> None:
|
|
137
|
+
"""Set creation timestamp for this messages."""
|
|
138
|
+
self._created_at = value
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def ttl(self) -> float:
|
|
128
142
|
"""Time-to-live for this message."""
|
|
129
143
|
return self._ttl
|
|
130
144
|
|
|
131
145
|
@ttl.setter
|
|
132
|
-
def ttl(self, value:
|
|
146
|
+
def ttl(self, value: float) -> None:
|
|
133
147
|
"""Set ttl."""
|
|
134
148
|
self._ttl = value
|
|
135
149
|
|
|
@@ -212,6 +226,9 @@ class Message:
|
|
|
212
226
|
) -> None:
|
|
213
227
|
self._metadata = metadata
|
|
214
228
|
|
|
229
|
+
# Set message creation timestamp
|
|
230
|
+
self._metadata.created_at = time.time()
|
|
231
|
+
|
|
215
232
|
if not (content is None) ^ (error is None):
|
|
216
233
|
raise ValueError("Either `content` or `error` must be set, but not both.")
|
|
217
234
|
|
|
@@ -266,7 +283,7 @@ class Message:
|
|
|
266
283
|
"""Return True if message has an error, else False."""
|
|
267
284
|
return self._error is not None
|
|
268
285
|
|
|
269
|
-
def _create_reply_metadata(self, ttl:
|
|
286
|
+
def _create_reply_metadata(self, ttl: float) -> Metadata:
|
|
270
287
|
"""Construct metadata for a reply message."""
|
|
271
288
|
return Metadata(
|
|
272
289
|
run_id=self.metadata.run_id,
|
|
@@ -283,7 +300,7 @@ class Message:
|
|
|
283
300
|
def create_error_reply(
|
|
284
301
|
self,
|
|
285
302
|
error: Error,
|
|
286
|
-
ttl:
|
|
303
|
+
ttl: float,
|
|
287
304
|
) -> Message:
|
|
288
305
|
"""Construct a reply message indicating an error happened.
|
|
289
306
|
|
|
@@ -291,14 +308,14 @@ class Message:
|
|
|
291
308
|
----------
|
|
292
309
|
error : Error
|
|
293
310
|
The error that was encountered.
|
|
294
|
-
ttl :
|
|
295
|
-
Time-to-live for this message.
|
|
311
|
+
ttl : float
|
|
312
|
+
Time-to-live for this message in seconds.
|
|
296
313
|
"""
|
|
297
314
|
# Create reply with error
|
|
298
315
|
message = Message(metadata=self._create_reply_metadata(ttl), error=error)
|
|
299
316
|
return message
|
|
300
317
|
|
|
301
|
-
def create_reply(self, content: RecordSet, ttl:
|
|
318
|
+
def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
|
|
302
319
|
"""Create a reply to this message with specified content and TTL.
|
|
303
320
|
|
|
304
321
|
The method generates a new `Message` as a reply to this message.
|
|
@@ -309,14 +326,18 @@ class Message:
|
|
|
309
326
|
----------
|
|
310
327
|
content : RecordSet
|
|
311
328
|
The content for the reply message.
|
|
312
|
-
ttl :
|
|
313
|
-
Time-to-live for this message.
|
|
329
|
+
ttl : Optional[float] (default: None)
|
|
330
|
+
Time-to-live for this message in seconds. If unset, it will use
|
|
331
|
+
the `common.DEFAULT_TTL` value.
|
|
314
332
|
|
|
315
333
|
Returns
|
|
316
334
|
-------
|
|
317
335
|
Message
|
|
318
336
|
A new `Message` instance representing the reply.
|
|
319
337
|
"""
|
|
338
|
+
if ttl is None:
|
|
339
|
+
ttl = DEFAULT_TTL
|
|
340
|
+
|
|
320
341
|
return Message(
|
|
321
342
|
metadata=self._create_reply_metadata(ttl),
|
|
322
343
|
content=content,
|
flwr/common/serde.py
CHANGED
|
@@ -575,6 +575,7 @@ def message_to_taskins(message: Message) -> TaskIns:
|
|
|
575
575
|
task=Task(
|
|
576
576
|
producer=Node(node_id=0, anonymous=True), # Assume driver node
|
|
577
577
|
consumer=Node(node_id=md.dst_node_id, anonymous=False),
|
|
578
|
+
created_at=md.created_at,
|
|
578
579
|
ttl=md.ttl,
|
|
579
580
|
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
|
|
580
581
|
task_type=md.message_type,
|
|
@@ -601,7 +602,7 @@ def message_from_taskins(taskins: TaskIns) -> Message:
|
|
|
601
602
|
)
|
|
602
603
|
|
|
603
604
|
# Construct Message
|
|
604
|
-
|
|
605
|
+
message = Message(
|
|
605
606
|
metadata=metadata,
|
|
606
607
|
content=(
|
|
607
608
|
recordset_from_proto(taskins.task.recordset)
|
|
@@ -614,6 +615,8 @@ def message_from_taskins(taskins: TaskIns) -> Message:
|
|
|
614
615
|
else None
|
|
615
616
|
),
|
|
616
617
|
)
|
|
618
|
+
message.metadata.created_at = taskins.task.created_at
|
|
619
|
+
return message
|
|
617
620
|
|
|
618
621
|
|
|
619
622
|
def message_to_taskres(message: Message) -> TaskRes:
|
|
@@ -626,6 +629,7 @@ def message_to_taskres(message: Message) -> TaskRes:
|
|
|
626
629
|
task=Task(
|
|
627
630
|
producer=Node(node_id=md.src_node_id, anonymous=False),
|
|
628
631
|
consumer=Node(node_id=0, anonymous=True), # Assume driver node
|
|
632
|
+
created_at=md.created_at,
|
|
629
633
|
ttl=md.ttl,
|
|
630
634
|
ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
|
|
631
635
|
task_type=md.message_type,
|
|
@@ -652,7 +656,7 @@ def message_from_taskres(taskres: TaskRes) -> Message:
|
|
|
652
656
|
)
|
|
653
657
|
|
|
654
658
|
# Construct the Message
|
|
655
|
-
|
|
659
|
+
message = Message(
|
|
656
660
|
metadata=metadata,
|
|
657
661
|
content=(
|
|
658
662
|
recordset_from_proto(taskres.task.recordset)
|
|
@@ -665,3 +669,5 @@ def message_from_taskres(taskres: TaskRes) -> Message:
|
|
|
665
669
|
else None
|
|
666
670
|
),
|
|
667
671
|
)
|
|
672
|
+
message.metadata.created_at = taskres.task.created_at
|
|
673
|
+
return message
|
flwr/proto/fleet_pb2.py
CHANGED
|
@@ -16,7 +16,7 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
|
16
16
|
from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\
|
|
19
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
|
|
20
20
|
|
|
21
21
|
_globals = globals()
|
|
22
22
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -33,18 +33,22 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
33
33
|
_globals['_DELETENODEREQUEST']._serialized_end=210
|
|
34
34
|
_globals['_DELETENODERESPONSE']._serialized_start=212
|
|
35
35
|
_globals['_DELETENODERESPONSE']._serialized_end=232
|
|
36
|
-
_globals['
|
|
37
|
-
_globals['
|
|
38
|
-
_globals['
|
|
39
|
-
_globals['
|
|
40
|
-
_globals['
|
|
41
|
-
_globals['
|
|
42
|
-
_globals['
|
|
43
|
-
_globals['
|
|
44
|
-
_globals['
|
|
45
|
-
_globals['
|
|
46
|
-
_globals['
|
|
47
|
-
_globals['
|
|
48
|
-
_globals['
|
|
49
|
-
_globals['
|
|
36
|
+
_globals['_PINGREQUEST']._serialized_start=234
|
|
37
|
+
_globals['_PINGREQUEST']._serialized_end=302
|
|
38
|
+
_globals['_PINGRESPONSE']._serialized_start=304
|
|
39
|
+
_globals['_PINGRESPONSE']._serialized_end=335
|
|
40
|
+
_globals['_PULLTASKINSREQUEST']._serialized_start=337
|
|
41
|
+
_globals['_PULLTASKINSREQUEST']._serialized_end=407
|
|
42
|
+
_globals['_PULLTASKINSRESPONSE']._serialized_start=409
|
|
43
|
+
_globals['_PULLTASKINSRESPONSE']._serialized_end=516
|
|
44
|
+
_globals['_PUSHTASKRESREQUEST']._serialized_start=518
|
|
45
|
+
_globals['_PUSHTASKRESREQUEST']._serialized_end=582
|
|
46
|
+
_globals['_PUSHTASKRESRESPONSE']._serialized_start=585
|
|
47
|
+
_globals['_PUSHTASKRESRESPONSE']._serialized_end=759
|
|
48
|
+
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=713
|
|
49
|
+
_globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=759
|
|
50
|
+
_globals['_RECONNECT']._serialized_start=761
|
|
51
|
+
_globals['_RECONNECT']._serialized_end=791
|
|
52
|
+
_globals['_FLEET']._serialized_start=794
|
|
53
|
+
_globals['_FLEET']._serialized_end=1184
|
|
50
54
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/fleet_pb2.pyi
CHANGED
|
@@ -53,6 +53,34 @@ class DeleteNodeResponse(google.protobuf.message.Message):
|
|
|
53
53
|
) -> None: ...
|
|
54
54
|
global___DeleteNodeResponse = DeleteNodeResponse
|
|
55
55
|
|
|
56
|
+
class PingRequest(google.protobuf.message.Message):
|
|
57
|
+
"""Ping messages"""
|
|
58
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
59
|
+
NODE_FIELD_NUMBER: builtins.int
|
|
60
|
+
PING_INTERVAL_FIELD_NUMBER: builtins.int
|
|
61
|
+
@property
|
|
62
|
+
def node(self) -> flwr.proto.node_pb2.Node: ...
|
|
63
|
+
ping_interval: builtins.float
|
|
64
|
+
def __init__(self,
|
|
65
|
+
*,
|
|
66
|
+
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
|
67
|
+
ping_interval: builtins.float = ...,
|
|
68
|
+
) -> None: ...
|
|
69
|
+
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
|
|
70
|
+
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","ping_interval",b"ping_interval"]) -> None: ...
|
|
71
|
+
global___PingRequest = PingRequest
|
|
72
|
+
|
|
73
|
+
class PingResponse(google.protobuf.message.Message):
|
|
74
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
75
|
+
SUCCESS_FIELD_NUMBER: builtins.int
|
|
76
|
+
success: builtins.bool
|
|
77
|
+
def __init__(self,
|
|
78
|
+
*,
|
|
79
|
+
success: builtins.bool = ...,
|
|
80
|
+
) -> None: ...
|
|
81
|
+
def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ...
|
|
82
|
+
global___PingResponse = PingResponse
|
|
83
|
+
|
|
56
84
|
class PullTaskInsRequest(google.protobuf.message.Message):
|
|
57
85
|
"""PullTaskIns messages"""
|
|
58
86
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
flwr/proto/fleet_pb2_grpc.py
CHANGED
|
@@ -24,6 +24,11 @@ class FleetStub(object):
|
|
|
24
24
|
request_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.SerializeToString,
|
|
25
25
|
response_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.FromString,
|
|
26
26
|
)
|
|
27
|
+
self.Ping = channel.unary_unary(
|
|
28
|
+
'/flwr.proto.Fleet/Ping',
|
|
29
|
+
request_serializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString,
|
|
30
|
+
response_deserializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString,
|
|
31
|
+
)
|
|
27
32
|
self.PullTaskIns = channel.unary_unary(
|
|
28
33
|
'/flwr.proto.Fleet/PullTaskIns',
|
|
29
34
|
request_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString,
|
|
@@ -51,6 +56,12 @@ class FleetServicer(object):
|
|
|
51
56
|
context.set_details('Method not implemented!')
|
|
52
57
|
raise NotImplementedError('Method not implemented!')
|
|
53
58
|
|
|
59
|
+
def Ping(self, request, context):
|
|
60
|
+
"""Missing associated documentation comment in .proto file."""
|
|
61
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
62
|
+
context.set_details('Method not implemented!')
|
|
63
|
+
raise NotImplementedError('Method not implemented!')
|
|
64
|
+
|
|
54
65
|
def PullTaskIns(self, request, context):
|
|
55
66
|
"""Retrieve one or more tasks, if possible
|
|
56
67
|
|
|
@@ -82,6 +93,11 @@ def add_FleetServicer_to_server(servicer, server):
|
|
|
82
93
|
request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString,
|
|
83
94
|
response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString,
|
|
84
95
|
),
|
|
96
|
+
'Ping': grpc.unary_unary_rpc_method_handler(
|
|
97
|
+
servicer.Ping,
|
|
98
|
+
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.FromString,
|
|
99
|
+
response_serializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.SerializeToString,
|
|
100
|
+
),
|
|
85
101
|
'PullTaskIns': grpc.unary_unary_rpc_method_handler(
|
|
86
102
|
servicer.PullTaskIns,
|
|
87
103
|
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString,
|
|
@@ -136,6 +152,23 @@ class Fleet(object):
|
|
|
136
152
|
options, channel_credentials,
|
|
137
153
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
138
154
|
|
|
155
|
+
@staticmethod
|
|
156
|
+
def Ping(request,
|
|
157
|
+
target,
|
|
158
|
+
options=(),
|
|
159
|
+
channel_credentials=None,
|
|
160
|
+
call_credentials=None,
|
|
161
|
+
insecure=False,
|
|
162
|
+
compression=None,
|
|
163
|
+
wait_for_ready=None,
|
|
164
|
+
timeout=None,
|
|
165
|
+
metadata=None):
|
|
166
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/Ping',
|
|
167
|
+
flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString,
|
|
168
|
+
flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString,
|
|
169
|
+
options, channel_credentials,
|
|
170
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
171
|
+
|
|
139
172
|
@staticmethod
|
|
140
173
|
def PullTaskIns(request,
|
|
141
174
|
target,
|
flwr/proto/fleet_pb2_grpc.pyi
CHANGED
|
@@ -16,6 +16,10 @@ class FleetStub:
|
|
|
16
16
|
flwr.proto.fleet_pb2.DeleteNodeRequest,
|
|
17
17
|
flwr.proto.fleet_pb2.DeleteNodeResponse]
|
|
18
18
|
|
|
19
|
+
Ping: grpc.UnaryUnaryMultiCallable[
|
|
20
|
+
flwr.proto.fleet_pb2.PingRequest,
|
|
21
|
+
flwr.proto.fleet_pb2.PingResponse]
|
|
22
|
+
|
|
19
23
|
PullTaskIns: grpc.UnaryUnaryMultiCallable[
|
|
20
24
|
flwr.proto.fleet_pb2.PullTaskInsRequest,
|
|
21
25
|
flwr.proto.fleet_pb2.PullTaskInsResponse]
|
|
@@ -46,6 +50,12 @@ class FleetServicer(metaclass=abc.ABCMeta):
|
|
|
46
50
|
context: grpc.ServicerContext,
|
|
47
51
|
) -> flwr.proto.fleet_pb2.DeleteNodeResponse: ...
|
|
48
52
|
|
|
53
|
+
@abc.abstractmethod
|
|
54
|
+
def Ping(self,
|
|
55
|
+
request: flwr.proto.fleet_pb2.PingRequest,
|
|
56
|
+
context: grpc.ServicerContext,
|
|
57
|
+
) -> flwr.proto.fleet_pb2.PingResponse: ...
|
|
58
|
+
|
|
49
59
|
@abc.abstractmethod
|
|
50
60
|
def PullTaskIns(self,
|
|
51
61
|
request: flwr.proto.fleet_pb2.PullTaskInsRequest,
|
flwr/proto/task_pb2.py
CHANGED
|
@@ -18,7 +18,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
|
18
18
|
from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\
|
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
|
|
22
22
|
|
|
23
23
|
_globals = globals()
|
|
24
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -26,9 +26,9 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _glob
|
|
|
26
26
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
27
27
|
DESCRIPTOR._options = None
|
|
28
28
|
_globals['_TASK']._serialized_start=141
|
|
29
|
-
_globals['_TASK']._serialized_end=
|
|
30
|
-
_globals['_TASKINS']._serialized_start=
|
|
31
|
-
_globals['_TASKINS']._serialized_end=
|
|
32
|
-
_globals['_TASKRES']._serialized_start=
|
|
33
|
-
_globals['_TASKRES']._serialized_end=
|
|
29
|
+
_globals['_TASK']._serialized_end=406
|
|
30
|
+
_globals['_TASKINS']._serialized_start=408
|
|
31
|
+
_globals['_TASKINS']._serialized_end=500
|
|
32
|
+
_globals['_TASKRES']._serialized_start=502
|
|
33
|
+
_globals['_TASKRES']._serialized_end=594
|
|
34
34
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/task_pb2.pyi
CHANGED
|
@@ -20,6 +20,7 @@ class Task(google.protobuf.message.Message):
|
|
|
20
20
|
CONSUMER_FIELD_NUMBER: builtins.int
|
|
21
21
|
CREATED_AT_FIELD_NUMBER: builtins.int
|
|
22
22
|
DELIVERED_AT_FIELD_NUMBER: builtins.int
|
|
23
|
+
PUSHED_AT_FIELD_NUMBER: builtins.int
|
|
23
24
|
TTL_FIELD_NUMBER: builtins.int
|
|
24
25
|
ANCESTRY_FIELD_NUMBER: builtins.int
|
|
25
26
|
TASK_TYPE_FIELD_NUMBER: builtins.int
|
|
@@ -29,9 +30,10 @@ class Task(google.protobuf.message.Message):
|
|
|
29
30
|
def producer(self) -> flwr.proto.node_pb2.Node: ...
|
|
30
31
|
@property
|
|
31
32
|
def consumer(self) -> flwr.proto.node_pb2.Node: ...
|
|
32
|
-
created_at:
|
|
33
|
+
created_at: builtins.float
|
|
33
34
|
delivered_at: typing.Text
|
|
34
|
-
|
|
35
|
+
pushed_at: builtins.float
|
|
36
|
+
ttl: builtins.float
|
|
35
37
|
@property
|
|
36
38
|
def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
|
|
37
39
|
task_type: typing.Text
|
|
@@ -43,16 +45,17 @@ class Task(google.protobuf.message.Message):
|
|
|
43
45
|
*,
|
|
44
46
|
producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
|
45
47
|
consumer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
|
46
|
-
created_at:
|
|
48
|
+
created_at: builtins.float = ...,
|
|
47
49
|
delivered_at: typing.Text = ...,
|
|
48
|
-
|
|
50
|
+
pushed_at: builtins.float = ...,
|
|
51
|
+
ttl: builtins.float = ...,
|
|
49
52
|
ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
|
|
50
53
|
task_type: typing.Text = ...,
|
|
51
54
|
recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
|
|
52
55
|
error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
|
|
53
56
|
) -> None: ...
|
|
54
57
|
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
|
55
|
-
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
|
58
|
+
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","pushed_at",b"pushed_at","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
|
56
59
|
global___Task = Task
|
|
57
60
|
|
|
58
61
|
class TaskIns(google.protobuf.message.Message):
|
|
@@ -19,7 +19,7 @@ import time
|
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
|
|
21
21
|
from flwr import common
|
|
22
|
-
from flwr.common import MessageType, MessageTypeLegacy, RecordSet
|
|
22
|
+
from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
|
|
23
23
|
from flwr.common import recordset_compat as compat
|
|
24
24
|
from flwr.common import serde
|
|
25
25
|
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
@@ -129,8 +129,16 @@ class DriverClientProxy(ClientProxy):
|
|
|
129
129
|
),
|
|
130
130
|
task_type=task_type,
|
|
131
131
|
recordset=serde.recordset_to_proto(recordset),
|
|
132
|
+
ttl=DEFAULT_TTL,
|
|
132
133
|
),
|
|
133
134
|
)
|
|
135
|
+
|
|
136
|
+
# This would normally be recorded upon common.Message creation
|
|
137
|
+
# but this compatibility stack doesn't create Messages,
|
|
138
|
+
# so we need to inject `created_at` manually (needed for
|
|
139
|
+
# taskins validation by server.utils.validator)
|
|
140
|
+
task_ins.task.created_at = time.time()
|
|
141
|
+
|
|
134
142
|
push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
|
|
135
143
|
task_ins_list=[task_ins]
|
|
136
144
|
)
|
flwr/server/driver/driver.py
CHANGED
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import time
|
|
19
19
|
from typing import Iterable, List, Optional, Tuple
|
|
20
20
|
|
|
21
|
-
from flwr.common import Message, Metadata, RecordSet
|
|
21
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
22
22
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
23
23
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
24
24
|
CreateRunRequest,
|
|
@@ -81,6 +81,7 @@ class Driver:
|
|
|
81
81
|
and message.metadata.src_node_id == self.node.node_id
|
|
82
82
|
and message.metadata.message_id == ""
|
|
83
83
|
and message.metadata.reply_to_message == ""
|
|
84
|
+
and message.metadata.ttl > 0
|
|
84
85
|
):
|
|
85
86
|
raise ValueError(f"Invalid message: {message}")
|
|
86
87
|
|
|
@@ -90,7 +91,7 @@ class Driver:
|
|
|
90
91
|
message_type: str,
|
|
91
92
|
dst_node_id: int,
|
|
92
93
|
group_id: str,
|
|
93
|
-
ttl:
|
|
94
|
+
ttl: float = DEFAULT_TTL,
|
|
94
95
|
) -> Message:
|
|
95
96
|
"""Create a new message with specified parameters.
|
|
96
97
|
|
|
@@ -110,10 +111,10 @@ class Driver:
|
|
|
110
111
|
group_id : str
|
|
111
112
|
The ID of the group to which this message is associated. In some settings,
|
|
112
113
|
this is used as the FL round.
|
|
113
|
-
ttl :
|
|
114
|
+
ttl : float (default: common.DEFAULT_TTL)
|
|
114
115
|
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
115
|
-
this message to receiving a reply. It specifies the duration for
|
|
116
|
-
message and its potential reply are considered valid.
|
|
116
|
+
this message to receiving a reply. It specifies in seconds the duration for
|
|
117
|
+
which the message and its potential reply are considered valid.
|
|
117
118
|
|
|
118
119
|
Returns
|
|
119
120
|
-------
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Driver API servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import time
|
|
18
19
|
from logging import DEBUG, INFO
|
|
19
20
|
from typing import List, Optional, Set
|
|
20
21
|
from uuid import UUID
|
|
@@ -72,6 +73,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
72
73
|
"""Push a set of TaskIns."""
|
|
73
74
|
log(DEBUG, "DriverServicer.PushTaskIns")
|
|
74
75
|
|
|
76
|
+
# Set pushed_at (timestamp in seconds)
|
|
77
|
+
pushed_at = time.time()
|
|
78
|
+
for task_ins in request.task_ins_list:
|
|
79
|
+
task_ins.task.pushed_at = pushed_at
|
|
80
|
+
|
|
75
81
|
# Validate request
|
|
76
82
|
_raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
|
|
77
83
|
for task_ins in request.task_ins_list:
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Fleet API gRPC request-response servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from logging import INFO
|
|
18
|
+
from logging import DEBUG, INFO
|
|
19
19
|
|
|
20
20
|
import grpc
|
|
21
21
|
|
|
@@ -26,6 +26,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
26
26
|
CreateNodeResponse,
|
|
27
27
|
DeleteNodeRequest,
|
|
28
28
|
DeleteNodeResponse,
|
|
29
|
+
PingRequest,
|
|
30
|
+
PingResponse,
|
|
29
31
|
PullTaskInsRequest,
|
|
30
32
|
PullTaskInsResponse,
|
|
31
33
|
PushTaskResRequest,
|
|
@@ -61,6 +63,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
61
63
|
state=self.state_factory.state(),
|
|
62
64
|
)
|
|
63
65
|
|
|
66
|
+
def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
|
|
67
|
+
"""."""
|
|
68
|
+
log(DEBUG, "FleetServicer.Ping")
|
|
69
|
+
return message_handler.ping(
|
|
70
|
+
request=request,
|
|
71
|
+
state=self.state_factory.state(),
|
|
72
|
+
)
|
|
73
|
+
|
|
64
74
|
def PullTaskIns(
|
|
65
75
|
self, request: PullTaskInsRequest, context: grpc.ServicerContext
|
|
66
76
|
) -> PullTaskInsResponse:
|