flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
|
@@ -26,29 +26,37 @@ from logging import DEBUG, ERROR, WARNING
|
|
|
26
26
|
from typing import Any, Optional, Union, cast
|
|
27
27
|
from uuid import UUID, uuid4
|
|
28
28
|
|
|
29
|
-
from flwr.common import Context, log, now
|
|
29
|
+
from flwr.common import Context, Message, Metadata, log, now
|
|
30
30
|
from flwr.common.constant import (
|
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
|
32
32
|
NODE_ID_NUM_BYTES,
|
|
33
|
+
PING_PATIENCE,
|
|
33
34
|
RUN_ID_NUM_BYTES,
|
|
34
35
|
SUPERLINK_NODE_ID,
|
|
35
36
|
Status,
|
|
36
37
|
)
|
|
37
|
-
from flwr.common.
|
|
38
|
+
from flwr.common.message import make_message
|
|
39
|
+
from flwr.common.record import ConfigRecord
|
|
40
|
+
from flwr.common.serde import (
|
|
41
|
+
error_from_proto,
|
|
42
|
+
error_to_proto,
|
|
43
|
+
recorddict_from_proto,
|
|
44
|
+
recorddict_to_proto,
|
|
45
|
+
)
|
|
38
46
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
39
47
|
|
|
40
48
|
# pylint: disable=E0611
|
|
41
|
-
from flwr.proto.
|
|
42
|
-
from flwr.proto.
|
|
43
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
49
|
+
from flwr.proto.error_pb2 import Error as ProtoError
|
|
50
|
+
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
44
51
|
|
|
45
52
|
# pylint: enable=E0611
|
|
46
|
-
from flwr.server.utils.validator import
|
|
53
|
+
from flwr.server.utils.validator import validate_message
|
|
47
54
|
|
|
48
55
|
from .linkstate import LinkState
|
|
49
56
|
from .utils import (
|
|
50
|
-
|
|
51
|
-
|
|
57
|
+
check_node_availability_for_in_message,
|
|
58
|
+
configrecord_from_bytes,
|
|
59
|
+
configrecord_to_bytes,
|
|
52
60
|
context_from_bytes,
|
|
53
61
|
context_to_bytes,
|
|
54
62
|
convert_sint64_to_uint64,
|
|
@@ -58,8 +66,8 @@ from .utils import (
|
|
|
58
66
|
generate_rand_int_from_bytes,
|
|
59
67
|
has_valid_sub_status,
|
|
60
68
|
is_valid_transition,
|
|
61
|
-
|
|
62
|
-
|
|
69
|
+
verify_found_message_replies,
|
|
70
|
+
verify_message_ids,
|
|
63
71
|
)
|
|
64
72
|
|
|
65
73
|
SQL_CREATE_TABLE_NODE = """
|
|
@@ -117,36 +125,39 @@ CREATE TABLE IF NOT EXISTS context(
|
|
|
117
125
|
);
|
|
118
126
|
"""
|
|
119
127
|
|
|
120
|
-
|
|
121
|
-
CREATE TABLE IF NOT EXISTS
|
|
122
|
-
|
|
128
|
+
SQL_CREATE_TABLE_MESSAGE_INS = """
|
|
129
|
+
CREATE TABLE IF NOT EXISTS message_ins(
|
|
130
|
+
message_id TEXT UNIQUE,
|
|
123
131
|
group_id TEXT,
|
|
124
132
|
run_id INTEGER,
|
|
125
|
-
|
|
126
|
-
|
|
133
|
+
src_node_id INTEGER,
|
|
134
|
+
dst_node_id INTEGER,
|
|
135
|
+
reply_to_message_id TEXT,
|
|
127
136
|
created_at REAL,
|
|
128
137
|
delivered_at TEXT,
|
|
129
138
|
ttl REAL,
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
139
|
+
message_type TEXT,
|
|
140
|
+
content BLOB NULL,
|
|
141
|
+
error BLOB NULL,
|
|
133
142
|
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
134
143
|
);
|
|
135
144
|
"""
|
|
136
145
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
146
|
+
|
|
147
|
+
SQL_CREATE_TABLE_MESSAGE_RES = """
|
|
148
|
+
CREATE TABLE IF NOT EXISTS message_res(
|
|
149
|
+
message_id TEXT UNIQUE,
|
|
140
150
|
group_id TEXT,
|
|
141
151
|
run_id INTEGER,
|
|
142
|
-
|
|
143
|
-
|
|
152
|
+
src_node_id INTEGER,
|
|
153
|
+
dst_node_id INTEGER,
|
|
154
|
+
reply_to_message_id TEXT,
|
|
144
155
|
created_at REAL,
|
|
145
156
|
delivered_at TEXT,
|
|
146
157
|
ttl REAL,
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
158
|
+
message_type TEXT,
|
|
159
|
+
content BLOB NULL,
|
|
160
|
+
error BLOB NULL,
|
|
150
161
|
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
151
162
|
);
|
|
152
163
|
"""
|
|
@@ -196,8 +207,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
196
207
|
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
197
208
|
cur.execute(SQL_CREATE_TABLE_LOGS)
|
|
198
209
|
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
199
|
-
cur.execute(
|
|
200
|
-
cur.execute(
|
|
210
|
+
cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
|
|
211
|
+
cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
|
|
201
212
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
202
213
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
203
214
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
@@ -239,88 +250,62 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
239
250
|
|
|
240
251
|
return result
|
|
241
252
|
|
|
242
|
-
def
|
|
243
|
-
"""Store one
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
Stores the value of the task_ins in the link state and, if successful,
|
|
248
|
-
returns the task_id (UUID) of the task_ins. If, for any reason, storing
|
|
249
|
-
the task_ins fails, `None` is returned.
|
|
250
|
-
|
|
251
|
-
Constraints
|
|
252
|
-
-----------
|
|
253
|
-
|
|
254
|
-
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
255
|
-
"""
|
|
256
|
-
# Validate task
|
|
257
|
-
errors = validate_task_ins_or_res(task_ins)
|
|
253
|
+
def store_message_ins(self, message: Message) -> Optional[UUID]:
|
|
254
|
+
"""Store one Message."""
|
|
255
|
+
# Validate message
|
|
256
|
+
errors = validate_message(message=message, is_reply_message=False)
|
|
258
257
|
if any(errors):
|
|
259
258
|
log(ERROR, errors)
|
|
260
259
|
return None
|
|
261
|
-
# Create
|
|
262
|
-
|
|
260
|
+
# Create message_id
|
|
261
|
+
message_id = uuid4()
|
|
263
262
|
|
|
264
|
-
# Store
|
|
265
|
-
|
|
266
|
-
|
|
263
|
+
# Store Message
|
|
264
|
+
# pylint: disable-next=W0212
|
|
265
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
|
266
|
+
data = (message_to_dict(message),)
|
|
267
267
|
|
|
268
268
|
# Convert values from uint64 to sint64 for SQLite
|
|
269
269
|
convert_uint64_values_in_dict_to_sint64(
|
|
270
|
-
data[0], ["run_id", "
|
|
270
|
+
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
271
271
|
)
|
|
272
272
|
|
|
273
273
|
# Validate run_id
|
|
274
274
|
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
275
275
|
if not self.query(query, (data[0]["run_id"],)):
|
|
276
|
-
log(ERROR, "Invalid run ID for
|
|
276
|
+
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
277
277
|
return None
|
|
278
|
+
|
|
278
279
|
# Validate source node ID
|
|
279
|
-
if
|
|
280
|
+
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
280
281
|
log(
|
|
281
282
|
ERROR,
|
|
282
|
-
"Invalid source node ID for
|
|
283
|
-
|
|
283
|
+
"Invalid source node ID for Message: %s",
|
|
284
|
+
message.metadata.src_node_id,
|
|
284
285
|
)
|
|
285
286
|
return None
|
|
287
|
+
|
|
286
288
|
# Validate destination node ID
|
|
287
289
|
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
|
288
|
-
if not self.query(query, (data[0]["
|
|
290
|
+
if not self.query(query, (data[0]["dst_node_id"],)):
|
|
289
291
|
log(
|
|
290
292
|
ERROR,
|
|
291
|
-
"Invalid destination node ID for
|
|
292
|
-
|
|
293
|
+
"Invalid destination node ID for Message: %s",
|
|
294
|
+
message.metadata.dst_node_id,
|
|
293
295
|
)
|
|
294
296
|
return None
|
|
295
297
|
|
|
296
298
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
297
|
-
query = f"INSERT INTO
|
|
299
|
+
query = f"INSERT INTO message_ins VALUES({columns});"
|
|
298
300
|
|
|
299
301
|
# Only invalid run_id can trigger IntegrityError.
|
|
300
302
|
# This may need to be changed in the future version with more integrity checks.
|
|
301
303
|
self.query(query, data)
|
|
302
304
|
|
|
303
|
-
return
|
|
304
|
-
|
|
305
|
-
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
|
306
|
-
"""Get undelivered TaskIns for one node.
|
|
307
|
-
|
|
308
|
-
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
|
309
|
-
TaskIns.
|
|
305
|
+
return message_id
|
|
310
306
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
Retrieve all TaskIns where
|
|
314
|
-
|
|
315
|
-
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
|
316
|
-
2. the `task_ins.task.delivered_at` equals `""`.
|
|
317
|
-
|
|
318
|
-
`delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
|
|
319
|
-
the result.
|
|
320
|
-
|
|
321
|
-
If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
|
|
322
|
-
`limit` is set, it has to be greater than zero.
|
|
323
|
-
"""
|
|
307
|
+
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
|
308
|
+
"""Get all Messages that have not been delivered yet."""
|
|
324
309
|
if limit is not None and limit < 1:
|
|
325
310
|
raise AssertionError("`limit` must be >= 1")
|
|
326
311
|
|
|
@@ -333,11 +318,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
333
318
|
# Convert the uint64 value to sint64 for SQLite
|
|
334
319
|
data["node_id"] = convert_uint64_to_sint64(node_id)
|
|
335
320
|
|
|
336
|
-
# Retrieve all
|
|
321
|
+
# Retrieve all Messages for node_id
|
|
337
322
|
query = """
|
|
338
|
-
SELECT
|
|
339
|
-
FROM
|
|
340
|
-
WHERE
|
|
323
|
+
SELECT message_id
|
|
324
|
+
FROM message_ins
|
|
325
|
+
WHERE dst_node_id == :node_id
|
|
341
326
|
AND delivered_at = ""
|
|
342
327
|
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
343
328
|
"""
|
|
@@ -352,20 +337,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
352
337
|
|
|
353
338
|
if rows:
|
|
354
339
|
# Prepare query
|
|
355
|
-
|
|
356
|
-
placeholders: str = ",".join([f":id_{i}" for i in range(len(
|
|
340
|
+
message_ids = [row["message_id"] for row in rows]
|
|
341
|
+
placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
|
|
357
342
|
query = f"""
|
|
358
|
-
UPDATE
|
|
343
|
+
UPDATE message_ins
|
|
359
344
|
SET delivered_at = :delivered_at
|
|
360
|
-
WHERE
|
|
345
|
+
WHERE message_id IN ({placeholders})
|
|
361
346
|
RETURNING *;
|
|
362
347
|
"""
|
|
363
348
|
|
|
364
349
|
# Prepare data for query
|
|
365
350
|
delivered_at = now().isoformat()
|
|
366
351
|
data = {"delivered_at": delivered_at}
|
|
367
|
-
for index,
|
|
368
|
-
data[f"id_{index}"] = str(
|
|
352
|
+
for index, msg_id in enumerate(message_ids):
|
|
353
|
+
data[f"id_{index}"] = str(msg_id)
|
|
369
354
|
|
|
370
355
|
# Run query
|
|
371
356
|
rows = self.query(query, data)
|
|
@@ -373,86 +358,80 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
373
358
|
for row in rows:
|
|
374
359
|
# Convert values from sint64 to uint64
|
|
375
360
|
convert_sint64_values_in_dict_to_uint64(
|
|
376
|
-
row, ["run_id", "
|
|
361
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
377
362
|
)
|
|
378
363
|
|
|
379
|
-
result = [
|
|
364
|
+
result = [dict_to_message(row) for row in rows]
|
|
380
365
|
|
|
381
366
|
return result
|
|
382
367
|
|
|
383
|
-
def
|
|
384
|
-
"""Store one
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
|
|
389
|
-
the `task_res`. If storing the `task_res` fails, `None` is returned.
|
|
390
|
-
|
|
391
|
-
Constraints
|
|
392
|
-
-----------
|
|
393
|
-
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
394
|
-
"""
|
|
395
|
-
# Validate task
|
|
396
|
-
errors = validate_task_ins_or_res(task_res)
|
|
368
|
+
def store_message_res(self, message: Message) -> Optional[UUID]:
|
|
369
|
+
"""Store one Message."""
|
|
370
|
+
# Validate message
|
|
371
|
+
errors = validate_message(message=message, is_reply_message=True)
|
|
397
372
|
if any(errors):
|
|
398
373
|
log(ERROR, errors)
|
|
399
374
|
return None
|
|
400
375
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
task_ins = self.get_valid_task_ins(task_ins_id)
|
|
406
|
-
if task_ins is None:
|
|
376
|
+
res_metadata = message.metadata
|
|
377
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
|
378
|
+
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
|
379
|
+
if msg_ins is None:
|
|
407
380
|
log(
|
|
408
381
|
ERROR,
|
|
409
|
-
"Failed to store
|
|
410
|
-
"
|
|
411
|
-
|
|
382
|
+
"Failed to store Message reply: "
|
|
383
|
+
"The message it replies to with message_id %s does not exist or "
|
|
384
|
+
"has expired.",
|
|
385
|
+
msg_ins_id,
|
|
412
386
|
)
|
|
413
387
|
return None
|
|
414
388
|
|
|
415
|
-
# Ensure that the
|
|
389
|
+
# Ensure that the dst_node_id of the original message matches the src_node_id of
|
|
390
|
+
# reply being processed.
|
|
416
391
|
if (
|
|
417
|
-
|
|
418
|
-
and
|
|
419
|
-
and convert_sint64_to_uint64(
|
|
420
|
-
!=
|
|
392
|
+
msg_ins
|
|
393
|
+
and message
|
|
394
|
+
and convert_sint64_to_uint64(msg_ins["dst_node_id"])
|
|
395
|
+
!= res_metadata.src_node_id
|
|
421
396
|
):
|
|
422
397
|
return None
|
|
423
398
|
|
|
424
|
-
# Fail if the
|
|
425
|
-
# expiration time of the
|
|
426
|
-
# Condition:
|
|
427
|
-
#
|
|
399
|
+
# Fail if the Message TTL exceeds the
|
|
400
|
+
# expiration time of the Message it replies to.
|
|
401
|
+
# Condition: ins_metadata.created_at + ins_metadata.ttl ≥
|
|
402
|
+
# res_metadata.created_at + res_metadata.ttl
|
|
428
403
|
# A small tolerance is introduced to account
|
|
429
404
|
# for floating-point precision issues.
|
|
430
405
|
max_allowed_ttl = (
|
|
431
|
-
|
|
406
|
+
msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
|
|
432
407
|
)
|
|
433
|
-
if
|
|
434
|
-
|
|
408
|
+
if res_metadata.ttl and (
|
|
409
|
+
res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
|
|
435
410
|
):
|
|
436
411
|
log(
|
|
437
412
|
WARNING,
|
|
438
|
-
"Received
|
|
439
|
-
"
|
|
440
|
-
|
|
413
|
+
"Received Message with TTL %.2f exceeding the allowed maximum "
|
|
414
|
+
"TTL %.2f.",
|
|
415
|
+
res_metadata.ttl,
|
|
441
416
|
max_allowed_ttl,
|
|
442
417
|
)
|
|
443
418
|
return None
|
|
444
419
|
|
|
445
|
-
#
|
|
446
|
-
|
|
447
|
-
|
|
420
|
+
# Create message_id
|
|
421
|
+
message_id = uuid4()
|
|
422
|
+
|
|
423
|
+
# Store Message
|
|
424
|
+
# pylint: disable-next=W0212
|
|
425
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
|
426
|
+
data = (message_to_dict(message),)
|
|
448
427
|
|
|
449
428
|
# Convert values from uint64 to sint64 for SQLite
|
|
450
429
|
convert_uint64_values_in_dict_to_sint64(
|
|
451
|
-
data[0], ["run_id", "
|
|
430
|
+
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
452
431
|
)
|
|
453
432
|
|
|
454
433
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
455
|
-
query = f"INSERT INTO
|
|
434
|
+
query = f"INSERT INTO message_res VALUES({columns});"
|
|
456
435
|
|
|
457
436
|
# Only invalid run_id can trigger IntegrityError.
|
|
458
437
|
# This may need to be changed in the future version with more integrity checks.
|
|
@@ -462,124 +441,149 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
462
441
|
log(ERROR, "`run` is invalid")
|
|
463
442
|
return None
|
|
464
443
|
|
|
465
|
-
return
|
|
444
|
+
return message_id
|
|
466
445
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
ret: dict[UUID,
|
|
446
|
+
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
|
447
|
+
"""Get reply Messages for the given Message IDs."""
|
|
448
|
+
# pylint: disable-msg=too-many-locals
|
|
449
|
+
ret: dict[UUID, Message] = {}
|
|
471
450
|
|
|
472
|
-
# Verify
|
|
451
|
+
# Verify Message IDs
|
|
473
452
|
current = time.time()
|
|
474
453
|
query = f"""
|
|
475
454
|
SELECT *
|
|
476
|
-
FROM
|
|
477
|
-
WHERE
|
|
455
|
+
FROM message_ins
|
|
456
|
+
WHERE message_id IN ({",".join(["?"] * len(message_ids))});
|
|
478
457
|
"""
|
|
479
|
-
rows = self.query(query, tuple(str(
|
|
480
|
-
|
|
458
|
+
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
459
|
+
found_message_ins_dict: dict[UUID, Message] = {}
|
|
481
460
|
for row in rows:
|
|
482
461
|
convert_sint64_values_in_dict_to_uint64(
|
|
483
|
-
row, ["run_id", "
|
|
462
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
484
463
|
)
|
|
485
|
-
|
|
464
|
+
found_message_ins_dict[UUID(row["message_id"])] = dict_to_message(row)
|
|
465
|
+
|
|
466
|
+
ret = verify_message_ids(
|
|
467
|
+
inquired_message_ids=message_ids,
|
|
468
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
469
|
+
current_time=current,
|
|
470
|
+
)
|
|
486
471
|
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
472
|
+
# Check node availability
|
|
473
|
+
dst_node_ids: set[int] = set()
|
|
474
|
+
for message_id in message_ids:
|
|
475
|
+
in_message = found_message_ins_dict[message_id]
|
|
476
|
+
sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
|
|
477
|
+
dst_node_ids.add(sint_node_id)
|
|
478
|
+
query = f"""
|
|
479
|
+
SELECT node_id, online_until
|
|
480
|
+
FROM node
|
|
481
|
+
WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
|
|
482
|
+
"""
|
|
483
|
+
rows = self.query(query, tuple(dst_node_ids))
|
|
484
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
|
485
|
+
inquired_in_message_ids=message_ids,
|
|
486
|
+
found_in_message_dict=found_message_ins_dict,
|
|
487
|
+
node_id_to_online_until={
|
|
488
|
+
convert_sint64_to_uint64(row["node_id"]): row["online_until"]
|
|
489
|
+
for row in rows
|
|
490
|
+
},
|
|
490
491
|
current_time=current,
|
|
491
492
|
)
|
|
493
|
+
ret.update(tmp_ret_dict)
|
|
492
494
|
|
|
493
|
-
# Find all
|
|
495
|
+
# Find all reply Messages
|
|
494
496
|
query = f"""
|
|
495
497
|
SELECT *
|
|
496
|
-
FROM
|
|
497
|
-
WHERE
|
|
498
|
+
FROM message_res
|
|
499
|
+
WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
|
|
498
500
|
AND delivered_at = "";
|
|
499
501
|
"""
|
|
500
|
-
rows = self.query(query, tuple(str(
|
|
502
|
+
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
501
503
|
for row in rows:
|
|
502
504
|
convert_sint64_values_in_dict_to_uint64(
|
|
503
|
-
row, ["run_id", "
|
|
505
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
504
506
|
)
|
|
505
|
-
tmp_ret_dict =
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
507
|
+
tmp_ret_dict = verify_found_message_replies(
|
|
508
|
+
inquired_message_ids=message_ids,
|
|
509
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
510
|
+
found_message_res_list=[dict_to_message(row) for row in rows],
|
|
509
511
|
current_time=current,
|
|
510
512
|
)
|
|
511
513
|
ret.update(tmp_ret_dict)
|
|
512
514
|
|
|
513
|
-
# Mark existing
|
|
515
|
+
# Mark existing reply Messages to be returned as delivered
|
|
514
516
|
delivered_at = now().isoformat()
|
|
515
|
-
for
|
|
516
|
-
|
|
517
|
-
|
|
517
|
+
for message_res in ret.values():
|
|
518
|
+
message_res.metadata.delivered_at = delivered_at
|
|
519
|
+
message_res_ids = [
|
|
520
|
+
message_res.metadata.message_id for message_res in ret.values()
|
|
521
|
+
]
|
|
518
522
|
query = f"""
|
|
519
|
-
UPDATE
|
|
523
|
+
UPDATE message_res
|
|
520
524
|
SET delivered_at = ?
|
|
521
|
-
WHERE
|
|
525
|
+
WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
|
|
522
526
|
"""
|
|
523
|
-
data: list[Any] = [delivered_at] +
|
|
527
|
+
data: list[Any] = [delivered_at] + message_res_ids
|
|
524
528
|
self.query(query, data)
|
|
525
529
|
|
|
526
530
|
return list(ret.values())
|
|
527
531
|
|
|
528
|
-
def
|
|
529
|
-
"""Calculate the number of
|
|
532
|
+
def num_message_ins(self) -> int:
|
|
533
|
+
"""Calculate the number of instruction Messages in store.
|
|
530
534
|
|
|
531
|
-
This includes delivered but not yet deleted
|
|
535
|
+
This includes delivered but not yet deleted.
|
|
532
536
|
"""
|
|
533
|
-
query = "SELECT count(*) AS num FROM
|
|
537
|
+
query = "SELECT count(*) AS num FROM message_ins;"
|
|
534
538
|
rows = self.query(query)
|
|
535
539
|
result = rows[0]
|
|
536
540
|
num = cast(int, result["num"])
|
|
537
541
|
return num
|
|
538
542
|
|
|
539
|
-
def
|
|
540
|
-
"""Calculate the number of
|
|
543
|
+
def num_message_res(self) -> int:
|
|
544
|
+
"""Calculate the number of reply Messages in store.
|
|
541
545
|
|
|
542
|
-
This includes delivered but not yet deleted
|
|
546
|
+
This includes delivered but not yet deleted.
|
|
543
547
|
"""
|
|
544
|
-
query = "SELECT count(*) AS num FROM
|
|
548
|
+
query = "SELECT count(*) AS num FROM message_res;"
|
|
545
549
|
rows = self.query(query)
|
|
546
550
|
result: dict[str, int] = rows[0]
|
|
547
551
|
return result["num"]
|
|
548
552
|
|
|
549
|
-
def
|
|
550
|
-
"""Delete
|
|
551
|
-
if not
|
|
553
|
+
def delete_messages(self, message_ins_ids: set[UUID]) -> None:
|
|
554
|
+
"""Delete a Message and its reply based on provided Message IDs."""
|
|
555
|
+
if not message_ins_ids:
|
|
552
556
|
return
|
|
553
557
|
if self.conn is None:
|
|
554
558
|
raise AttributeError("LinkState not initialized")
|
|
555
559
|
|
|
556
|
-
placeholders = ",".join(["?"] * len(
|
|
557
|
-
data = tuple(str(
|
|
560
|
+
placeholders = ",".join(["?"] * len(message_ins_ids))
|
|
561
|
+
data = tuple(str(message_id) for message_id in message_ins_ids)
|
|
558
562
|
|
|
559
|
-
# Delete
|
|
563
|
+
# Delete Message
|
|
560
564
|
query_1 = f"""
|
|
561
|
-
DELETE FROM
|
|
562
|
-
WHERE
|
|
565
|
+
DELETE FROM message_ins
|
|
566
|
+
WHERE message_id IN ({placeholders});
|
|
563
567
|
"""
|
|
564
568
|
|
|
565
|
-
# Delete
|
|
569
|
+
# Delete reply Message
|
|
566
570
|
query_2 = f"""
|
|
567
|
-
DELETE FROM
|
|
568
|
-
WHERE
|
|
571
|
+
DELETE FROM message_res
|
|
572
|
+
WHERE reply_to_message_id IN ({placeholders});
|
|
569
573
|
"""
|
|
570
574
|
|
|
571
575
|
with self.conn:
|
|
572
576
|
self.conn.execute(query_1, data)
|
|
573
577
|
self.conn.execute(query_2, data)
|
|
574
578
|
|
|
575
|
-
def
|
|
576
|
-
"""Get all
|
|
579
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
580
|
+
"""Get all instruction Message IDs for the given run_id."""
|
|
577
581
|
if self.conn is None:
|
|
578
582
|
raise AttributeError("LinkState not initialized")
|
|
579
583
|
|
|
580
584
|
query = """
|
|
581
|
-
SELECT
|
|
582
|
-
FROM
|
|
585
|
+
SELECT message_id
|
|
586
|
+
FROM message_ins
|
|
583
587
|
WHERE run_id = :run_id;
|
|
584
588
|
"""
|
|
585
589
|
|
|
@@ -589,7 +593,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
589
593
|
with self.conn:
|
|
590
594
|
rows = self.conn.execute(query, data).fetchall()
|
|
591
595
|
|
|
592
|
-
return {UUID(row["
|
|
596
|
+
return {UUID(row["message_id"]) for row in rows}
|
|
593
597
|
|
|
594
598
|
def create_node(self, ping_interval: float) -> int:
|
|
595
599
|
"""Create, store in the link state, and return `node_id`."""
|
|
@@ -607,6 +611,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
607
611
|
"VALUES (?, ?, ?, ?)"
|
|
608
612
|
)
|
|
609
613
|
|
|
614
|
+
# Mark the node online util time.time() + ping_interval
|
|
610
615
|
try:
|
|
611
616
|
self.query(
|
|
612
617
|
query,
|
|
@@ -722,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
722
727
|
fab_version: Optional[str],
|
|
723
728
|
fab_hash: Optional[str],
|
|
724
729
|
override_config: UserConfig,
|
|
725
|
-
federation_options:
|
|
730
|
+
federation_options: ConfigRecord,
|
|
726
731
|
) -> int:
|
|
727
732
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
728
733
|
# Sample a random int64 as run_id
|
|
@@ -748,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
748
753
|
fab_version,
|
|
749
754
|
fab_hash,
|
|
750
755
|
override_config_json,
|
|
751
|
-
|
|
756
|
+
configrecord_to_bytes(federation_options),
|
|
752
757
|
]
|
|
753
758
|
data += [
|
|
754
759
|
now().isoformat(),
|
|
@@ -906,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
906
911
|
|
|
907
912
|
return pending_run_id
|
|
908
913
|
|
|
909
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
|
914
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
|
910
915
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
911
916
|
# Convert the uint64 value to sint64 for SQLite
|
|
912
917
|
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
@@ -919,10 +924,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
919
924
|
return None
|
|
920
925
|
|
|
921
926
|
row = rows[0]
|
|
922
|
-
return
|
|
927
|
+
return configrecord_from_bytes(row["federation_options"])
|
|
923
928
|
|
|
924
929
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
925
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
930
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
931
|
+
|
|
932
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
|
933
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
|
934
|
+
"""
|
|
926
935
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
927
936
|
|
|
928
937
|
# Check if the node exists in the `node` table
|
|
@@ -932,7 +941,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
932
941
|
|
|
933
942
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
934
943
|
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
|
935
|
-
self.query(
|
|
944
|
+
self.query(
|
|
945
|
+
query,
|
|
946
|
+
(
|
|
947
|
+
time.time() + PING_PATIENCE * ping_interval,
|
|
948
|
+
ping_interval,
|
|
949
|
+
sint64_node_id,
|
|
950
|
+
),
|
|
951
|
+
)
|
|
936
952
|
return True
|
|
937
953
|
|
|
938
954
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
@@ -1001,32 +1017,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1001
1017
|
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1002
1018
|
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1003
1019
|
|
|
1004
|
-
def
|
|
1005
|
-
"""Check if the
|
|
1020
|
+
def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
|
|
1021
|
+
"""Check if the Message exists and is valid (not expired).
|
|
1006
1022
|
|
|
1007
|
-
Return
|
|
1023
|
+
Return Message if valid.
|
|
1008
1024
|
"""
|
|
1009
1025
|
query = """
|
|
1010
1026
|
SELECT *
|
|
1011
|
-
FROM
|
|
1012
|
-
WHERE
|
|
1027
|
+
FROM message_ins
|
|
1028
|
+
WHERE message_id = :message_id
|
|
1013
1029
|
"""
|
|
1014
|
-
data = {"
|
|
1030
|
+
data = {"message_id": message_id}
|
|
1015
1031
|
rows = self.query(query, data)
|
|
1016
1032
|
if not rows:
|
|
1017
|
-
#
|
|
1033
|
+
# Message does not exist
|
|
1018
1034
|
return None
|
|
1019
1035
|
|
|
1020
|
-
|
|
1021
|
-
created_at =
|
|
1022
|
-
ttl =
|
|
1036
|
+
message_ins = rows[0]
|
|
1037
|
+
created_at = message_ins["created_at"]
|
|
1038
|
+
ttl = message_ins["ttl"]
|
|
1023
1039
|
current_time = time.time()
|
|
1024
1040
|
|
|
1025
|
-
# Check if
|
|
1041
|
+
# Check if Message is expired
|
|
1026
1042
|
if ttl is not None and created_at + ttl <= current_time:
|
|
1027
1043
|
return None
|
|
1028
1044
|
|
|
1029
|
-
return
|
|
1045
|
+
return message_ins
|
|
1030
1046
|
|
|
1031
1047
|
|
|
1032
1048
|
def dict_factory(
|
|
@@ -1041,94 +1057,46 @@ def dict_factory(
|
|
|
1041
1057
|
return dict(zip(fields, row))
|
|
1042
1058
|
|
|
1043
1059
|
|
|
1044
|
-
def
|
|
1045
|
-
"""Transform
|
|
1046
|
-
result = {
|
|
1047
|
-
"task_id": task_msg.task_id,
|
|
1048
|
-
"group_id": task_msg.group_id,
|
|
1049
|
-
"run_id": task_msg.run_id,
|
|
1050
|
-
"producer_node_id": task_msg.task.producer.node_id,
|
|
1051
|
-
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
1052
|
-
"created_at": task_msg.task.created_at,
|
|
1053
|
-
"delivered_at": task_msg.task.delivered_at,
|
|
1054
|
-
"ttl": task_msg.task.ttl,
|
|
1055
|
-
"ancestry": ",".join(task_msg.task.ancestry),
|
|
1056
|
-
"task_type": task_msg.task.task_type,
|
|
1057
|
-
"recordset": task_msg.task.recordset.SerializeToString(),
|
|
1058
|
-
}
|
|
1059
|
-
return result
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
1063
|
-
"""Transform TaskRes to dict."""
|
|
1060
|
+
def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1061
|
+
"""Transform Message to dict."""
|
|
1064
1062
|
result = {
|
|
1065
|
-
"
|
|
1066
|
-
"group_id":
|
|
1067
|
-
"run_id":
|
|
1068
|
-
"
|
|
1069
|
-
"
|
|
1070
|
-
"
|
|
1071
|
-
"
|
|
1072
|
-
"
|
|
1073
|
-
"
|
|
1074
|
-
"
|
|
1075
|
-
"
|
|
1063
|
+
"message_id": message.metadata.message_id,
|
|
1064
|
+
"group_id": message.metadata.group_id,
|
|
1065
|
+
"run_id": message.metadata.run_id,
|
|
1066
|
+
"src_node_id": message.metadata.src_node_id,
|
|
1067
|
+
"dst_node_id": message.metadata.dst_node_id,
|
|
1068
|
+
"reply_to_message_id": message.metadata.reply_to_message_id,
|
|
1069
|
+
"created_at": message.metadata.created_at,
|
|
1070
|
+
"delivered_at": message.metadata.delivered_at,
|
|
1071
|
+
"ttl": message.metadata.ttl,
|
|
1072
|
+
"message_type": message.metadata.message_type,
|
|
1073
|
+
"content": None,
|
|
1074
|
+
"error": None,
|
|
1076
1075
|
}
|
|
1077
|
-
return result
|
|
1078
|
-
|
|
1079
1076
|
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1077
|
+
if message.has_content():
|
|
1078
|
+
result["content"] = recorddict_to_proto(message.content).SerializeToString()
|
|
1079
|
+
else:
|
|
1080
|
+
result["error"] = error_to_proto(message.error).SerializeToString()
|
|
1084
1081
|
|
|
1085
|
-
result = TaskIns(
|
|
1086
|
-
task_id=task_dict["task_id"],
|
|
1087
|
-
group_id=task_dict["group_id"],
|
|
1088
|
-
run_id=task_dict["run_id"],
|
|
1089
|
-
task=Task(
|
|
1090
|
-
producer=Node(
|
|
1091
|
-
node_id=task_dict["producer_node_id"],
|
|
1092
|
-
),
|
|
1093
|
-
consumer=Node(
|
|
1094
|
-
node_id=task_dict["consumer_node_id"],
|
|
1095
|
-
),
|
|
1096
|
-
created_at=task_dict["created_at"],
|
|
1097
|
-
delivered_at=task_dict["delivered_at"],
|
|
1098
|
-
ttl=task_dict["ttl"],
|
|
1099
|
-
ancestry=task_dict["ancestry"].split(","),
|
|
1100
|
-
task_type=task_dict["task_type"],
|
|
1101
|
-
recordset=recordset,
|
|
1102
|
-
),
|
|
1103
|
-
)
|
|
1104
1082
|
return result
|
|
1105
1083
|
|
|
1106
1084
|
|
|
1107
|
-
def
|
|
1108
|
-
"""
|
|
1109
|
-
|
|
1110
|
-
|
|
1085
|
+
def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
|
1086
|
+
"""Transform dict to Message."""
|
|
1087
|
+
content, error = None, None
|
|
1088
|
+
if (b_content := message_dict.pop("content")) is not None:
|
|
1089
|
+
content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
|
|
1090
|
+
if (b_error := message_dict.pop("error")) is not None:
|
|
1091
|
+
error = error_from_proto(ProtoError.FromString(b_error))
|
|
1111
1092
|
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
run_id=task_dict["run_id"],
|
|
1116
|
-
task=Task(
|
|
1117
|
-
producer=Node(
|
|
1118
|
-
node_id=task_dict["producer_node_id"],
|
|
1119
|
-
),
|
|
1120
|
-
consumer=Node(
|
|
1121
|
-
node_id=task_dict["consumer_node_id"],
|
|
1122
|
-
),
|
|
1123
|
-
created_at=task_dict["created_at"],
|
|
1124
|
-
delivered_at=task_dict["delivered_at"],
|
|
1125
|
-
ttl=task_dict["ttl"],
|
|
1126
|
-
ancestry=task_dict["ancestry"].split(","),
|
|
1127
|
-
task_type=task_dict["task_type"],
|
|
1128
|
-
recordset=recordset,
|
|
1129
|
-
),
|
|
1093
|
+
# Metadata constructor doesn't allow passing created_at. We set it later
|
|
1094
|
+
metadata = Metadata(
|
|
1095
|
+
**{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
|
|
1130
1096
|
)
|
|
1131
|
-
|
|
1097
|
+
msg = make_message(metadata=metadata, content=content, error=error)
|
|
1098
|
+
msg.metadata.delivered_at = message_dict["delivered_at"]
|
|
1099
|
+
return msg
|
|
1132
1100
|
|
|
1133
1101
|
|
|
1134
1102
|
def determine_run_status(row: dict[str, Any]) -> str:
|