flwr 1.15.2__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/client/client_app.py +147 -36
- flwr/client/clientapp/app.py +4 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/constant.py +16 -0
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/message.py +18 -7
- flwr/common/object_ref.py +0 -10
- flwr/common/record/conversion_utils.py +8 -17
- flwr/common/record/parametersrecord.py +151 -16
- flwr/common/record/recordset.py +95 -88
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/serde.py +8 -126
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +36 -0
- flwr/server/app.py +18 -2
- flwr/server/compat/app.py +4 -1
- flwr/server/compat/app_utils.py +10 -2
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +10 -1
- flwr/server/driver/inmemory_driver.py +17 -20
- flwr/server/run_serverapp.py +2 -13
- flwr/server/server_app.py +93 -20
- flwr/server/superlink/driver/serverappio_servicer.py +25 -27
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -35
- flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
- flwr/server/superlink/linkstate/linkstate.py +47 -60
- flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -276
- flwr/server/superlink/linkstate/utils.py +91 -119
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +4 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/app.py +0 -14
- flwr/superexec/exec_servicer.py +4 -4
- flwr/superexec/exec_user_auth_interceptor.py +5 -3
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/METADATA +4 -4
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/RECORD +63 -66
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
|
@@ -26,7 +26,7 @@ from logging import DEBUG, ERROR, WARNING
|
|
|
26
26
|
from typing import Any, Optional, Union, cast
|
|
27
27
|
from uuid import UUID, uuid4
|
|
28
28
|
|
|
29
|
-
from flwr.common import Context, log, now
|
|
29
|
+
from flwr.common import Context, Message, Metadata, log, now
|
|
30
30
|
from flwr.common.constant import (
|
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
|
32
32
|
NODE_ID_NUM_BYTES,
|
|
@@ -35,15 +35,20 @@ from flwr.common.constant import (
|
|
|
35
35
|
Status,
|
|
36
36
|
)
|
|
37
37
|
from flwr.common.record import ConfigsRecord
|
|
38
|
+
from flwr.common.serde import (
|
|
39
|
+
error_from_proto,
|
|
40
|
+
error_to_proto,
|
|
41
|
+
recordset_from_proto,
|
|
42
|
+
recordset_to_proto,
|
|
43
|
+
)
|
|
38
44
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
39
45
|
|
|
40
46
|
# pylint: disable=E0611
|
|
41
|
-
from flwr.proto.
|
|
47
|
+
from flwr.proto.error_pb2 import Error as ProtoError
|
|
42
48
|
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
43
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
44
49
|
|
|
45
50
|
# pylint: enable=E0611
|
|
46
|
-
from flwr.server.utils.validator import
|
|
51
|
+
from flwr.server.utils.validator import validate_message
|
|
47
52
|
|
|
48
53
|
from .linkstate import LinkState
|
|
49
54
|
from .utils import (
|
|
@@ -58,8 +63,8 @@ from .utils import (
|
|
|
58
63
|
generate_rand_int_from_bytes,
|
|
59
64
|
has_valid_sub_status,
|
|
60
65
|
is_valid_transition,
|
|
61
|
-
|
|
62
|
-
|
|
66
|
+
verify_found_message_replies,
|
|
67
|
+
verify_message_ids,
|
|
63
68
|
)
|
|
64
69
|
|
|
65
70
|
SQL_CREATE_TABLE_NODE = """
|
|
@@ -117,36 +122,39 @@ CREATE TABLE IF NOT EXISTS context(
|
|
|
117
122
|
);
|
|
118
123
|
"""
|
|
119
124
|
|
|
120
|
-
|
|
121
|
-
CREATE TABLE IF NOT EXISTS
|
|
122
|
-
|
|
125
|
+
SQL_CREATE_TABLE_MESSAGE_INS = """
|
|
126
|
+
CREATE TABLE IF NOT EXISTS message_ins(
|
|
127
|
+
message_id TEXT UNIQUE,
|
|
123
128
|
group_id TEXT,
|
|
124
129
|
run_id INTEGER,
|
|
125
|
-
|
|
126
|
-
|
|
130
|
+
src_node_id INTEGER,
|
|
131
|
+
dst_node_id INTEGER,
|
|
132
|
+
reply_to_message TEXT,
|
|
127
133
|
created_at REAL,
|
|
128
134
|
delivered_at TEXT,
|
|
129
135
|
ttl REAL,
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
136
|
+
message_type TEXT,
|
|
137
|
+
content BLOB NULL,
|
|
138
|
+
error BLOB NULL,
|
|
133
139
|
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
134
140
|
);
|
|
135
141
|
"""
|
|
136
142
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
143
|
+
|
|
144
|
+
SQL_CREATE_TABLE_MESSAGE_RES = """
|
|
145
|
+
CREATE TABLE IF NOT EXISTS message_res(
|
|
146
|
+
message_id TEXT UNIQUE,
|
|
140
147
|
group_id TEXT,
|
|
141
148
|
run_id INTEGER,
|
|
142
|
-
|
|
143
|
-
|
|
149
|
+
src_node_id INTEGER,
|
|
150
|
+
dst_node_id INTEGER,
|
|
151
|
+
reply_to_message TEXT,
|
|
144
152
|
created_at REAL,
|
|
145
153
|
delivered_at TEXT,
|
|
146
154
|
ttl REAL,
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
155
|
+
message_type TEXT,
|
|
156
|
+
content BLOB NULL,
|
|
157
|
+
error BLOB NULL,
|
|
150
158
|
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
151
159
|
);
|
|
152
160
|
"""
|
|
@@ -196,8 +204,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
196
204
|
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
197
205
|
cur.execute(SQL_CREATE_TABLE_LOGS)
|
|
198
206
|
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
199
|
-
cur.execute(
|
|
200
|
-
cur.execute(
|
|
207
|
+
cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
|
|
208
|
+
cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
|
|
201
209
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
202
210
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
203
211
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
@@ -239,88 +247,62 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
239
247
|
|
|
240
248
|
return result
|
|
241
249
|
|
|
242
|
-
def
|
|
243
|
-
"""Store one
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
Stores the value of the task_ins in the link state and, if successful,
|
|
248
|
-
returns the task_id (UUID) of the task_ins. If, for any reason, storing
|
|
249
|
-
the task_ins fails, `None` is returned.
|
|
250
|
-
|
|
251
|
-
Constraints
|
|
252
|
-
-----------
|
|
253
|
-
|
|
254
|
-
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
255
|
-
"""
|
|
256
|
-
# Validate task
|
|
257
|
-
errors = validate_task_ins_or_res(task_ins)
|
|
250
|
+
def store_message_ins(self, message: Message) -> Optional[UUID]:
|
|
251
|
+
"""Store one Message."""
|
|
252
|
+
# Validate message
|
|
253
|
+
errors = validate_message(message=message, is_reply_message=False)
|
|
258
254
|
if any(errors):
|
|
259
255
|
log(ERROR, errors)
|
|
260
256
|
return None
|
|
261
|
-
# Create
|
|
262
|
-
|
|
257
|
+
# Create message_id
|
|
258
|
+
message_id = uuid4()
|
|
263
259
|
|
|
264
|
-
# Store
|
|
265
|
-
|
|
266
|
-
|
|
260
|
+
# Store Message
|
|
261
|
+
# pylint: disable-next=W0212
|
|
262
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
|
263
|
+
data = (message_to_dict(message),)
|
|
267
264
|
|
|
268
265
|
# Convert values from uint64 to sint64 for SQLite
|
|
269
266
|
convert_uint64_values_in_dict_to_sint64(
|
|
270
|
-
data[0], ["run_id", "
|
|
267
|
+
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
271
268
|
)
|
|
272
269
|
|
|
273
270
|
# Validate run_id
|
|
274
271
|
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
275
272
|
if not self.query(query, (data[0]["run_id"],)):
|
|
276
|
-
log(ERROR, "Invalid run ID for
|
|
273
|
+
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
277
274
|
return None
|
|
275
|
+
|
|
278
276
|
# Validate source node ID
|
|
279
|
-
if
|
|
277
|
+
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
280
278
|
log(
|
|
281
279
|
ERROR,
|
|
282
|
-
"Invalid source node ID for
|
|
283
|
-
|
|
280
|
+
"Invalid source node ID for Message: %s",
|
|
281
|
+
message.metadata.src_node_id,
|
|
284
282
|
)
|
|
285
283
|
return None
|
|
284
|
+
|
|
286
285
|
# Validate destination node ID
|
|
287
286
|
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
|
288
|
-
if not self.query(query, (data[0]["
|
|
287
|
+
if not self.query(query, (data[0]["dst_node_id"],)):
|
|
289
288
|
log(
|
|
290
289
|
ERROR,
|
|
291
|
-
"Invalid destination node ID for
|
|
292
|
-
|
|
290
|
+
"Invalid destination node ID for Message: %s",
|
|
291
|
+
message.metadata.dst_node_id,
|
|
293
292
|
)
|
|
294
293
|
return None
|
|
295
294
|
|
|
296
295
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
297
|
-
query = f"INSERT INTO
|
|
296
|
+
query = f"INSERT INTO message_ins VALUES({columns});"
|
|
298
297
|
|
|
299
298
|
# Only invalid run_id can trigger IntegrityError.
|
|
300
299
|
# This may need to be changed in the future version with more integrity checks.
|
|
301
300
|
self.query(query, data)
|
|
302
301
|
|
|
303
|
-
return
|
|
304
|
-
|
|
305
|
-
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
|
306
|
-
"""Get undelivered TaskIns for one node.
|
|
307
|
-
|
|
308
|
-
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
|
309
|
-
TaskIns.
|
|
310
|
-
|
|
311
|
-
Constraints
|
|
312
|
-
-----------
|
|
313
|
-
Retrieve all TaskIns where
|
|
314
|
-
|
|
315
|
-
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
|
316
|
-
2. the `task_ins.task.delivered_at` equals `""`.
|
|
317
|
-
|
|
318
|
-
`delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
|
|
319
|
-
the result.
|
|
302
|
+
return message_id
|
|
320
303
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
"""
|
|
304
|
+
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
|
305
|
+
"""Get all Messages that have not been delivered yet."""
|
|
324
306
|
if limit is not None and limit < 1:
|
|
325
307
|
raise AssertionError("`limit` must be >= 1")
|
|
326
308
|
|
|
@@ -333,11 +315,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
333
315
|
# Convert the uint64 value to sint64 for SQLite
|
|
334
316
|
data["node_id"] = convert_uint64_to_sint64(node_id)
|
|
335
317
|
|
|
336
|
-
# Retrieve all
|
|
318
|
+
# Retrieve all Messages for node_id
|
|
337
319
|
query = """
|
|
338
|
-
SELECT
|
|
339
|
-
FROM
|
|
340
|
-
WHERE
|
|
320
|
+
SELECT message_id
|
|
321
|
+
FROM message_ins
|
|
322
|
+
WHERE dst_node_id == :node_id
|
|
341
323
|
AND delivered_at = ""
|
|
342
324
|
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
343
325
|
"""
|
|
@@ -352,20 +334,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
352
334
|
|
|
353
335
|
if rows:
|
|
354
336
|
# Prepare query
|
|
355
|
-
|
|
356
|
-
placeholders: str = ",".join([f":id_{i}" for i in range(len(
|
|
337
|
+
message_ids = [row["message_id"] for row in rows]
|
|
338
|
+
placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
|
|
357
339
|
query = f"""
|
|
358
|
-
UPDATE
|
|
340
|
+
UPDATE message_ins
|
|
359
341
|
SET delivered_at = :delivered_at
|
|
360
|
-
WHERE
|
|
342
|
+
WHERE message_id IN ({placeholders})
|
|
361
343
|
RETURNING *;
|
|
362
344
|
"""
|
|
363
345
|
|
|
364
346
|
# Prepare data for query
|
|
365
347
|
delivered_at = now().isoformat()
|
|
366
348
|
data = {"delivered_at": delivered_at}
|
|
367
|
-
for index,
|
|
368
|
-
data[f"id_{index}"] = str(
|
|
349
|
+
for index, msg_id in enumerate(message_ids):
|
|
350
|
+
data[f"id_{index}"] = str(msg_id)
|
|
369
351
|
|
|
370
352
|
# Run query
|
|
371
353
|
rows = self.query(query, data)
|
|
@@ -373,86 +355,80 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
373
355
|
for row in rows:
|
|
374
356
|
# Convert values from sint64 to uint64
|
|
375
357
|
convert_sint64_values_in_dict_to_uint64(
|
|
376
|
-
row, ["run_id", "
|
|
358
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
377
359
|
)
|
|
378
360
|
|
|
379
|
-
result = [
|
|
361
|
+
result = [dict_to_message(row) for row in rows]
|
|
380
362
|
|
|
381
363
|
return result
|
|
382
364
|
|
|
383
|
-
def
|
|
384
|
-
"""Store one
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
|
|
389
|
-
the `task_res`. If storing the `task_res` fails, `None` is returned.
|
|
390
|
-
|
|
391
|
-
Constraints
|
|
392
|
-
-----------
|
|
393
|
-
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
394
|
-
"""
|
|
395
|
-
# Validate task
|
|
396
|
-
errors = validate_task_ins_or_res(task_res)
|
|
365
|
+
def store_message_res(self, message: Message) -> Optional[UUID]:
|
|
366
|
+
"""Store one Message."""
|
|
367
|
+
# Validate message
|
|
368
|
+
errors = validate_message(message=message, is_reply_message=True)
|
|
397
369
|
if any(errors):
|
|
398
370
|
log(ERROR, errors)
|
|
399
371
|
return None
|
|
400
372
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
task_ins = self.get_valid_task_ins(task_ins_id)
|
|
406
|
-
if task_ins is None:
|
|
373
|
+
res_metadata = message.metadata
|
|
374
|
+
msg_ins_id = res_metadata.reply_to_message
|
|
375
|
+
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
|
376
|
+
if msg_ins is None:
|
|
407
377
|
log(
|
|
408
378
|
ERROR,
|
|
409
|
-
"Failed to store
|
|
410
|
-
"
|
|
411
|
-
|
|
379
|
+
"Failed to store Message reply: "
|
|
380
|
+
"The message it replies to with message_id %s does not exist or "
|
|
381
|
+
"has expired.",
|
|
382
|
+
msg_ins_id,
|
|
412
383
|
)
|
|
413
384
|
return None
|
|
414
385
|
|
|
415
|
-
# Ensure that the
|
|
386
|
+
# Ensure that the dst_node_id of the original message matches the src_node_id of
|
|
387
|
+
# reply being processed.
|
|
416
388
|
if (
|
|
417
|
-
|
|
418
|
-
and
|
|
419
|
-
and convert_sint64_to_uint64(
|
|
420
|
-
!=
|
|
389
|
+
msg_ins
|
|
390
|
+
and message
|
|
391
|
+
and convert_sint64_to_uint64(msg_ins["dst_node_id"])
|
|
392
|
+
!= res_metadata.src_node_id
|
|
421
393
|
):
|
|
422
394
|
return None
|
|
423
395
|
|
|
424
|
-
# Fail if the
|
|
425
|
-
# expiration time of the
|
|
426
|
-
# Condition:
|
|
427
|
-
#
|
|
396
|
+
# Fail if the Message TTL exceeds the
|
|
397
|
+
# expiration time of the Message it replies to.
|
|
398
|
+
# Condition: ins_metadata.created_at + ins_metadata.ttl ≥
|
|
399
|
+
# res_metadata.created_at + res_metadata.ttl
|
|
428
400
|
# A small tolerance is introduced to account
|
|
429
401
|
# for floating-point precision issues.
|
|
430
402
|
max_allowed_ttl = (
|
|
431
|
-
|
|
403
|
+
msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
|
|
432
404
|
)
|
|
433
|
-
if
|
|
434
|
-
|
|
405
|
+
if res_metadata.ttl and (
|
|
406
|
+
res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
|
|
435
407
|
):
|
|
436
408
|
log(
|
|
437
409
|
WARNING,
|
|
438
|
-
"Received
|
|
439
|
-
"
|
|
440
|
-
|
|
410
|
+
"Received Message with TTL %.2f exceeding the allowed maximum "
|
|
411
|
+
"TTL %.2f.",
|
|
412
|
+
res_metadata.ttl,
|
|
441
413
|
max_allowed_ttl,
|
|
442
414
|
)
|
|
443
415
|
return None
|
|
444
416
|
|
|
445
|
-
#
|
|
446
|
-
|
|
447
|
-
|
|
417
|
+
# Create message_id
|
|
418
|
+
message_id = uuid4()
|
|
419
|
+
|
|
420
|
+
# Store Message
|
|
421
|
+
# pylint: disable-next=W0212
|
|
422
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
|
423
|
+
data = (message_to_dict(message),)
|
|
448
424
|
|
|
449
425
|
# Convert values from uint64 to sint64 for SQLite
|
|
450
426
|
convert_uint64_values_in_dict_to_sint64(
|
|
451
|
-
data[0], ["run_id", "
|
|
427
|
+
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
452
428
|
)
|
|
453
429
|
|
|
454
430
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
455
|
-
query = f"INSERT INTO
|
|
431
|
+
query = f"INSERT INTO message_res VALUES({columns});"
|
|
456
432
|
|
|
457
433
|
# Only invalid run_id can trigger IntegrityError.
|
|
458
434
|
# This may need to be changed in the future version with more integrity checks.
|
|
@@ -462,124 +438,125 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
462
438
|
log(ERROR, "`run` is invalid")
|
|
463
439
|
return None
|
|
464
440
|
|
|
465
|
-
return
|
|
441
|
+
return message_id
|
|
466
442
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
ret: dict[UUID, TaskRes] = {}
|
|
443
|
+
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
|
444
|
+
"""Get reply Messages for the given Message IDs."""
|
|
445
|
+
ret: dict[UUID, Message] = {}
|
|
471
446
|
|
|
472
|
-
# Verify
|
|
447
|
+
# Verify Message IDs
|
|
473
448
|
current = time.time()
|
|
474
449
|
query = f"""
|
|
475
450
|
SELECT *
|
|
476
|
-
FROM
|
|
477
|
-
WHERE
|
|
451
|
+
FROM message_ins
|
|
452
|
+
WHERE message_id IN ({",".join(["?"] * len(message_ids))});
|
|
478
453
|
"""
|
|
479
|
-
rows = self.query(query, tuple(str(
|
|
480
|
-
|
|
454
|
+
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
455
|
+
found_message_ins_dict: dict[UUID, Message] = {}
|
|
481
456
|
for row in rows:
|
|
482
457
|
convert_sint64_values_in_dict_to_uint64(
|
|
483
|
-
row, ["run_id", "
|
|
458
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
484
459
|
)
|
|
485
|
-
|
|
460
|
+
found_message_ins_dict[UUID(row["message_id"])] = dict_to_message(row)
|
|
486
461
|
|
|
487
|
-
ret =
|
|
488
|
-
|
|
489
|
-
|
|
462
|
+
ret = verify_message_ids(
|
|
463
|
+
inquired_message_ids=message_ids,
|
|
464
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
490
465
|
current_time=current,
|
|
491
466
|
)
|
|
492
467
|
|
|
493
|
-
# Find all
|
|
468
|
+
# Find all reply Messages
|
|
494
469
|
query = f"""
|
|
495
470
|
SELECT *
|
|
496
|
-
FROM
|
|
497
|
-
WHERE
|
|
471
|
+
FROM message_res
|
|
472
|
+
WHERE reply_to_message IN ({",".join(["?"] * len(message_ids))})
|
|
498
473
|
AND delivered_at = "";
|
|
499
474
|
"""
|
|
500
|
-
rows = self.query(query, tuple(str(
|
|
475
|
+
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
501
476
|
for row in rows:
|
|
502
477
|
convert_sint64_values_in_dict_to_uint64(
|
|
503
|
-
row, ["run_id", "
|
|
478
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
504
479
|
)
|
|
505
|
-
tmp_ret_dict =
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
480
|
+
tmp_ret_dict = verify_found_message_replies(
|
|
481
|
+
inquired_message_ids=message_ids,
|
|
482
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
483
|
+
found_message_res_list=[dict_to_message(row) for row in rows],
|
|
509
484
|
current_time=current,
|
|
510
485
|
)
|
|
511
486
|
ret.update(tmp_ret_dict)
|
|
512
487
|
|
|
513
|
-
# Mark existing
|
|
488
|
+
# Mark existing reply Messages to be returned as delivered
|
|
514
489
|
delivered_at = now().isoformat()
|
|
515
|
-
for
|
|
516
|
-
|
|
517
|
-
|
|
490
|
+
for message_res in ret.values():
|
|
491
|
+
message_res.metadata.delivered_at = delivered_at
|
|
492
|
+
message_res_ids = [
|
|
493
|
+
message_res.metadata.message_id for message_res in ret.values()
|
|
494
|
+
]
|
|
518
495
|
query = f"""
|
|
519
|
-
UPDATE
|
|
496
|
+
UPDATE message_res
|
|
520
497
|
SET delivered_at = ?
|
|
521
|
-
WHERE
|
|
498
|
+
WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
|
|
522
499
|
"""
|
|
523
|
-
data: list[Any] = [delivered_at] +
|
|
500
|
+
data: list[Any] = [delivered_at] + message_res_ids
|
|
524
501
|
self.query(query, data)
|
|
525
502
|
|
|
526
503
|
return list(ret.values())
|
|
527
504
|
|
|
528
|
-
def
|
|
529
|
-
"""Calculate the number of
|
|
505
|
+
def num_message_ins(self) -> int:
|
|
506
|
+
"""Calculate the number of instruction Messages in store.
|
|
530
507
|
|
|
531
|
-
This includes delivered but not yet deleted
|
|
508
|
+
This includes delivered but not yet deleted.
|
|
532
509
|
"""
|
|
533
|
-
query = "SELECT count(*) AS num FROM
|
|
510
|
+
query = "SELECT count(*) AS num FROM message_ins;"
|
|
534
511
|
rows = self.query(query)
|
|
535
512
|
result = rows[0]
|
|
536
513
|
num = cast(int, result["num"])
|
|
537
514
|
return num
|
|
538
515
|
|
|
539
|
-
def
|
|
540
|
-
"""Calculate the number of
|
|
516
|
+
def num_message_res(self) -> int:
|
|
517
|
+
"""Calculate the number of reply Messages in store.
|
|
541
518
|
|
|
542
|
-
This includes delivered but not yet deleted
|
|
519
|
+
This includes delivered but not yet deleted.
|
|
543
520
|
"""
|
|
544
|
-
query = "SELECT count(*) AS num FROM
|
|
521
|
+
query = "SELECT count(*) AS num FROM message_res;"
|
|
545
522
|
rows = self.query(query)
|
|
546
523
|
result: dict[str, int] = rows[0]
|
|
547
524
|
return result["num"]
|
|
548
525
|
|
|
549
|
-
def
|
|
550
|
-
"""Delete
|
|
551
|
-
if not
|
|
526
|
+
def delete_messages(self, message_ins_ids: set[UUID]) -> None:
|
|
527
|
+
"""Delete a Message and its reply based on provided Message IDs."""
|
|
528
|
+
if not message_ins_ids:
|
|
552
529
|
return
|
|
553
530
|
if self.conn is None:
|
|
554
531
|
raise AttributeError("LinkState not initialized")
|
|
555
532
|
|
|
556
|
-
placeholders = ",".join(["?"] * len(
|
|
557
|
-
data = tuple(str(
|
|
533
|
+
placeholders = ",".join(["?"] * len(message_ins_ids))
|
|
534
|
+
data = tuple(str(message_id) for message_id in message_ins_ids)
|
|
558
535
|
|
|
559
|
-
# Delete
|
|
536
|
+
# Delete Message
|
|
560
537
|
query_1 = f"""
|
|
561
|
-
DELETE FROM
|
|
562
|
-
WHERE
|
|
538
|
+
DELETE FROM message_ins
|
|
539
|
+
WHERE message_id IN ({placeholders});
|
|
563
540
|
"""
|
|
564
541
|
|
|
565
|
-
# Delete
|
|
542
|
+
# Delete reply Message
|
|
566
543
|
query_2 = f"""
|
|
567
|
-
DELETE FROM
|
|
568
|
-
WHERE
|
|
544
|
+
DELETE FROM message_res
|
|
545
|
+
WHERE reply_to_message IN ({placeholders});
|
|
569
546
|
"""
|
|
570
547
|
|
|
571
548
|
with self.conn:
|
|
572
549
|
self.conn.execute(query_1, data)
|
|
573
550
|
self.conn.execute(query_2, data)
|
|
574
551
|
|
|
575
|
-
def
|
|
576
|
-
"""Get all
|
|
552
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
553
|
+
"""Get all instruction Message IDs for the given run_id."""
|
|
577
554
|
if self.conn is None:
|
|
578
555
|
raise AttributeError("LinkState not initialized")
|
|
579
556
|
|
|
580
557
|
query = """
|
|
581
|
-
SELECT
|
|
582
|
-
FROM
|
|
558
|
+
SELECT message_id
|
|
559
|
+
FROM message_ins
|
|
583
560
|
WHERE run_id = :run_id;
|
|
584
561
|
"""
|
|
585
562
|
|
|
@@ -589,7 +566,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
589
566
|
with self.conn:
|
|
590
567
|
rows = self.conn.execute(query, data).fetchall()
|
|
591
568
|
|
|
592
|
-
return {UUID(row["
|
|
569
|
+
return {UUID(row["message_id"]) for row in rows}
|
|
593
570
|
|
|
594
571
|
def create_node(self, ping_interval: float) -> int:
|
|
595
572
|
"""Create, store in the link state, and return `node_id`."""
|
|
@@ -1001,32 +978,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1001
978
|
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1002
979
|
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1003
980
|
|
|
1004
|
-
def
|
|
1005
|
-
"""Check if the
|
|
981
|
+
def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
|
|
982
|
+
"""Check if the Message exists and is valid (not expired).
|
|
1006
983
|
|
|
1007
|
-
Return
|
|
984
|
+
Return Message if valid.
|
|
1008
985
|
"""
|
|
1009
986
|
query = """
|
|
1010
987
|
SELECT *
|
|
1011
|
-
FROM
|
|
1012
|
-
WHERE
|
|
988
|
+
FROM message_ins
|
|
989
|
+
WHERE message_id = :message_id
|
|
1013
990
|
"""
|
|
1014
|
-
data = {"
|
|
991
|
+
data = {"message_id": message_id}
|
|
1015
992
|
rows = self.query(query, data)
|
|
1016
993
|
if not rows:
|
|
1017
|
-
#
|
|
994
|
+
# Message does not exist
|
|
1018
995
|
return None
|
|
1019
996
|
|
|
1020
|
-
|
|
1021
|
-
created_at =
|
|
1022
|
-
ttl =
|
|
997
|
+
message_ins = rows[0]
|
|
998
|
+
created_at = message_ins["created_at"]
|
|
999
|
+
ttl = message_ins["ttl"]
|
|
1023
1000
|
current_time = time.time()
|
|
1024
1001
|
|
|
1025
|
-
# Check if
|
|
1002
|
+
# Check if Message is expired
|
|
1026
1003
|
if ttl is not None and created_at + ttl <= current_time:
|
|
1027
1004
|
return None
|
|
1028
1005
|
|
|
1029
|
-
return
|
|
1006
|
+
return message_ins
|
|
1030
1007
|
|
|
1031
1008
|
|
|
1032
1009
|
def dict_factory(
|
|
@@ -1041,94 +1018,51 @@ def dict_factory(
|
|
|
1041
1018
|
return dict(zip(fields, row))
|
|
1042
1019
|
|
|
1043
1020
|
|
|
1044
|
-
def
|
|
1045
|
-
"""Transform
|
|
1021
|
+
def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1022
|
+
"""Transform Message to dict."""
|
|
1046
1023
|
result = {
|
|
1047
|
-
"
|
|
1048
|
-
"group_id":
|
|
1049
|
-
"run_id":
|
|
1050
|
-
"
|
|
1051
|
-
"
|
|
1052
|
-
"
|
|
1053
|
-
"
|
|
1054
|
-
"
|
|
1055
|
-
"
|
|
1056
|
-
"
|
|
1057
|
-
"
|
|
1024
|
+
"message_id": message.metadata.message_id,
|
|
1025
|
+
"group_id": message.metadata.group_id,
|
|
1026
|
+
"run_id": message.metadata.run_id,
|
|
1027
|
+
"src_node_id": message.metadata.src_node_id,
|
|
1028
|
+
"dst_node_id": message.metadata.dst_node_id,
|
|
1029
|
+
"reply_to_message": message.metadata.reply_to_message,
|
|
1030
|
+
"created_at": message.metadata.created_at,
|
|
1031
|
+
"delivered_at": message.metadata.delivered_at,
|
|
1032
|
+
"ttl": message.metadata.ttl,
|
|
1033
|
+
"message_type": message.metadata.message_type,
|
|
1034
|
+
"content": None,
|
|
1035
|
+
"error": None,
|
|
1058
1036
|
}
|
|
1059
|
-
return result
|
|
1060
1037
|
|
|
1038
|
+
if message.has_content():
|
|
1039
|
+
result["content"] = recordset_to_proto(message.content).SerializeToString()
|
|
1040
|
+
else:
|
|
1041
|
+
result["error"] = error_to_proto(message.error).SerializeToString()
|
|
1061
1042
|
|
|
1062
|
-
def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
1063
|
-
"""Transform TaskRes to dict."""
|
|
1064
|
-
result = {
|
|
1065
|
-
"task_id": task_msg.task_id,
|
|
1066
|
-
"group_id": task_msg.group_id,
|
|
1067
|
-
"run_id": task_msg.run_id,
|
|
1068
|
-
"producer_node_id": task_msg.task.producer.node_id,
|
|
1069
|
-
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
1070
|
-
"created_at": task_msg.task.created_at,
|
|
1071
|
-
"delivered_at": task_msg.task.delivered_at,
|
|
1072
|
-
"ttl": task_msg.task.ttl,
|
|
1073
|
-
"ancestry": ",".join(task_msg.task.ancestry),
|
|
1074
|
-
"task_type": task_msg.task.task_type,
|
|
1075
|
-
"recordset": task_msg.task.recordset.SerializeToString(),
|
|
1076
|
-
}
|
|
1077
1043
|
return result
|
|
1078
1044
|
|
|
1079
1045
|
|
|
1080
|
-
def
|
|
1081
|
-
"""
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
)
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
),
|
|
1096
|
-
created_at=task_dict["created_at"],
|
|
1097
|
-
delivered_at=task_dict["delivered_at"],
|
|
1098
|
-
ttl=task_dict["ttl"],
|
|
1099
|
-
ancestry=task_dict["ancestry"].split(","),
|
|
1100
|
-
task_type=task_dict["task_type"],
|
|
1101
|
-
recordset=recordset,
|
|
1102
|
-
),
|
|
1103
|
-
)
|
|
1104
|
-
return result
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
1108
|
-
"""Turn task_dict into protobuf message."""
|
|
1109
|
-
recordset = ProtoRecordSet()
|
|
1110
|
-
recordset.ParseFromString(task_dict["recordset"])
|
|
1111
|
-
|
|
1112
|
-
result = TaskRes(
|
|
1113
|
-
task_id=task_dict["task_id"],
|
|
1114
|
-
group_id=task_dict["group_id"],
|
|
1115
|
-
run_id=task_dict["run_id"],
|
|
1116
|
-
task=Task(
|
|
1117
|
-
producer=Node(
|
|
1118
|
-
node_id=task_dict["producer_node_id"],
|
|
1119
|
-
),
|
|
1120
|
-
consumer=Node(
|
|
1121
|
-
node_id=task_dict["consumer_node_id"],
|
|
1122
|
-
),
|
|
1123
|
-
created_at=task_dict["created_at"],
|
|
1124
|
-
delivered_at=task_dict["delivered_at"],
|
|
1125
|
-
ttl=task_dict["ttl"],
|
|
1126
|
-
ancestry=task_dict["ancestry"].split(","),
|
|
1127
|
-
task_type=task_dict["task_type"],
|
|
1128
|
-
recordset=recordset,
|
|
1129
|
-
),
|
|
1046
|
+
def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
|
1047
|
+
"""Transform dict to Message."""
|
|
1048
|
+
content, error = None, None
|
|
1049
|
+
if (b_content := message_dict.pop("content")) is not None:
|
|
1050
|
+
content = recordset_from_proto(ProtoRecordSet.FromString(b_content))
|
|
1051
|
+
if (b_error := message_dict.pop("error")) is not None:
|
|
1052
|
+
error = error_from_proto(ProtoError.FromString(b_error))
|
|
1053
|
+
|
|
1054
|
+
# Metadata constructor doesn't allow passing created_at. We set it later
|
|
1055
|
+
metadata = Metadata(
|
|
1056
|
+
**{
|
|
1057
|
+
k: v
|
|
1058
|
+
for k, v in message_dict.items()
|
|
1059
|
+
if k not in ["created_at", "delivered_at"]
|
|
1060
|
+
}
|
|
1130
1061
|
)
|
|
1131
|
-
|
|
1062
|
+
msg = Message(metadata=metadata, content=content, error=error)
|
|
1063
|
+
msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
|
|
1064
|
+
msg.metadata.delivered_at = message_dict["delivered_at"]
|
|
1065
|
+
return msg
|
|
1132
1066
|
|
|
1133
1067
|
|
|
1134
1068
|
def determine_run_status(row: dict[str, Any]) -> str:
|