flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.8.0.dev20240327__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/client_app.py +4 -4
- flwr/client/grpc_client/connection.py +2 -1
- flwr/client/message_handler/message_handler.py +3 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/common/__init__.py +2 -0
- flwr/common/message.py +34 -13
- flwr/common/serde.py +8 -2
- flwr/proto/fleet_pb2.py +19 -15
- flwr/proto/fleet_pb2.pyi +28 -0
- flwr/proto/fleet_pb2_grpc.py +33 -0
- flwr/proto/fleet_pb2_grpc.pyi +10 -0
- flwr/proto/task_pb2.py +6 -6
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/compat/driver_client_proxy.py +9 -1
- flwr/server/driver/driver.py +6 -5
- flwr/server/superlink/driver/driver_servicer.py +6 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +11 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +14 -0
- flwr/server/superlink/state/in_memory_state.py +38 -26
- flwr/server/superlink/state/sqlite_state.py +42 -21
- flwr/server/superlink/state/state.py +19 -0
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +4 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +5 -4
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/RECORD +30 -30
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.8.0.dev20240327.dist-info}/entry_points.txt +0 -0
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Fleet API message handlers."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import time
|
|
18
19
|
from typing import List, Optional
|
|
19
20
|
from uuid import UUID
|
|
20
21
|
|
|
@@ -23,6 +24,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
23
24
|
CreateNodeResponse,
|
|
24
25
|
DeleteNodeRequest,
|
|
25
26
|
DeleteNodeResponse,
|
|
27
|
+
PingRequest,
|
|
28
|
+
PingResponse,
|
|
26
29
|
PullTaskInsRequest,
|
|
27
30
|
PullTaskInsResponse,
|
|
28
31
|
PushTaskResRequest,
|
|
@@ -55,6 +58,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
|
|
|
55
58
|
return DeleteNodeResponse()
|
|
56
59
|
|
|
57
60
|
|
|
61
|
+
def ping(
|
|
62
|
+
request: PingRequest, # pylint: disable=unused-argument
|
|
63
|
+
state: State, # pylint: disable=unused-argument
|
|
64
|
+
) -> PingResponse:
|
|
65
|
+
"""."""
|
|
66
|
+
return PingResponse(success=True)
|
|
67
|
+
|
|
68
|
+
|
|
58
69
|
def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
|
|
59
70
|
"""Pull TaskIns handler."""
|
|
60
71
|
# Get node_id if client node is not anonymous
|
|
@@ -77,6 +88,9 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
77
88
|
task_res: TaskRes = request.task_res_list[0]
|
|
78
89
|
# pylint: enable=no-member
|
|
79
90
|
|
|
91
|
+
# Set pushed_at (timestamp in seconds)
|
|
92
|
+
task_res.task.pushed_at = time.time()
|
|
93
|
+
|
|
80
94
|
# Store TaskRes in State
|
|
81
95
|
task_id: Optional[UUID] = state.store_task_res(task_res=task_res)
|
|
82
96
|
|
|
@@ -17,9 +17,9 @@
|
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
19
|
import threading
|
|
20
|
-
|
|
20
|
+
import time
|
|
21
21
|
from logging import ERROR
|
|
22
|
-
from typing import Dict, List, Optional, Set
|
|
22
|
+
from typing import Dict, List, Optional, Set, Tuple
|
|
23
23
|
from uuid import UUID, uuid4
|
|
24
24
|
|
|
25
25
|
from flwr.common import log, now
|
|
@@ -32,7 +32,8 @@ class InMemoryState(State):
|
|
|
32
32
|
"""In-memory State implementation."""
|
|
33
33
|
|
|
34
34
|
def __init__(self) -> None:
|
|
35
|
-
|
|
35
|
+
# Map node_id to (online_until, ping_interval)
|
|
36
|
+
self.node_ids: Dict[int, Tuple[float, float]] = {}
|
|
36
37
|
self.run_ids: Set[int] = set()
|
|
37
38
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
38
39
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
@@ -50,15 +51,11 @@ class InMemoryState(State):
|
|
|
50
51
|
log(ERROR, "`run_id` is invalid")
|
|
51
52
|
return None
|
|
52
53
|
|
|
53
|
-
# Create task_id
|
|
54
|
+
# Create task_id
|
|
54
55
|
task_id = uuid4()
|
|
55
|
-
created_at: datetime = now()
|
|
56
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
57
56
|
|
|
58
57
|
# Store TaskIns
|
|
59
58
|
task_ins.task_id = str(task_id)
|
|
60
|
-
task_ins.task.created_at = created_at.isoformat()
|
|
61
|
-
task_ins.task.ttl = ttl.isoformat()
|
|
62
59
|
with self.lock:
|
|
63
60
|
self.task_ins_store[task_id] = task_ins
|
|
64
61
|
|
|
@@ -113,15 +110,11 @@ class InMemoryState(State):
|
|
|
113
110
|
log(ERROR, "`run_id` is invalid")
|
|
114
111
|
return None
|
|
115
112
|
|
|
116
|
-
# Create task_id
|
|
113
|
+
# Create task_id
|
|
117
114
|
task_id = uuid4()
|
|
118
|
-
created_at: datetime = now()
|
|
119
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
120
115
|
|
|
121
116
|
# Store TaskRes
|
|
122
117
|
task_res.task_id = str(task_id)
|
|
123
|
-
task_res.task.created_at = created_at.isoformat()
|
|
124
|
-
task_res.task.ttl = ttl.isoformat()
|
|
125
118
|
with self.lock:
|
|
126
119
|
self.task_res_store[task_id] = task_res
|
|
127
120
|
|
|
@@ -194,17 +187,21 @@ class InMemoryState(State):
|
|
|
194
187
|
# Sample a random int64 as node_id
|
|
195
188
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
196
189
|
|
|
197
|
-
|
|
198
|
-
self.node_ids
|
|
199
|
-
|
|
190
|
+
with self.lock:
|
|
191
|
+
if node_id not in self.node_ids:
|
|
192
|
+
# Default ping interval is 30s
|
|
193
|
+
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
194
|
+
self.node_ids[node_id] = (time.time() + 1e9, 1e9)
|
|
195
|
+
return node_id
|
|
200
196
|
log(ERROR, "Unexpected node registration failure.")
|
|
201
197
|
return 0
|
|
202
198
|
|
|
203
199
|
def delete_node(self, node_id: int) -> None:
|
|
204
200
|
"""Delete a client node."""
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
201
|
+
with self.lock:
|
|
202
|
+
if node_id not in self.node_ids:
|
|
203
|
+
raise ValueError(f"Node {node_id} not found")
|
|
204
|
+
del self.node_ids[node_id]
|
|
208
205
|
|
|
209
206
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
210
207
|
"""Return all available client nodes.
|
|
@@ -214,17 +211,32 @@ class InMemoryState(State):
|
|
|
214
211
|
If the provided `run_id` does not exist or has no matching nodes,
|
|
215
212
|
an empty `Set` MUST be returned.
|
|
216
213
|
"""
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
214
|
+
with self.lock:
|
|
215
|
+
if run_id not in self.run_ids:
|
|
216
|
+
return set()
|
|
217
|
+
current_time = time.time()
|
|
218
|
+
return {
|
|
219
|
+
node_id
|
|
220
|
+
for node_id, (online_until, _) in self.node_ids.items()
|
|
221
|
+
if online_until > current_time
|
|
222
|
+
}
|
|
220
223
|
|
|
221
224
|
def create_run(self) -> int:
|
|
222
225
|
"""Create one run."""
|
|
223
226
|
# Sample a random int64 as run_id
|
|
224
|
-
|
|
227
|
+
with self.lock:
|
|
228
|
+
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
225
229
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
230
|
+
if run_id not in self.run_ids:
|
|
231
|
+
self.run_ids.add(run_id)
|
|
232
|
+
return run_id
|
|
229
233
|
log(ERROR, "Unexpected run creation failure.")
|
|
230
234
|
return 0
|
|
235
|
+
|
|
236
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
237
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
238
|
+
with self.lock:
|
|
239
|
+
if node_id in self.node_ids:
|
|
240
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
241
|
+
return True
|
|
242
|
+
return False
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import os
|
|
19
19
|
import re
|
|
20
20
|
import sqlite3
|
|
21
|
-
|
|
21
|
+
import time
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
23
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
24
24
|
from uuid import UUID, uuid4
|
|
@@ -33,10 +33,16 @@ from .state import State
|
|
|
33
33
|
|
|
34
34
|
SQL_CREATE_TABLE_NODE = """
|
|
35
35
|
CREATE TABLE IF NOT EXISTS node(
|
|
36
|
-
node_id
|
|
36
|
+
node_id INTEGER UNIQUE,
|
|
37
|
+
online_until REAL,
|
|
38
|
+
ping_interval REAL
|
|
37
39
|
);
|
|
38
40
|
"""
|
|
39
41
|
|
|
42
|
+
SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
43
|
+
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
44
|
+
"""
|
|
45
|
+
|
|
40
46
|
SQL_CREATE_TABLE_RUN = """
|
|
41
47
|
CREATE TABLE IF NOT EXISTS run(
|
|
42
48
|
run_id INTEGER UNIQUE
|
|
@@ -52,9 +58,10 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
52
58
|
producer_node_id INTEGER,
|
|
53
59
|
consumer_anonymous BOOLEAN,
|
|
54
60
|
consumer_node_id INTEGER,
|
|
55
|
-
created_at
|
|
61
|
+
created_at REAL,
|
|
56
62
|
delivered_at TEXT,
|
|
57
|
-
|
|
63
|
+
pushed_at REAL,
|
|
64
|
+
ttl REAL,
|
|
58
65
|
ancestry TEXT,
|
|
59
66
|
task_type TEXT,
|
|
60
67
|
recordset BLOB,
|
|
@@ -72,9 +79,10 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
72
79
|
producer_node_id INTEGER,
|
|
73
80
|
consumer_anonymous BOOLEAN,
|
|
74
81
|
consumer_node_id INTEGER,
|
|
75
|
-
created_at
|
|
82
|
+
created_at REAL,
|
|
76
83
|
delivered_at TEXT,
|
|
77
|
-
|
|
84
|
+
pushed_at REAL,
|
|
85
|
+
ttl REAL,
|
|
78
86
|
ancestry TEXT,
|
|
79
87
|
task_type TEXT,
|
|
80
88
|
recordset BLOB,
|
|
@@ -82,7 +90,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
82
90
|
);
|
|
83
91
|
"""
|
|
84
92
|
|
|
85
|
-
DictOrTuple = Union[Tuple[Any], Dict[str, Any]]
|
|
93
|
+
DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
|
|
86
94
|
|
|
87
95
|
|
|
88
96
|
class SqliteState(State):
|
|
@@ -123,6 +131,7 @@ class SqliteState(State):
|
|
|
123
131
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
124
132
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
125
133
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
134
|
+
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
126
135
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
127
136
|
|
|
128
137
|
return res.fetchall()
|
|
@@ -185,15 +194,11 @@ class SqliteState(State):
|
|
|
185
194
|
log(ERROR, errors)
|
|
186
195
|
return None
|
|
187
196
|
|
|
188
|
-
# Create task_id
|
|
197
|
+
# Create task_id
|
|
189
198
|
task_id = uuid4()
|
|
190
|
-
created_at: datetime = now()
|
|
191
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
192
199
|
|
|
193
200
|
# Store TaskIns
|
|
194
201
|
task_ins.task_id = str(task_id)
|
|
195
|
-
task_ins.task.created_at = created_at.isoformat()
|
|
196
|
-
task_ins.task.ttl = ttl.isoformat()
|
|
197
202
|
data = (task_ins_to_dict(task_ins),)
|
|
198
203
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
199
204
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
@@ -320,15 +325,11 @@ class SqliteState(State):
|
|
|
320
325
|
log(ERROR, errors)
|
|
321
326
|
return None
|
|
322
327
|
|
|
323
|
-
# Create task_id
|
|
328
|
+
# Create task_id
|
|
324
329
|
task_id = uuid4()
|
|
325
|
-
created_at: datetime = now()
|
|
326
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
327
330
|
|
|
328
331
|
# Store TaskIns
|
|
329
332
|
task_res.task_id = str(task_id)
|
|
330
|
-
task_res.task.created_at = created_at.isoformat()
|
|
331
|
-
task_res.task.ttl = ttl.isoformat()
|
|
332
333
|
data = (task_res_to_dict(task_res),)
|
|
333
334
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
334
335
|
query = f"INSERT INTO task_res VALUES({columns});"
|
|
@@ -472,9 +473,14 @@ class SqliteState(State):
|
|
|
472
473
|
# Sample a random int64 as node_id
|
|
473
474
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
474
475
|
|
|
475
|
-
query =
|
|
476
|
+
query = (
|
|
477
|
+
"INSERT INTO node (node_id, online_until, ping_interval) VALUES (?, ?, ?)"
|
|
478
|
+
)
|
|
479
|
+
|
|
476
480
|
try:
|
|
477
|
-
|
|
481
|
+
# Default ping interval is 30s
|
|
482
|
+
# TODO: change 1e9 to 30s # pylint: disable=W0511
|
|
483
|
+
self.query(query, (node_id, time.time() + 1e9, 1e9))
|
|
478
484
|
except sqlite3.IntegrityError:
|
|
479
485
|
log(ERROR, "Unexpected node registration failure.")
|
|
480
486
|
return 0
|
|
@@ -499,8 +505,8 @@ class SqliteState(State):
|
|
|
499
505
|
return set()
|
|
500
506
|
|
|
501
507
|
# Get nodes
|
|
502
|
-
query = "SELECT
|
|
503
|
-
rows = self.query(query)
|
|
508
|
+
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
|
509
|
+
rows = self.query(query, (time.time(),))
|
|
504
510
|
result: Set[int] = {row["node_id"] for row in rows}
|
|
505
511
|
return result
|
|
506
512
|
|
|
@@ -519,6 +525,17 @@ class SqliteState(State):
|
|
|
519
525
|
log(ERROR, "Unexpected run creation failure.")
|
|
520
526
|
return 0
|
|
521
527
|
|
|
528
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
529
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
530
|
+
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
531
|
+
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
|
|
532
|
+
try:
|
|
533
|
+
self.query(query, (time.time() + ping_interval, ping_interval, node_id))
|
|
534
|
+
return True
|
|
535
|
+
except sqlite3.IntegrityError:
|
|
536
|
+
log(ERROR, "`node_id` does not exist.")
|
|
537
|
+
return False
|
|
538
|
+
|
|
522
539
|
|
|
523
540
|
def dict_factory(
|
|
524
541
|
cursor: sqlite3.Cursor,
|
|
@@ -544,6 +561,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
|
|
|
544
561
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
545
562
|
"created_at": task_msg.task.created_at,
|
|
546
563
|
"delivered_at": task_msg.task.delivered_at,
|
|
564
|
+
"pushed_at": task_msg.task.pushed_at,
|
|
547
565
|
"ttl": task_msg.task.ttl,
|
|
548
566
|
"ancestry": ",".join(task_msg.task.ancestry),
|
|
549
567
|
"task_type": task_msg.task.task_type,
|
|
@@ -564,6 +582,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
|
|
|
564
582
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
565
583
|
"created_at": task_msg.task.created_at,
|
|
566
584
|
"delivered_at": task_msg.task.delivered_at,
|
|
585
|
+
"pushed_at": task_msg.task.pushed_at,
|
|
567
586
|
"ttl": task_msg.task.ttl,
|
|
568
587
|
"ancestry": ",".join(task_msg.task.ancestry),
|
|
569
588
|
"task_type": task_msg.task.task_type,
|
|
@@ -592,6 +611,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
|
|
|
592
611
|
),
|
|
593
612
|
created_at=task_dict["created_at"],
|
|
594
613
|
delivered_at=task_dict["delivered_at"],
|
|
614
|
+
pushed_at=task_dict["pushed_at"],
|
|
595
615
|
ttl=task_dict["ttl"],
|
|
596
616
|
ancestry=task_dict["ancestry"].split(","),
|
|
597
617
|
task_type=task_dict["task_type"],
|
|
@@ -621,6 +641,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes:
|
|
|
621
641
|
),
|
|
622
642
|
created_at=task_dict["created_at"],
|
|
623
643
|
delivered_at=task_dict["delivered_at"],
|
|
644
|
+
pushed_at=task_dict["pushed_at"],
|
|
624
645
|
ttl=task_dict["ttl"],
|
|
625
646
|
ancestry=task_dict["ancestry"].split(","),
|
|
626
647
|
task_type=task_dict["task_type"],
|
|
@@ -152,3 +152,22 @@ class State(abc.ABC):
|
|
|
152
152
|
@abc.abstractmethod
|
|
153
153
|
def create_run(self) -> int:
|
|
154
154
|
"""Create one run."""
|
|
155
|
+
|
|
156
|
+
@abc.abstractmethod
|
|
157
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
158
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
node_id : int
|
|
163
|
+
The `node_id` from which the ping was received.
|
|
164
|
+
ping_interval : float
|
|
165
|
+
The interval (in seconds) from the current timestamp within which the next
|
|
166
|
+
ping from this node must be received. This acts as a hard deadline to ensure
|
|
167
|
+
an accurate assessment of the node's availability.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
is_acknowledged : bool
|
|
172
|
+
True if the ping is successfully acknowledged; otherwise, False.
|
|
173
|
+
"""
|
flwr/server/utils/validator.py
CHANGED
|
@@ -31,13 +31,21 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
|
|
|
31
31
|
if not tasks_ins_res.HasField("task"):
|
|
32
32
|
validation_errors.append("`task` does not set field `task`")
|
|
33
33
|
|
|
34
|
-
# Created/delivered/TTL
|
|
35
|
-
if
|
|
36
|
-
|
|
34
|
+
# Created/delivered/TTL/Pushed
|
|
35
|
+
if (
|
|
36
|
+
tasks_ins_res.task.created_at < 1711497600.0
|
|
37
|
+
): # unix timestamp of 27 March 2024 00h:00m:00s UTC
|
|
38
|
+
validation_errors.append(
|
|
39
|
+
"`created_at` must be a float that records the unix timestamp "
|
|
40
|
+
"in seconds when the message was created."
|
|
41
|
+
)
|
|
37
42
|
if tasks_ins_res.task.delivered_at != "":
|
|
38
43
|
validation_errors.append("`delivered_at` must be an empty str")
|
|
39
|
-
if tasks_ins_res.task.ttl
|
|
40
|
-
validation_errors.append("`ttl` must be
|
|
44
|
+
if tasks_ins_res.task.ttl <= 0:
|
|
45
|
+
validation_errors.append("`ttl` must be higher than zero")
|
|
46
|
+
if tasks_ins_res.task.pushed_at < 1711497600.0:
|
|
47
|
+
# unix timestamp of 27 March 2024 00h:00m:00s UTC
|
|
48
|
+
validation_errors.append("`pushed_at` is not a recent timestamp")
|
|
41
49
|
|
|
42
50
|
# TaskIns specific
|
|
43
51
|
if isinstance(tasks_ins_res, TaskIns):
|
|
@@ -66,8 +74,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
|
|
|
66
74
|
# Content check
|
|
67
75
|
if tasks_ins_res.task.task_type == "":
|
|
68
76
|
validation_errors.append("`task_type` MUST be set")
|
|
69
|
-
if not
|
|
70
|
-
|
|
77
|
+
if not (
|
|
78
|
+
tasks_ins_res.task.HasField("recordset")
|
|
79
|
+
^ tasks_ins_res.task.HasField("error")
|
|
80
|
+
):
|
|
81
|
+
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
71
82
|
|
|
72
83
|
# Ancestors
|
|
73
84
|
if len(tasks_ins_res.task.ancestry) != 0:
|
|
@@ -106,8 +117,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
|
|
|
106
117
|
# Content check
|
|
107
118
|
if tasks_ins_res.task.task_type == "":
|
|
108
119
|
validation_errors.append("`task_type` MUST be set")
|
|
109
|
-
if not
|
|
110
|
-
|
|
120
|
+
if not (
|
|
121
|
+
tasks_ins_res.task.HasField("recordset")
|
|
122
|
+
^ tasks_ins_res.task.HasField("error")
|
|
123
|
+
):
|
|
124
|
+
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
111
125
|
|
|
112
126
|
# Ancestors
|
|
113
127
|
if len(tasks_ins_res.task.ancestry) == 0:
|
|
@@ -21,7 +21,7 @@ from logging import INFO
|
|
|
21
21
|
from typing import Optional, cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
|
-
from flwr.common import ConfigsRecord, Context, GetParametersIns, log
|
|
24
|
+
from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, GetParametersIns, log
|
|
25
25
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
26
26
|
|
|
27
27
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
@@ -127,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
127
127
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
128
128
|
dst_node_id=random_client.node_id,
|
|
129
129
|
group_id="0",
|
|
130
|
-
ttl=
|
|
130
|
+
ttl=DEFAULT_TTL,
|
|
131
131
|
)
|
|
132
132
|
]
|
|
133
133
|
)
|
|
@@ -226,7 +226,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
226
226
|
message_type=MessageType.TRAIN,
|
|
227
227
|
dst_node_id=proxy.node_id,
|
|
228
228
|
group_id=str(current_round),
|
|
229
|
-
ttl=
|
|
229
|
+
ttl=DEFAULT_TTL,
|
|
230
230
|
)
|
|
231
231
|
for proxy, fitins in client_instructions
|
|
232
232
|
]
|
|
@@ -306,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
306
306
|
message_type=MessageType.EVALUATE,
|
|
307
307
|
dst_node_id=proxy.node_id,
|
|
308
308
|
group_id=str(current_round),
|
|
309
|
-
ttl=
|
|
309
|
+
ttl=DEFAULT_TTL,
|
|
310
310
|
)
|
|
311
311
|
for proxy, evalins in client_instructions
|
|
312
312
|
]
|
|
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union, cast
|
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
25
|
+
DEFAULT_TTL,
|
|
25
26
|
ConfigsRecord,
|
|
26
27
|
Context,
|
|
27
28
|
FitRes,
|
|
@@ -373,7 +374,7 @@ class SecAggPlusWorkflow:
|
|
|
373
374
|
message_type=MessageType.TRAIN,
|
|
374
375
|
dst_node_id=nid,
|
|
375
376
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
376
|
-
ttl=
|
|
377
|
+
ttl=DEFAULT_TTL,
|
|
377
378
|
)
|
|
378
379
|
|
|
379
380
|
log(
|
|
@@ -421,7 +422,7 @@ class SecAggPlusWorkflow:
|
|
|
421
422
|
message_type=MessageType.TRAIN,
|
|
422
423
|
dst_node_id=nid,
|
|
423
424
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
424
|
-
ttl=
|
|
425
|
+
ttl=DEFAULT_TTL,
|
|
425
426
|
)
|
|
426
427
|
|
|
427
428
|
# Broadcast public keys to clients and receive secret key shares
|
|
@@ -492,7 +493,7 @@ class SecAggPlusWorkflow:
|
|
|
492
493
|
message_type=MessageType.TRAIN,
|
|
493
494
|
dst_node_id=nid,
|
|
494
495
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
495
|
-
ttl=
|
|
496
|
+
ttl=DEFAULT_TTL,
|
|
496
497
|
)
|
|
497
498
|
|
|
498
499
|
log(
|
|
@@ -563,7 +564,7 @@ class SecAggPlusWorkflow:
|
|
|
563
564
|
message_type=MessageType.TRAIN,
|
|
564
565
|
dst_node_id=nid,
|
|
565
566
|
group_id=str(current_round),
|
|
566
|
-
ttl=
|
|
567
|
+
ttl=DEFAULT_TTL,
|
|
567
568
|
)
|
|
568
569
|
|
|
569
570
|
log(
|
|
@@ -23,7 +23,7 @@ from flwr import common
|
|
|
23
23
|
from flwr.client import ClientFn
|
|
24
24
|
from flwr.client.client_app import ClientApp
|
|
25
25
|
from flwr.client.node_state import NodeState
|
|
26
|
-
from flwr.common import Message, Metadata, RecordSet
|
|
26
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
27
27
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
28
28
|
from flwr.common.logger import log
|
|
29
29
|
from flwr.common.recordset_compat import (
|
|
@@ -105,7 +105,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
105
105
|
src_node_id=0,
|
|
106
106
|
dst_node_id=int(self.cid),
|
|
107
107
|
reply_to_message="",
|
|
108
|
-
ttl=
|
|
108
|
+
ttl=timeout if timeout else DEFAULT_TTL,
|
|
109
109
|
message_type=message_type,
|
|
110
110
|
partition_id=int(self.cid),
|
|
111
111
|
),
|