flwr-nightly 1.15.0.dev20250104__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +23 -146
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +17 -2
- flwr/cli/login/login.py +20 -5
- flwr/cli/ls.py +10 -2
- flwr/cli/run/run.py +20 -10
- flwr/cli/stop.py +9 -1
- flwr/cli/utils.py +4 -4
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +4 -6
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +0 -6
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +34 -24
- flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +14 -20
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +152 -15
- flwr/common/constant.py +11 -8
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +161 -3
- flwr/common/logger.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +13 -3
- flwr/proto/clientappio_pb2_grpc.py +63 -12
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/exec_pb2.py +27 -29
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/exec_pb2_grpc.py +105 -24
- flwr/proto/fab_pb2.py +13 -3
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fleet_pb2.py +54 -31
- flwr/proto/fleet_pb2.pyi +84 -0
- flwr/proto/fleet_pb2_grpc.py +207 -28
- flwr/proto/fleet_pb2_grpc.pyi +26 -0
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/recordset_pb2.py +18 -8
- flwr/proto/recordset_pb2_grpc.py +20 -0
- flwr/proto/run_pb2.py +16 -6
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/serverappio_pb2.py +32 -14
- flwr/proto/serverappio_pb2.pyi +56 -0
- flwr/proto/serverappio_pb2_grpc.py +261 -44
- flwr/proto/serverappio_pb2_grpc.pyi +20 -0
- flwr/proto/simulationio_pb2.py +13 -3
- flwr/proto/simulationio_pb2_grpc.py +105 -24
- flwr/proto/task_pb2.py +13 -3
- flwr/proto/task_pb2_grpc.py +20 -0
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/server/app.py +87 -38
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +5 -2
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +5 -6
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +132 -14
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +38 -0
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +66 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -3
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +40 -48
- flwr/server/superlink/linkstate/linkstate.py +15 -22
- flwr/server/superlink/linkstate/sqlite_linkstate.py +80 -99
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +7 -7
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +98 -94
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
@@ -28,6 +28,7 @@ from flwr.common.constant import (
|
|
28
28
|
MESSAGE_TTL_TOLERANCE,
|
29
29
|
NODE_ID_NUM_BYTES,
|
30
30
|
RUN_ID_NUM_BYTES,
|
31
|
+
SUPERLINK_NODE_ID,
|
31
32
|
Status,
|
32
33
|
)
|
33
34
|
from flwr.common.record import ConfigsRecord
|
@@ -62,6 +63,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
62
63
|
# Map node_id to (online_until, ping_interval)
|
63
64
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
64
65
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
66
|
+
self.node_id_to_public_key: dict[int, bytes] = {}
|
65
67
|
|
66
68
|
# Map run_id to RunRecord
|
67
69
|
self.run_ids: dict[int, RunRecord] = {}
|
@@ -89,7 +91,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
89
91
|
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
90
92
|
return None
|
91
93
|
# Validate source node ID
|
92
|
-
if task_ins.task.producer.node_id !=
|
94
|
+
if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
|
93
95
|
log(
|
94
96
|
ERROR,
|
95
97
|
"Invalid source node ID for TaskIns: %s",
|
@@ -97,14 +99,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
97
99
|
)
|
98
100
|
return None
|
99
101
|
# Validate destination node ID
|
100
|
-
if
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
return None
|
102
|
+
if task_ins.task.consumer.node_id not in self.node_ids:
|
103
|
+
log(
|
104
|
+
ERROR,
|
105
|
+
"Invalid destination node ID for TaskIns: %s",
|
106
|
+
task_ins.task.consumer.node_id,
|
107
|
+
)
|
108
|
+
return None
|
108
109
|
|
109
110
|
# Create task_id
|
110
111
|
task_id = uuid4()
|
@@ -117,9 +118,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
117
118
|
# Return the new task_id
|
118
119
|
return task_id
|
119
120
|
|
120
|
-
def get_task_ins(
|
121
|
-
self, node_id: Optional[int], limit: Optional[int]
|
122
|
-
) -> list[TaskIns]:
|
121
|
+
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
123
122
|
"""Get all TaskIns that have not been delivered yet."""
|
124
123
|
if limit is not None and limit < 1:
|
125
124
|
raise AssertionError("`limit` must be >= 1")
|
@@ -129,17 +128,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
129
128
|
current_time = time.time()
|
130
129
|
with self.lock:
|
131
130
|
for _, task_ins in self.task_ins_store.items():
|
132
|
-
# pylint: disable=too-many-boolean-expressions
|
133
131
|
if (
|
134
|
-
node_id
|
135
|
-
and task_ins.task.consumer.anonymous is False
|
136
|
-
and task_ins.task.consumer.node_id == node_id
|
137
|
-
and task_ins.task.delivered_at == ""
|
138
|
-
and task_ins.task.created_at + task_ins.task.ttl > current_time
|
139
|
-
) or (
|
140
|
-
node_id is None # Anonymous
|
141
|
-
and task_ins.task.consumer.anonymous is True
|
142
|
-
and task_ins.task.consumer.node_id == 0
|
132
|
+
task_ins.task.consumer.node_id == node_id
|
143
133
|
and task_ins.task.delivered_at == ""
|
144
134
|
and task_ins.task.created_at + task_ins.task.ttl > current_time
|
145
135
|
):
|
@@ -173,9 +163,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
173
163
|
if (
|
174
164
|
task_ins
|
175
165
|
and task_res
|
176
|
-
and not (
|
177
|
-
task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
|
178
|
-
)
|
179
166
|
and task_ins.task.consumer.node_id != task_res.task.producer.node_id
|
180
167
|
):
|
181
168
|
return None
|
@@ -306,45 +293,30 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
306
293
|
"""
|
307
294
|
return len(self.task_res_store)
|
308
295
|
|
309
|
-
def create_node(
|
310
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
311
|
-
) -> int:
|
296
|
+
def create_node(self, ping_interval: float) -> int:
|
312
297
|
"""Create, store in the link state, and return `node_id`."""
|
313
298
|
# Sample a random int64 as node_id
|
314
|
-
node_id = generate_rand_int_from_bytes(
|
299
|
+
node_id = generate_rand_int_from_bytes(
|
300
|
+
NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
|
301
|
+
)
|
315
302
|
|
316
303
|
with self.lock:
|
317
304
|
if node_id in self.node_ids:
|
318
305
|
log(ERROR, "Unexpected node registration failure.")
|
319
306
|
return 0
|
320
307
|
|
321
|
-
if public_key is not None:
|
322
|
-
if (
|
323
|
-
public_key in self.public_key_to_node_id
|
324
|
-
or node_id in self.public_key_to_node_id.values()
|
325
|
-
):
|
326
|
-
log(ERROR, "Unexpected node registration failure.")
|
327
|
-
return 0
|
328
|
-
|
329
|
-
self.public_key_to_node_id[public_key] = node_id
|
330
|
-
|
331
308
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
332
309
|
return node_id
|
333
310
|
|
334
|
-
def delete_node(self, node_id: int
|
311
|
+
def delete_node(self, node_id: int) -> None:
|
335
312
|
"""Delete a node."""
|
336
313
|
with self.lock:
|
337
314
|
if node_id not in self.node_ids:
|
338
315
|
raise ValueError(f"Node {node_id} not found")
|
339
316
|
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
or node_id not in self.public_key_to_node_id.values()
|
344
|
-
):
|
345
|
-
raise ValueError("Public key or node_id not found")
|
346
|
-
|
347
|
-
del self.public_key_to_node_id[public_key]
|
317
|
+
# Remove node ID <> public key mappings
|
318
|
+
if pk := self.node_id_to_public_key.pop(node_id, None):
|
319
|
+
del self.public_key_to_node_id[pk]
|
348
320
|
|
349
321
|
del self.node_ids[node_id]
|
350
322
|
|
@@ -366,6 +338,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
366
338
|
if online_until > current_time
|
367
339
|
}
|
368
340
|
|
341
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
342
|
+
"""Set `public_key` for the specified `node_id`."""
|
343
|
+
with self.lock:
|
344
|
+
if node_id not in self.node_ids:
|
345
|
+
raise ValueError(f"Node {node_id} not found")
|
346
|
+
|
347
|
+
if public_key in self.public_key_to_node_id:
|
348
|
+
raise ValueError("Public key already in use")
|
349
|
+
|
350
|
+
self.public_key_to_node_id[public_key] = node_id
|
351
|
+
self.node_id_to_public_key[node_id] = public_key
|
352
|
+
|
353
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
354
|
+
"""Get `public_key` for the specified `node_id`."""
|
355
|
+
with self.lock:
|
356
|
+
if node_id not in self.node_ids:
|
357
|
+
raise ValueError(f"Node {node_id} not found")
|
358
|
+
|
359
|
+
return self.node_id_to_public_key.get(node_id)
|
360
|
+
|
369
361
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
370
362
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
371
363
|
return self.public_key_to_node_id.get(node_public_key)
|
@@ -40,20 +40,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
40
40
|
|
41
41
|
Constraints
|
42
42
|
-----------
|
43
|
-
|
44
|
-
`task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
|
45
|
-
|
46
|
-
If `task_ins.task.consumer.anonymous` is `False`, then
|
47
|
-
`task_ins.task.consumer.node_id` MUST be set (not 0)
|
43
|
+
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
48
44
|
|
49
45
|
If `task_ins.run_id` is invalid, then
|
50
46
|
storing the `task_ins` MUST fail.
|
51
47
|
"""
|
52
48
|
|
53
49
|
@abc.abstractmethod
|
54
|
-
def get_task_ins(
|
55
|
-
self, node_id: Optional[int], limit: Optional[int]
|
56
|
-
) -> list[TaskIns]:
|
50
|
+
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
57
51
|
"""Get TaskIns optionally filtered by node_id.
|
58
52
|
|
59
53
|
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
@@ -61,15 +55,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
61
55
|
|
62
56
|
Constraints
|
63
57
|
-----------
|
64
|
-
|
58
|
+
Retrieve all TaskIns where
|
65
59
|
|
66
60
|
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
67
|
-
2. the `task_ins.task.
|
68
|
-
3. the `task_ins.task.delivered_at` equals `""`.
|
61
|
+
2. the `task_ins.task.delivered_at` equals `""`.
|
69
62
|
|
70
|
-
If `node_id` is `None`, retrieve all TaskIns where the
|
71
|
-
`task_ins.task.consumer.node_id` equals `0` and
|
72
|
-
`task_ins.task.consumer.anonymous` is set to `True`.
|
73
63
|
|
74
64
|
If `delivered_at` MUST BE set (not `""`) otherwise the TaskIns MUST not be in
|
75
65
|
the result.
|
@@ -89,11 +79,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
89
79
|
|
90
80
|
Constraints
|
91
81
|
-----------
|
92
|
-
If `task_res.task.consumer.anonymous` is `True`, then
|
93
|
-
`task_res.task.consumer.node_id` MUST NOT be set (equal 0).
|
94
82
|
|
95
|
-
|
96
|
-
`task_res.task.consumer.node_id` MUST be set (not 0)
|
83
|
+
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
97
84
|
|
98
85
|
If `task_res.run_id` is invalid, then
|
99
86
|
storing the `task_res` MUST fail.
|
@@ -154,13 +141,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
154
141
|
"""Get all TaskIns IDs for the given run_id."""
|
155
142
|
|
156
143
|
@abc.abstractmethod
|
157
|
-
def create_node(
|
158
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
159
|
-
) -> int:
|
144
|
+
def create_node(self, ping_interval: float) -> int:
|
160
145
|
"""Create, store in the link state, and return `node_id`."""
|
161
146
|
|
162
147
|
@abc.abstractmethod
|
163
|
-
def delete_node(self, node_id: int
|
148
|
+
def delete_node(self, node_id: int) -> None:
|
164
149
|
"""Remove `node_id` from the link state."""
|
165
150
|
|
166
151
|
@abc.abstractmethod
|
@@ -173,6 +158,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
173
158
|
an empty `Set` MUST be returned.
|
174
159
|
"""
|
175
160
|
|
161
|
+
@abc.abstractmethod
|
162
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
163
|
+
"""Set `public_key` for the specified `node_id`."""
|
164
|
+
|
165
|
+
@abc.abstractmethod
|
166
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
167
|
+
"""Get `public_key` for the specified `node_id`."""
|
168
|
+
|
176
169
|
@abc.abstractmethod
|
177
170
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
178
171
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
@@ -31,6 +31,7 @@ from flwr.common.constant import (
|
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
32
32
|
NODE_ID_NUM_BYTES,
|
33
33
|
RUN_ID_NUM_BYTES,
|
34
|
+
SUPERLINK_NODE_ID,
|
34
35
|
Status,
|
35
36
|
)
|
36
37
|
from flwr.common.record import ConfigsRecord
|
@@ -72,14 +73,14 @@ CREATE TABLE IF NOT EXISTS node(
|
|
72
73
|
|
73
74
|
SQL_CREATE_TABLE_CREDENTIAL = """
|
74
75
|
CREATE TABLE IF NOT EXISTS credential(
|
75
|
-
private_key
|
76
|
-
public_key
|
76
|
+
private_key BLOB PRIMARY KEY,
|
77
|
+
public_key BLOB
|
77
78
|
);
|
78
79
|
"""
|
79
80
|
|
80
81
|
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
81
82
|
CREATE TABLE IF NOT EXISTS public_key(
|
82
|
-
public_key
|
83
|
+
public_key BLOB PRIMARY KEY
|
83
84
|
);
|
84
85
|
"""
|
85
86
|
|
@@ -128,9 +129,7 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
128
129
|
task_id TEXT UNIQUE,
|
129
130
|
group_id TEXT,
|
130
131
|
run_id INTEGER,
|
131
|
-
producer_anonymous BOOLEAN,
|
132
132
|
producer_node_id INTEGER,
|
133
|
-
consumer_anonymous BOOLEAN,
|
134
133
|
consumer_node_id INTEGER,
|
135
134
|
created_at REAL,
|
136
135
|
delivered_at TEXT,
|
@@ -148,9 +147,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
148
147
|
task_id TEXT UNIQUE,
|
149
148
|
group_id TEXT,
|
150
149
|
run_id INTEGER,
|
151
|
-
producer_anonymous BOOLEAN,
|
152
150
|
producer_node_id INTEGER,
|
153
|
-
consumer_anonymous BOOLEAN,
|
154
151
|
consumer_node_id INTEGER,
|
155
152
|
created_at REAL,
|
156
153
|
delivered_at TEXT,
|
@@ -263,11 +260,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
263
260
|
|
264
261
|
Constraints
|
265
262
|
-----------
|
266
|
-
If `task_ins.task.consumer.anonymous` is `True`, then
|
267
|
-
`task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
|
268
263
|
|
269
|
-
|
270
|
-
`task_ins.task.consumer.node_id` MUST be set (not 0)
|
264
|
+
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
271
265
|
"""
|
272
266
|
# Validate task
|
273
267
|
errors = validate_task_ins_or_res(task_ins)
|
@@ -292,7 +286,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
292
286
|
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
293
287
|
return None
|
294
288
|
# Validate source node ID
|
295
|
-
if task_ins.task.producer.node_id !=
|
289
|
+
if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
|
296
290
|
log(
|
297
291
|
ERROR,
|
298
292
|
"Invalid source node ID for TaskIns: %s",
|
@@ -301,14 +295,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
301
295
|
return None
|
302
296
|
# Validate destination node ID
|
303
297
|
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
304
|
-
if not
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
return None
|
298
|
+
if not self.query(query, (data[0]["consumer_node_id"],)):
|
299
|
+
log(
|
300
|
+
ERROR,
|
301
|
+
"Invalid destination node ID for TaskIns: %s",
|
302
|
+
task_ins.task.consumer.node_id,
|
303
|
+
)
|
304
|
+
return None
|
312
305
|
|
313
306
|
columns = ", ".join([f":{key}" for key in data[0]])
|
314
307
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
@@ -319,25 +312,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
319
312
|
|
320
313
|
return task_id
|
321
314
|
|
322
|
-
def get_task_ins(
|
323
|
-
|
324
|
-
) -> list[TaskIns]:
|
325
|
-
"""Get undelivered TaskIns for one node (either anonymous or with ID).
|
315
|
+
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
316
|
+
"""Get undelivered TaskIns for one node.
|
326
317
|
|
327
318
|
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
328
319
|
TaskIns.
|
329
320
|
|
330
321
|
Constraints
|
331
322
|
-----------
|
332
|
-
|
323
|
+
Retrieve all TaskIns where
|
333
324
|
|
334
325
|
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
335
|
-
2. the `task_ins.task.
|
336
|
-
3. the `task_ins.task.delivered_at` equals `""`.
|
337
|
-
|
338
|
-
If `node_id` is `None`, retrieve all TaskIns where the
|
339
|
-
`task_ins.task.consumer.node_id` equals `0` and
|
340
|
-
`task_ins.task.consumer.anonymous` is set to `True`.
|
326
|
+
2. the `task_ins.task.delivered_at` equals `""`.
|
341
327
|
|
342
328
|
`delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
|
343
329
|
the result.
|
@@ -348,38 +334,23 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
348
334
|
if limit is not None and limit < 1:
|
349
335
|
raise AssertionError("`limit` must be >= 1")
|
350
336
|
|
351
|
-
if node_id ==
|
352
|
-
msg =
|
353
|
-
"`node_id` must be >= 1"
|
354
|
-
"\n\n For requesting anonymous tasks use `node_id` equal `None`"
|
355
|
-
)
|
337
|
+
if node_id == SUPERLINK_NODE_ID:
|
338
|
+
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
356
339
|
raise AssertionError(msg)
|
357
340
|
|
358
341
|
data: dict[str, Union[str, int]] = {}
|
359
342
|
|
360
|
-
|
361
|
-
|
362
|
-
query = """
|
363
|
-
SELECT task_id
|
364
|
-
FROM task_ins
|
365
|
-
WHERE consumer_anonymous == 1
|
366
|
-
AND consumer_node_id == 0
|
367
|
-
AND delivered_at = ""
|
368
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
369
|
-
"""
|
370
|
-
else:
|
371
|
-
# Convert the uint64 value to sint64 for SQLite
|
372
|
-
data["node_id"] = convert_uint64_to_sint64(node_id)
|
343
|
+
# Convert the uint64 value to sint64 for SQLite
|
344
|
+
data["node_id"] = convert_uint64_to_sint64(node_id)
|
373
345
|
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
"""
|
346
|
+
# Retrieve all TaskIns for node_id
|
347
|
+
query = """
|
348
|
+
SELECT task_id
|
349
|
+
FROM task_ins
|
350
|
+
WHERE consumer_node_id == :node_id
|
351
|
+
AND delivered_at = ""
|
352
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
353
|
+
"""
|
383
354
|
|
384
355
|
if limit is not None:
|
385
356
|
query += " LIMIT :limit"
|
@@ -429,11 +400,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
429
400
|
|
430
401
|
Constraints
|
431
402
|
-----------
|
432
|
-
|
433
|
-
`task_res.task.consumer.node_id` MUST NOT be set (equal 0).
|
434
|
-
|
435
|
-
If `task_res.task.consumer.anonymous` is `False`, then
|
436
|
-
`task_res.task.consumer.node_id` MUST be set (not 0)
|
403
|
+
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
437
404
|
"""
|
438
405
|
# Validate task
|
439
406
|
errors = validate_task_ins_or_res(task_res)
|
@@ -459,7 +426,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
459
426
|
if (
|
460
427
|
task_ins
|
461
428
|
and task_res
|
462
|
-
and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
|
463
429
|
and convert_sint64_to_uint64(task_ins["consumer_node_id"])
|
464
430
|
!= task_res.task.producer.node_id
|
465
431
|
):
|
@@ -635,23 +601,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
635
601
|
|
636
602
|
return {UUID(row["task_id"]) for row in rows}
|
637
603
|
|
638
|
-
def create_node(
|
639
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
640
|
-
) -> int:
|
604
|
+
def create_node(self, ping_interval: float) -> int:
|
641
605
|
"""Create, store in the link state, and return `node_id`."""
|
642
606
|
# Sample a random uint64 as node_id
|
643
|
-
uint64_node_id = generate_rand_int_from_bytes(
|
607
|
+
uint64_node_id = generate_rand_int_from_bytes(
|
608
|
+
NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
|
609
|
+
)
|
644
610
|
|
645
611
|
# Convert the uint64 value to sint64 for SQLite
|
646
612
|
sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
|
647
613
|
|
648
|
-
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
649
|
-
row = self.query(query, {"public_key": public_key})
|
650
|
-
|
651
|
-
if len(row) > 0:
|
652
|
-
log(ERROR, "Unexpected node registration failure.")
|
653
|
-
return 0
|
654
|
-
|
655
614
|
query = (
|
656
615
|
"INSERT INTO node "
|
657
616
|
"(node_id, online_until, ping_interval, public_key) "
|
@@ -665,7 +624,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
665
624
|
sint64_node_id,
|
666
625
|
time.time() + ping_interval,
|
667
626
|
ping_interval,
|
668
|
-
|
627
|
+
b"", # Initialize with an empty public key
|
669
628
|
),
|
670
629
|
)
|
671
630
|
except sqlite3.IntegrityError:
|
@@ -675,7 +634,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
675
634
|
# Note: we need to return the uint64 value of the node_id
|
676
635
|
return uint64_node_id
|
677
636
|
|
678
|
-
def delete_node(self, node_id: int
|
637
|
+
def delete_node(self, node_id: int) -> None:
|
679
638
|
"""Delete a node."""
|
680
639
|
# Convert the uint64 value to sint64 for SQLite
|
681
640
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
@@ -683,10 +642,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
683
642
|
query = "DELETE FROM node WHERE node_id = ?"
|
684
643
|
params = (sint64_node_id,)
|
685
644
|
|
686
|
-
if public_key is not None:
|
687
|
-
query += " AND public_key = ?"
|
688
|
-
params += (public_key,) # type: ignore
|
689
|
-
|
690
645
|
if self.conn is None:
|
691
646
|
raise AttributeError("LinkState is not initialized.")
|
692
647
|
|
@@ -694,7 +649,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
694
649
|
with self.conn:
|
695
650
|
rows = self.conn.execute(query, params)
|
696
651
|
if rows.rowcount < 1:
|
697
|
-
raise ValueError("
|
652
|
+
raise ValueError(f"Node {node_id} not found")
|
698
653
|
except KeyError as exc:
|
699
654
|
log(ERROR, {"query": query, "data": params, "exception": exc})
|
700
655
|
|
@@ -722,6 +677,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
722
677
|
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
723
678
|
return result
|
724
679
|
|
680
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
681
|
+
"""Set `public_key` for the specified `node_id`."""
|
682
|
+
# Convert the uint64 value to sint64 for SQLite
|
683
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
684
|
+
|
685
|
+
# Check if the node exists in the `node` table
|
686
|
+
query = "SELECT 1 FROM node WHERE node_id = ?"
|
687
|
+
if not self.query(query, (sint64_node_id,)):
|
688
|
+
raise ValueError(f"Node {node_id} not found")
|
689
|
+
|
690
|
+
# Check if the public key is already in use in the `node` table
|
691
|
+
query = "SELECT 1 FROM node WHERE public_key = ?"
|
692
|
+
if self.query(query, (public_key,)):
|
693
|
+
raise ValueError("Public key already in use")
|
694
|
+
|
695
|
+
# Update the `node` table to set the public key for the given node ID
|
696
|
+
query = "UPDATE node SET public_key = ? WHERE node_id = ?"
|
697
|
+
self.query(query, (public_key, sint64_node_id))
|
698
|
+
|
699
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
700
|
+
"""Get `public_key` for the specified `node_id`."""
|
701
|
+
# Convert the uint64 value to sint64 for SQLite
|
702
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
703
|
+
|
704
|
+
# Query the public key for the given node_id
|
705
|
+
query = "SELECT public_key FROM node WHERE node_id = ?"
|
706
|
+
rows = self.query(query, (sint64_node_id,))
|
707
|
+
|
708
|
+
# If no result is found, return None
|
709
|
+
if not rows:
|
710
|
+
raise ValueError(f"Node {node_id} not found")
|
711
|
+
|
712
|
+
# Return the public key if it is not empty, otherwise return None
|
713
|
+
return rows[0]["public_key"] or None
|
714
|
+
|
725
715
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
726
716
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
727
717
|
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
@@ -982,17 +972,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
982
972
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
983
973
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
984
974
|
|
985
|
-
#
|
986
|
-
query = "
|
987
|
-
|
988
|
-
self.query(
|
989
|
-
query, (time.time() + ping_interval, ping_interval, sint64_node_id)
|
990
|
-
)
|
991
|
-
return True
|
992
|
-
except sqlite3.IntegrityError:
|
993
|
-
log(ERROR, "`node_id` does not exist.")
|
975
|
+
# Check if the node exists in the `node` table
|
976
|
+
query = "SELECT 1 FROM node WHERE node_id = ?"
|
977
|
+
if not self.query(query, (sint64_node_id,)):
|
994
978
|
return False
|
995
979
|
|
980
|
+
# Update `online_until` and `ping_interval` for the given `node_id`
|
981
|
+
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
982
|
+
self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
|
983
|
+
return True
|
984
|
+
|
996
985
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
997
986
|
"""Get the context for the specified `run_id`."""
|
998
987
|
# Retrieve context if any
|
@@ -1105,9 +1094,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
|
|
1105
1094
|
"task_id": task_msg.task_id,
|
1106
1095
|
"group_id": task_msg.group_id,
|
1107
1096
|
"run_id": task_msg.run_id,
|
1108
|
-
"producer_anonymous": task_msg.task.producer.anonymous,
|
1109
1097
|
"producer_node_id": task_msg.task.producer.node_id,
|
1110
|
-
"consumer_anonymous": task_msg.task.consumer.anonymous,
|
1111
1098
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
1112
1099
|
"created_at": task_msg.task.created_at,
|
1113
1100
|
"delivered_at": task_msg.task.delivered_at,
|
@@ -1126,9 +1113,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
1126
1113
|
"task_id": task_msg.task_id,
|
1127
1114
|
"group_id": task_msg.group_id,
|
1128
1115
|
"run_id": task_msg.run_id,
|
1129
|
-
"producer_anonymous": task_msg.task.producer.anonymous,
|
1130
1116
|
"producer_node_id": task_msg.task.producer.node_id,
|
1131
|
-
"consumer_anonymous": task_msg.task.consumer.anonymous,
|
1132
1117
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
1133
1118
|
"created_at": task_msg.task.created_at,
|
1134
1119
|
"delivered_at": task_msg.task.delivered_at,
|
@@ -1153,11 +1138,9 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
1153
1138
|
task=Task(
|
1154
1139
|
producer=Node(
|
1155
1140
|
node_id=task_dict["producer_node_id"],
|
1156
|
-
anonymous=task_dict["producer_anonymous"],
|
1157
1141
|
),
|
1158
1142
|
consumer=Node(
|
1159
1143
|
node_id=task_dict["consumer_node_id"],
|
1160
|
-
anonymous=task_dict["consumer_anonymous"],
|
1161
1144
|
),
|
1162
1145
|
created_at=task_dict["created_at"],
|
1163
1146
|
delivered_at=task_dict["delivered_at"],
|
@@ -1183,11 +1166,9 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
1183
1166
|
task=Task(
|
1184
1167
|
producer=Node(
|
1185
1168
|
node_id=task_dict["producer_node_id"],
|
1186
|
-
anonymous=task_dict["producer_anonymous"],
|
1187
1169
|
),
|
1188
1170
|
consumer=Node(
|
1189
1171
|
node_id=task_dict["consumer_node_id"],
|
1190
|
-
anonymous=task_dict["consumer_anonymous"],
|
1191
1172
|
),
|
1192
1173
|
created_at=task_dict["created_at"],
|
1193
1174
|
delivered_at=task_dict["delivered_at"],
|
@@ -21,7 +21,7 @@ from typing import Optional, Union
|
|
21
21
|
from uuid import UUID, uuid4
|
22
22
|
|
23
23
|
from flwr.common import ConfigsRecord, Context, log, now, serde
|
24
|
-
from flwr.common.constant import ErrorCode, Status, SubStatus
|
24
|
+
from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
|
25
25
|
from flwr.common.typing import RunStatus
|
26
26
|
|
27
27
|
# pylint: disable=E0611
|
@@ -60,9 +60,19 @@ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
60
60
|
)
|
61
61
|
|
62
62
|
|
63
|
-
def generate_rand_int_from_bytes(
|
64
|
-
|
65
|
-
|
63
|
+
def generate_rand_int_from_bytes(
|
64
|
+
num_bytes: int, exclude: Optional[list[int]] = None
|
65
|
+
) -> int:
|
66
|
+
"""Generate a random unsigned integer from `num_bytes` bytes.
|
67
|
+
|
68
|
+
If `exclude` is set, this function guarantees such number is not returned.
|
69
|
+
"""
|
70
|
+
num = int.from_bytes(urandom(num_bytes), "little", signed=False)
|
71
|
+
|
72
|
+
if exclude:
|
73
|
+
while num in exclude:
|
74
|
+
num = int.from_bytes(urandom(num_bytes), "little", signed=False)
|
75
|
+
return num
|
66
76
|
|
67
77
|
|
68
78
|
def convert_uint64_to_sint64(u: int) -> int:
|
@@ -246,8 +256,8 @@ def create_taskres_for_unavailable_taskins(taskins_id: Union[str, UUID]) -> Task
|
|
246
256
|
run_id=0, # Unknown run ID
|
247
257
|
task=Task(
|
248
258
|
# This function is only called by SuperLink, and thus it's the producer.
|
249
|
-
producer=Node(node_id=
|
250
|
-
consumer=Node(node_id=
|
259
|
+
producer=Node(node_id=SUPERLINK_NODE_ID),
|
260
|
+
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
251
261
|
created_at=current_time,
|
252
262
|
ttl=0,
|
253
263
|
ancestry=[str(taskins_id)],
|
@@ -285,8 +295,8 @@ def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
285
295
|
run_id=ref_taskins.run_id,
|
286
296
|
task=Task(
|
287
297
|
# This function is only called by SuperLink, and thus it's the producer.
|
288
|
-
producer=Node(node_id=
|
289
|
-
consumer=Node(node_id=
|
298
|
+
producer=Node(node_id=SUPERLINK_NODE_ID),
|
299
|
+
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
290
300
|
created_at=current_time,
|
291
301
|
ttl=ttl,
|
292
302
|
ancestry=[ref_taskins.task_id],
|