flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
|
@@ -23,26 +23,27 @@ from logging import ERROR, WARNING
|
|
|
23
23
|
from typing import Optional
|
|
24
24
|
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
|
-
from flwr.common import Context, log, now
|
|
26
|
+
from flwr.common import Context, Message, log, now
|
|
27
27
|
from flwr.common.constant import (
|
|
28
28
|
MESSAGE_TTL_TOLERANCE,
|
|
29
29
|
NODE_ID_NUM_BYTES,
|
|
30
|
+
PING_PATIENCE,
|
|
30
31
|
RUN_ID_NUM_BYTES,
|
|
31
32
|
SUPERLINK_NODE_ID,
|
|
32
33
|
Status,
|
|
33
34
|
)
|
|
34
|
-
from flwr.common.record import
|
|
35
|
+
from flwr.common.record import ConfigRecord
|
|
35
36
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
36
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
37
37
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
38
|
-
from flwr.server.utils import
|
|
38
|
+
from flwr.server.utils import validate_message
|
|
39
39
|
|
|
40
40
|
from .utils import (
|
|
41
|
+
check_node_availability_for_in_message,
|
|
41
42
|
generate_rand_int_from_bytes,
|
|
42
43
|
has_valid_sub_status,
|
|
43
44
|
is_valid_transition,
|
|
44
|
-
|
|
45
|
-
|
|
45
|
+
verify_found_message_replies,
|
|
46
|
+
verify_message_ids,
|
|
46
47
|
)
|
|
47
48
|
|
|
48
49
|
|
|
@@ -68,228 +69,258 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
68
69
|
# Map run_id to RunRecord
|
|
69
70
|
self.run_ids: dict[int, RunRecord] = {}
|
|
70
71
|
self.contexts: dict[int, Context] = {}
|
|
71
|
-
self.federation_options: dict[int,
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
72
|
+
self.federation_options: dict[int, ConfigRecord] = {}
|
|
73
|
+
self.message_ins_store: dict[UUID, Message] = {}
|
|
74
|
+
self.message_res_store: dict[UUID, Message] = {}
|
|
75
|
+
self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
|
|
75
76
|
|
|
76
77
|
self.node_public_keys: set[bytes] = set()
|
|
77
78
|
|
|
78
79
|
self.lock = threading.RLock()
|
|
79
80
|
|
|
80
|
-
def
|
|
81
|
-
"""Store one
|
|
82
|
-
# Validate
|
|
83
|
-
errors =
|
|
81
|
+
def store_message_ins(self, message: Message) -> Optional[UUID]:
|
|
82
|
+
"""Store one Message."""
|
|
83
|
+
# Validate message
|
|
84
|
+
errors = validate_message(message, is_reply_message=False)
|
|
84
85
|
if any(errors):
|
|
85
86
|
log(ERROR, errors)
|
|
86
87
|
return None
|
|
87
88
|
# Validate run_id
|
|
88
|
-
if
|
|
89
|
-
log(ERROR, "Invalid run ID for
|
|
89
|
+
if message.metadata.run_id not in self.run_ids:
|
|
90
|
+
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
90
91
|
return None
|
|
91
92
|
# Validate source node ID
|
|
92
|
-
if
|
|
93
|
+
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
93
94
|
log(
|
|
94
95
|
ERROR,
|
|
95
|
-
"Invalid source node ID for
|
|
96
|
-
|
|
96
|
+
"Invalid source node ID for Message: %s",
|
|
97
|
+
message.metadata.src_node_id,
|
|
97
98
|
)
|
|
98
99
|
return None
|
|
99
100
|
# Validate destination node ID
|
|
100
|
-
if
|
|
101
|
+
if message.metadata.dst_node_id not in self.node_ids:
|
|
101
102
|
log(
|
|
102
103
|
ERROR,
|
|
103
|
-
"Invalid destination node ID for
|
|
104
|
-
|
|
104
|
+
"Invalid destination node ID for Message: %s",
|
|
105
|
+
message.metadata.dst_node_id,
|
|
105
106
|
)
|
|
106
107
|
return None
|
|
107
108
|
|
|
108
|
-
# Create
|
|
109
|
-
|
|
109
|
+
# Create message_id
|
|
110
|
+
message_id = uuid4()
|
|
110
111
|
|
|
111
|
-
# Store
|
|
112
|
-
|
|
112
|
+
# Store Message
|
|
113
|
+
# pylint: disable-next=W0212
|
|
114
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
|
113
115
|
with self.lock:
|
|
114
|
-
self.
|
|
116
|
+
self.message_ins_store[message_id] = message
|
|
115
117
|
|
|
116
|
-
# Return the new
|
|
117
|
-
return
|
|
118
|
+
# Return the new message_id
|
|
119
|
+
return message_id
|
|
118
120
|
|
|
119
|
-
def
|
|
120
|
-
"""Get all
|
|
121
|
+
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
|
122
|
+
"""Get all Messages that have not been delivered yet."""
|
|
121
123
|
if limit is not None and limit < 1:
|
|
122
124
|
raise AssertionError("`limit` must be >= 1")
|
|
123
125
|
|
|
124
|
-
# Find
|
|
125
|
-
|
|
126
|
+
# Find Message for node_id that were not delivered yet
|
|
127
|
+
message_ins_list: list[Message] = []
|
|
126
128
|
current_time = time.time()
|
|
127
129
|
with self.lock:
|
|
128
|
-
for _,
|
|
130
|
+
for _, msg_ins in self.message_ins_store.items():
|
|
129
131
|
if (
|
|
130
|
-
|
|
131
|
-
and
|
|
132
|
-
and
|
|
132
|
+
msg_ins.metadata.dst_node_id == node_id
|
|
133
|
+
and msg_ins.metadata.delivered_at == ""
|
|
134
|
+
and msg_ins.metadata.created_at + msg_ins.metadata.ttl
|
|
135
|
+
> current_time
|
|
133
136
|
):
|
|
134
|
-
|
|
135
|
-
if limit and len(
|
|
137
|
+
message_ins_list.append(msg_ins)
|
|
138
|
+
if limit and len(message_ins_list) == limit:
|
|
136
139
|
break
|
|
137
140
|
|
|
138
141
|
# Mark all of them as delivered
|
|
139
142
|
delivered_at = now().isoformat()
|
|
140
|
-
for
|
|
141
|
-
|
|
143
|
+
for msg_ins in message_ins_list:
|
|
144
|
+
msg_ins.metadata.delivered_at = delivered_at
|
|
142
145
|
|
|
143
|
-
# Return
|
|
144
|
-
return
|
|
146
|
+
# Return list of messages
|
|
147
|
+
return message_ins_list
|
|
145
148
|
|
|
146
149
|
# pylint: disable=R0911
|
|
147
|
-
def
|
|
148
|
-
"""Store one
|
|
149
|
-
# Validate
|
|
150
|
-
errors =
|
|
150
|
+
def store_message_res(self, message: Message) -> Optional[UUID]:
|
|
151
|
+
"""Store one Message."""
|
|
152
|
+
# Validate message
|
|
153
|
+
errors = validate_message(message, is_reply_message=True)
|
|
151
154
|
if any(errors):
|
|
152
155
|
log(ERROR, errors)
|
|
153
156
|
return None
|
|
154
157
|
|
|
158
|
+
res_metadata = message.metadata
|
|
155
159
|
with self.lock:
|
|
156
|
-
# Check if the
|
|
157
|
-
|
|
158
|
-
|
|
160
|
+
# Check if the Message it is replying to exists and is valid
|
|
161
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
|
162
|
+
msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
|
|
159
163
|
|
|
160
|
-
# Ensure that
|
|
164
|
+
# Ensure that dst_node_id of original Message matches the src_node_id of
|
|
165
|
+
# reply Message.
|
|
161
166
|
if (
|
|
162
|
-
|
|
163
|
-
and
|
|
164
|
-
and
|
|
167
|
+
msg_ins
|
|
168
|
+
and message
|
|
169
|
+
and msg_ins.metadata.dst_node_id != res_metadata.src_node_id
|
|
165
170
|
):
|
|
166
171
|
return None
|
|
167
172
|
|
|
168
|
-
if
|
|
169
|
-
log(
|
|
173
|
+
if msg_ins is None:
|
|
174
|
+
log(
|
|
175
|
+
ERROR,
|
|
176
|
+
"Message with ID %s does not exist.",
|
|
177
|
+
msg_ins_id,
|
|
178
|
+
)
|
|
170
179
|
return None
|
|
171
180
|
|
|
172
|
-
|
|
181
|
+
ins_metadata = msg_ins.metadata
|
|
182
|
+
if ins_metadata.created_at + ins_metadata.ttl <= time.time():
|
|
173
183
|
log(
|
|
174
184
|
ERROR,
|
|
175
|
-
"Failed to store
|
|
176
|
-
|
|
185
|
+
"Failed to store Message: the message it is replying to "
|
|
186
|
+
"(with ID %s) has expired",
|
|
187
|
+
msg_ins_id,
|
|
177
188
|
)
|
|
178
189
|
return None
|
|
179
190
|
|
|
180
|
-
# Fail if the
|
|
181
|
-
# expiration time of the
|
|
182
|
-
# Condition:
|
|
183
|
-
#
|
|
191
|
+
# Fail if the Message TTL exceeds the
|
|
192
|
+
# expiration time of the Message it replies to.
|
|
193
|
+
# Condition: ins_metadata.created_at + ins_metadata.ttl ≥
|
|
194
|
+
# res_metadata.created_at + res_metadata.ttl
|
|
184
195
|
# A small tolerance is introduced to account
|
|
185
196
|
# for floating-point precision issues.
|
|
186
197
|
max_allowed_ttl = (
|
|
187
|
-
|
|
198
|
+
ins_metadata.created_at + ins_metadata.ttl - res_metadata.created_at
|
|
188
199
|
)
|
|
189
|
-
if
|
|
190
|
-
|
|
200
|
+
if res_metadata.ttl and (
|
|
201
|
+
res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
|
|
191
202
|
):
|
|
192
203
|
log(
|
|
193
204
|
WARNING,
|
|
194
|
-
"Received
|
|
195
|
-
"
|
|
196
|
-
|
|
205
|
+
"Received Message with TTL %.2f exceeding the allowed maximum "
|
|
206
|
+
"TTL %.2f.",
|
|
207
|
+
res_metadata.ttl,
|
|
197
208
|
max_allowed_ttl,
|
|
198
209
|
)
|
|
199
210
|
return None
|
|
200
211
|
|
|
201
212
|
# Validate run_id
|
|
202
|
-
if
|
|
203
|
-
log(ERROR, "`run_id` is invalid")
|
|
213
|
+
if res_metadata.run_id != ins_metadata.run_id:
|
|
214
|
+
log(ERROR, "`metadata.run_id` is invalid")
|
|
204
215
|
return None
|
|
205
216
|
|
|
206
|
-
# Create
|
|
207
|
-
|
|
217
|
+
# Create message_id
|
|
218
|
+
message_id = uuid4()
|
|
208
219
|
|
|
209
|
-
# Store
|
|
210
|
-
|
|
220
|
+
# Store Message
|
|
221
|
+
# pylint: disable-next=W0212
|
|
222
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
|
211
223
|
with self.lock:
|
|
212
|
-
self.
|
|
213
|
-
self.
|
|
224
|
+
self.message_res_store[message_id] = message
|
|
225
|
+
self.message_ins_id_to_message_res_id[UUID(msg_ins_id)] = message_id
|
|
214
226
|
|
|
215
|
-
# Return the new
|
|
216
|
-
return
|
|
227
|
+
# Return the new message_id
|
|
228
|
+
return message_id
|
|
217
229
|
|
|
218
|
-
def
|
|
219
|
-
"""Get
|
|
220
|
-
ret: dict[UUID,
|
|
230
|
+
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
|
231
|
+
"""Get reply Messages for the given Message IDs."""
|
|
232
|
+
ret: dict[UUID, Message] = {}
|
|
221
233
|
|
|
222
234
|
with self.lock:
|
|
223
235
|
current = time.time()
|
|
224
236
|
|
|
225
|
-
# Verify
|
|
226
|
-
ret =
|
|
227
|
-
|
|
228
|
-
|
|
237
|
+
# Verify Message IDs
|
|
238
|
+
ret = verify_message_ids(
|
|
239
|
+
inquired_message_ids=message_ids,
|
|
240
|
+
found_message_ins_dict=self.message_ins_store,
|
|
241
|
+
current_time=current,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Check node availability
|
|
245
|
+
dst_node_ids = {
|
|
246
|
+
self.message_ins_store[message_id].metadata.dst_node_id
|
|
247
|
+
for message_id in message_ids
|
|
248
|
+
}
|
|
249
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
|
250
|
+
inquired_in_message_ids=message_ids,
|
|
251
|
+
found_in_message_dict=self.message_ins_store,
|
|
252
|
+
node_id_to_online_until={
|
|
253
|
+
node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
|
|
254
|
+
},
|
|
229
255
|
current_time=current,
|
|
230
256
|
)
|
|
257
|
+
ret.update(tmp_ret_dict)
|
|
231
258
|
|
|
232
|
-
# Find all
|
|
233
|
-
|
|
234
|
-
for
|
|
235
|
-
# If
|
|
236
|
-
if
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
259
|
+
# Find all reply Messages
|
|
260
|
+
message_res_found: list[Message] = []
|
|
261
|
+
for message_id in message_ids:
|
|
262
|
+
# If Message exists and is not delivered, add it to the list
|
|
263
|
+
if message_res_id := self.message_ins_id_to_message_res_id.get(
|
|
264
|
+
message_id
|
|
265
|
+
):
|
|
266
|
+
message_res = self.message_res_store[message_res_id]
|
|
267
|
+
if message_res.metadata.delivered_at == "":
|
|
268
|
+
message_res_found.append(message_res)
|
|
269
|
+
tmp_ret_dict = verify_found_message_replies(
|
|
270
|
+
inquired_message_ids=message_ids,
|
|
271
|
+
found_message_ins_dict=self.message_ins_store,
|
|
272
|
+
found_message_res_list=message_res_found,
|
|
244
273
|
current_time=current,
|
|
245
274
|
)
|
|
246
275
|
ret.update(tmp_ret_dict)
|
|
247
276
|
|
|
248
|
-
# Mark existing
|
|
277
|
+
# Mark existing reply Messages to be returned as delivered
|
|
249
278
|
delivered_at = now().isoformat()
|
|
250
|
-
for
|
|
251
|
-
|
|
279
|
+
for message_res in message_res_found:
|
|
280
|
+
message_res.metadata.delivered_at = delivered_at
|
|
252
281
|
|
|
253
282
|
return list(ret.values())
|
|
254
283
|
|
|
255
|
-
def
|
|
256
|
-
"""Delete
|
|
257
|
-
if not
|
|
284
|
+
def delete_messages(self, message_ins_ids: set[UUID]) -> None:
|
|
285
|
+
"""Delete a Message and its reply based on provided Message IDs."""
|
|
286
|
+
if not message_ins_ids:
|
|
258
287
|
return
|
|
259
288
|
|
|
260
289
|
with self.lock:
|
|
261
|
-
for
|
|
262
|
-
# Delete
|
|
263
|
-
if
|
|
264
|
-
del self.
|
|
265
|
-
# Delete
|
|
266
|
-
if
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
290
|
+
for message_id in message_ins_ids:
|
|
291
|
+
# Delete Messages
|
|
292
|
+
if message_id in self.message_ins_store:
|
|
293
|
+
del self.message_ins_store[message_id]
|
|
294
|
+
# Delete Message replies
|
|
295
|
+
if message_id in self.message_ins_id_to_message_res_id:
|
|
296
|
+
message_res_id = self.message_ins_id_to_message_res_id.pop(
|
|
297
|
+
message_id
|
|
298
|
+
)
|
|
299
|
+
del self.message_res_store[message_res_id]
|
|
300
|
+
|
|
301
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
302
|
+
"""Get all instruction Message IDs for the given run_id."""
|
|
303
|
+
message_id_list: set[UUID] = set()
|
|
273
304
|
with self.lock:
|
|
274
|
-
for
|
|
275
|
-
if
|
|
276
|
-
|
|
305
|
+
for message_id, message in self.message_ins_store.items():
|
|
306
|
+
if message.metadata.run_id == run_id:
|
|
307
|
+
message_id_list.add(message_id)
|
|
277
308
|
|
|
278
|
-
return
|
|
309
|
+
return message_id_list
|
|
279
310
|
|
|
280
|
-
def
|
|
281
|
-
"""Calculate the number of
|
|
311
|
+
def num_message_ins(self) -> int:
|
|
312
|
+
"""Calculate the number of instruction Messages in store.
|
|
282
313
|
|
|
283
|
-
This includes delivered but not yet deleted
|
|
314
|
+
This includes delivered but not yet deleted.
|
|
284
315
|
"""
|
|
285
|
-
return len(self.
|
|
316
|
+
return len(self.message_ins_store)
|
|
286
317
|
|
|
287
|
-
def
|
|
288
|
-
"""Calculate the number of
|
|
318
|
+
def num_message_res(self) -> int:
|
|
319
|
+
"""Calculate the number of reply Messages in store.
|
|
289
320
|
|
|
290
|
-
This includes delivered but not yet deleted
|
|
321
|
+
This includes delivered but not yet deleted.
|
|
291
322
|
"""
|
|
292
|
-
return len(self.
|
|
323
|
+
return len(self.message_res_store)
|
|
293
324
|
|
|
294
325
|
def create_node(self, ping_interval: float) -> int:
|
|
295
326
|
"""Create, store in the link state, and return `node_id`."""
|
|
@@ -303,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
303
334
|
log(ERROR, "Unexpected node registration failure.")
|
|
304
335
|
return 0
|
|
305
336
|
|
|
337
|
+
# Mark the node online util time.time() + ping_interval
|
|
306
338
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
307
339
|
return node_id
|
|
308
340
|
|
|
@@ -367,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
367
399
|
fab_version: Optional[str],
|
|
368
400
|
fab_hash: Optional[str],
|
|
369
401
|
override_config: UserConfig,
|
|
370
|
-
federation_options:
|
|
402
|
+
federation_options: ConfigRecord,
|
|
371
403
|
) -> int:
|
|
372
404
|
"""Create a new run for the specified `fab_hash`."""
|
|
373
405
|
# Sample a random int64 as run_id
|
|
@@ -496,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
496
528
|
|
|
497
529
|
return pending_run_id
|
|
498
530
|
|
|
499
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
|
531
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
|
500
532
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
501
533
|
with self.lock:
|
|
502
534
|
if run_id not in self.run_ids:
|
|
@@ -505,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
505
537
|
return self.federation_options[run_id]
|
|
506
538
|
|
|
507
539
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
508
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
540
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
541
|
+
|
|
542
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
|
543
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
|
544
|
+
"""
|
|
509
545
|
with self.lock:
|
|
510
546
|
if node_id in self.node_ids:
|
|
511
|
-
self.node_ids[node_id] = (
|
|
547
|
+
self.node_ids[node_id] = (
|
|
548
|
+
time.time() + PING_PATIENCE * ping_interval,
|
|
549
|
+
ping_interval,
|
|
550
|
+
)
|
|
512
551
|
return True
|
|
513
552
|
return False
|
|
514
553
|
|