flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
@@ -23,26 +23,27 @@ from logging import ERROR, WARNING
23
23
  from typing import Optional
24
24
  from uuid import UUID, uuid4
25
25
 
26
- from flwr.common import Context, log, now
26
+ from flwr.common import Context, Message, log, now
27
27
  from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
30
+ PING_PATIENCE,
30
31
  RUN_ID_NUM_BYTES,
31
32
  SUPERLINK_NODE_ID,
32
33
  Status,
33
34
  )
34
- from flwr.common.record import ConfigsRecord
35
+ from flwr.common.record import ConfigRecord
35
36
  from flwr.common.typing import Run, RunStatus, UserConfig
36
- from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
37
37
  from flwr.server.superlink.linkstate.linkstate import LinkState
38
- from flwr.server.utils import validate_task_ins_or_res
38
+ from flwr.server.utils import validate_message
39
39
 
40
40
  from .utils import (
41
+ check_node_availability_for_in_message,
41
42
  generate_rand_int_from_bytes,
42
43
  has_valid_sub_status,
43
44
  is_valid_transition,
44
- verify_found_taskres,
45
- verify_taskins_ids,
45
+ verify_found_message_replies,
46
+ verify_message_ids,
46
47
  )
47
48
 
48
49
 
@@ -68,228 +69,258 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
68
69
  # Map run_id to RunRecord
69
70
  self.run_ids: dict[int, RunRecord] = {}
70
71
  self.contexts: dict[int, Context] = {}
71
- self.federation_options: dict[int, ConfigsRecord] = {}
72
- self.task_ins_store: dict[UUID, TaskIns] = {}
73
- self.task_res_store: dict[UUID, TaskRes] = {}
74
- self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
72
+ self.federation_options: dict[int, ConfigRecord] = {}
73
+ self.message_ins_store: dict[UUID, Message] = {}
74
+ self.message_res_store: dict[UUID, Message] = {}
75
+ self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
75
76
 
76
77
  self.node_public_keys: set[bytes] = set()
77
78
 
78
79
  self.lock = threading.RLock()
79
80
 
80
- def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
81
- """Store one TaskIns."""
82
- # Validate task
83
- errors = validate_task_ins_or_res(task_ins)
81
+ def store_message_ins(self, message: Message) -> Optional[UUID]:
82
+ """Store one Message."""
83
+ # Validate message
84
+ errors = validate_message(message, is_reply_message=False)
84
85
  if any(errors):
85
86
  log(ERROR, errors)
86
87
  return None
87
88
  # Validate run_id
88
- if task_ins.run_id not in self.run_ids:
89
- log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
89
+ if message.metadata.run_id not in self.run_ids:
90
+ log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
90
91
  return None
91
92
  # Validate source node ID
92
- if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
93
+ if message.metadata.src_node_id != SUPERLINK_NODE_ID:
93
94
  log(
94
95
  ERROR,
95
- "Invalid source node ID for TaskIns: %s",
96
- task_ins.task.producer.node_id,
96
+ "Invalid source node ID for Message: %s",
97
+ message.metadata.src_node_id,
97
98
  )
98
99
  return None
99
100
  # Validate destination node ID
100
- if task_ins.task.consumer.node_id not in self.node_ids:
101
+ if message.metadata.dst_node_id not in self.node_ids:
101
102
  log(
102
103
  ERROR,
103
- "Invalid destination node ID for TaskIns: %s",
104
- task_ins.task.consumer.node_id,
104
+ "Invalid destination node ID for Message: %s",
105
+ message.metadata.dst_node_id,
105
106
  )
106
107
  return None
107
108
 
108
- # Create task_id
109
- task_id = uuid4()
109
+ # Create message_id
110
+ message_id = uuid4()
110
111
 
111
- # Store TaskIns
112
- task_ins.task_id = str(task_id)
112
+ # Store Message
113
+ # pylint: disable-next=W0212
114
+ message.metadata._message_id = str(message_id) # type: ignore
113
115
  with self.lock:
114
- self.task_ins_store[task_id] = task_ins
116
+ self.message_ins_store[message_id] = message
115
117
 
116
- # Return the new task_id
117
- return task_id
118
+ # Return the new message_id
119
+ return message_id
118
120
 
119
- def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
120
- """Get all TaskIns that have not been delivered yet."""
121
+ def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
122
+ """Get all Messages that have not been delivered yet."""
121
123
  if limit is not None and limit < 1:
122
124
  raise AssertionError("`limit` must be >= 1")
123
125
 
124
- # Find TaskIns for node_id that were not delivered yet
125
- task_ins_list: list[TaskIns] = []
126
+ # Find Message for node_id that were not delivered yet
127
+ message_ins_list: list[Message] = []
126
128
  current_time = time.time()
127
129
  with self.lock:
128
- for _, task_ins in self.task_ins_store.items():
130
+ for _, msg_ins in self.message_ins_store.items():
129
131
  if (
130
- task_ins.task.consumer.node_id == node_id
131
- and task_ins.task.delivered_at == ""
132
- and task_ins.task.created_at + task_ins.task.ttl > current_time
132
+ msg_ins.metadata.dst_node_id == node_id
133
+ and msg_ins.metadata.delivered_at == ""
134
+ and msg_ins.metadata.created_at + msg_ins.metadata.ttl
135
+ > current_time
133
136
  ):
134
- task_ins_list.append(task_ins)
135
- if limit and len(task_ins_list) == limit:
137
+ message_ins_list.append(msg_ins)
138
+ if limit and len(message_ins_list) == limit:
136
139
  break
137
140
 
138
141
  # Mark all of them as delivered
139
142
  delivered_at = now().isoformat()
140
- for task_ins in task_ins_list:
141
- task_ins.task.delivered_at = delivered_at
143
+ for msg_ins in message_ins_list:
144
+ msg_ins.metadata.delivered_at = delivered_at
142
145
 
143
- # Return TaskIns
144
- return task_ins_list
146
+ # Return list of messages
147
+ return message_ins_list
145
148
 
146
149
  # pylint: disable=R0911
147
- def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
148
- """Store one TaskRes."""
149
- # Validate task
150
- errors = validate_task_ins_or_res(task_res)
150
+ def store_message_res(self, message: Message) -> Optional[UUID]:
151
+ """Store one Message."""
152
+ # Validate message
153
+ errors = validate_message(message, is_reply_message=True)
151
154
  if any(errors):
152
155
  log(ERROR, errors)
153
156
  return None
154
157
 
158
+ res_metadata = message.metadata
155
159
  with self.lock:
156
- # Check if the TaskIns it is replying to exists and is valid
157
- task_ins_id = task_res.task.ancestry[0]
158
- task_ins = self.task_ins_store.get(UUID(task_ins_id))
160
+ # Check if the Message it is replying to exists and is valid
161
+ msg_ins_id = res_metadata.reply_to_message_id
162
+ msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
159
163
 
160
- # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
164
+ # Ensure that dst_node_id of original Message matches the src_node_id of
165
+ # reply Message.
161
166
  if (
162
- task_ins
163
- and task_res
164
- and task_ins.task.consumer.node_id != task_res.task.producer.node_id
167
+ msg_ins
168
+ and message
169
+ and msg_ins.metadata.dst_node_id != res_metadata.src_node_id
165
170
  ):
166
171
  return None
167
172
 
168
- if task_ins is None:
169
- log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id)
173
+ if msg_ins is None:
174
+ log(
175
+ ERROR,
176
+ "Message with ID %s does not exist.",
177
+ msg_ins_id,
178
+ )
170
179
  return None
171
180
 
172
- if task_ins.task.created_at + task_ins.task.ttl <= time.time():
181
+ ins_metadata = msg_ins.metadata
182
+ if ins_metadata.created_at + ins_metadata.ttl <= time.time():
173
183
  log(
174
184
  ERROR,
175
- "Failed to store TaskRes: TaskIns with task_id %s has expired.",
176
- task_ins_id,
185
+ "Failed to store Message: the message it is replying to "
186
+ "(with ID %s) has expired",
187
+ msg_ins_id,
177
188
  )
178
189
  return None
179
190
 
180
- # Fail if the TaskRes TTL exceeds the
181
- # expiration time of the TaskIns it replies to.
182
- # Condition: TaskIns.created_at + TaskIns.ttl ≥
183
- # TaskRes.created_at + TaskRes.ttl
191
+ # Fail if the Message TTL exceeds the
192
+ # expiration time of the Message it replies to.
193
+ # Condition: ins_metadata.created_at + ins_metadata.ttl ≥
194
+ # res_metadata.created_at + res_metadata.ttl
184
195
  # A small tolerance is introduced to account
185
196
  # for floating-point precision issues.
186
197
  max_allowed_ttl = (
187
- task_ins.task.created_at + task_ins.task.ttl - task_res.task.created_at
198
+ ins_metadata.created_at + ins_metadata.ttl - res_metadata.created_at
188
199
  )
189
- if task_res.task.ttl and (
190
- task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
200
+ if res_metadata.ttl and (
201
+ res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
191
202
  ):
192
203
  log(
193
204
  WARNING,
194
- "Received TaskRes with TTL %.2f "
195
- "exceeding the allowed maximum TTL %.2f.",
196
- task_res.task.ttl,
205
+ "Received Message with TTL %.2f exceeding the allowed maximum "
206
+ "TTL %.2f.",
207
+ res_metadata.ttl,
197
208
  max_allowed_ttl,
198
209
  )
199
210
  return None
200
211
 
201
212
  # Validate run_id
202
- if task_res.run_id not in self.run_ids:
203
- log(ERROR, "`run_id` is invalid")
213
+ if res_metadata.run_id != ins_metadata.run_id:
214
+ log(ERROR, "`metadata.run_id` is invalid")
204
215
  return None
205
216
 
206
- # Create task_id
207
- task_id = uuid4()
217
+ # Create message_id
218
+ message_id = uuid4()
208
219
 
209
- # Store TaskRes
210
- task_res.task_id = str(task_id)
220
+ # Store Message
221
+ # pylint: disable-next=W0212
222
+ message.metadata._message_id = str(message_id) # type: ignore
211
223
  with self.lock:
212
- self.task_res_store[task_id] = task_res
213
- self.task_ins_id_to_task_res_id[UUID(task_ins_id)] = task_id
224
+ self.message_res_store[message_id] = message
225
+ self.message_ins_id_to_message_res_id[UUID(msg_ins_id)] = message_id
214
226
 
215
- # Return the new task_id
216
- return task_id
227
+ # Return the new message_id
228
+ return message_id
217
229
 
218
- def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
219
- """Get TaskRes for the given TaskIns IDs."""
220
- ret: dict[UUID, TaskRes] = {}
230
+ def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
231
+ """Get reply Messages for the given Message IDs."""
232
+ ret: dict[UUID, Message] = {}
221
233
 
222
234
  with self.lock:
223
235
  current = time.time()
224
236
 
225
- # Verify TaskIns IDs
226
- ret = verify_taskins_ids(
227
- inquired_taskins_ids=task_ids,
228
- found_taskins_dict=self.task_ins_store,
237
+ # Verify Message IDs
238
+ ret = verify_message_ids(
239
+ inquired_message_ids=message_ids,
240
+ found_message_ins_dict=self.message_ins_store,
241
+ current_time=current,
242
+ )
243
+
244
+ # Check node availability
245
+ dst_node_ids = {
246
+ self.message_ins_store[message_id].metadata.dst_node_id
247
+ for message_id in message_ids
248
+ }
249
+ tmp_ret_dict = check_node_availability_for_in_message(
250
+ inquired_in_message_ids=message_ids,
251
+ found_in_message_dict=self.message_ins_store,
252
+ node_id_to_online_until={
253
+ node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
254
+ },
229
255
  current_time=current,
230
256
  )
257
+ ret.update(tmp_ret_dict)
231
258
 
232
- # Find all TaskRes
233
- task_res_found: list[TaskRes] = []
234
- for task_id in task_ids:
235
- # If TaskRes exists and is not delivered, add it to the list
236
- if task_res_id := self.task_ins_id_to_task_res_id.get(task_id):
237
- task_res = self.task_res_store[task_res_id]
238
- if task_res.task.delivered_at == "":
239
- task_res_found.append(task_res)
240
- tmp_ret_dict = verify_found_taskres(
241
- inquired_taskins_ids=task_ids,
242
- found_taskins_dict=self.task_ins_store,
243
- found_taskres_list=task_res_found,
259
+ # Find all reply Messages
260
+ message_res_found: list[Message] = []
261
+ for message_id in message_ids:
262
+ # If Message exists and is not delivered, add it to the list
263
+ if message_res_id := self.message_ins_id_to_message_res_id.get(
264
+ message_id
265
+ ):
266
+ message_res = self.message_res_store[message_res_id]
267
+ if message_res.metadata.delivered_at == "":
268
+ message_res_found.append(message_res)
269
+ tmp_ret_dict = verify_found_message_replies(
270
+ inquired_message_ids=message_ids,
271
+ found_message_ins_dict=self.message_ins_store,
272
+ found_message_res_list=message_res_found,
244
273
  current_time=current,
245
274
  )
246
275
  ret.update(tmp_ret_dict)
247
276
 
248
- # Mark existing TaskRes to be returned as delivered
277
+ # Mark existing reply Messages to be returned as delivered
249
278
  delivered_at = now().isoformat()
250
- for task_res in task_res_found:
251
- task_res.task.delivered_at = delivered_at
279
+ for message_res in message_res_found:
280
+ message_res.metadata.delivered_at = delivered_at
252
281
 
253
282
  return list(ret.values())
254
283
 
255
- def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
256
- """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
257
- if not task_ins_ids:
284
+ def delete_messages(self, message_ins_ids: set[UUID]) -> None:
285
+ """Delete a Message and its reply based on provided Message IDs."""
286
+ if not message_ins_ids:
258
287
  return
259
288
 
260
289
  with self.lock:
261
- for task_id in task_ins_ids:
262
- # Delete TaskIns
263
- if task_id in self.task_ins_store:
264
- del self.task_ins_store[task_id]
265
- # Delete TaskRes
266
- if task_id in self.task_ins_id_to_task_res_id:
267
- task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
268
- del self.task_res_store[task_res_id]
269
-
270
- def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
271
- """Get all TaskIns IDs for the given run_id."""
272
- task_id_list: set[UUID] = set()
290
+ for message_id in message_ins_ids:
291
+ # Delete Messages
292
+ if message_id in self.message_ins_store:
293
+ del self.message_ins_store[message_id]
294
+ # Delete Message replies
295
+ if message_id in self.message_ins_id_to_message_res_id:
296
+ message_res_id = self.message_ins_id_to_message_res_id.pop(
297
+ message_id
298
+ )
299
+ del self.message_res_store[message_res_id]
300
+
301
+ def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
302
+ """Get all instruction Message IDs for the given run_id."""
303
+ message_id_list: set[UUID] = set()
273
304
  with self.lock:
274
- for task_id, task_ins in self.task_ins_store.items():
275
- if task_ins.run_id == run_id:
276
- task_id_list.add(task_id)
305
+ for message_id, message in self.message_ins_store.items():
306
+ if message.metadata.run_id == run_id:
307
+ message_id_list.add(message_id)
277
308
 
278
- return task_id_list
309
+ return message_id_list
279
310
 
280
- def num_task_ins(self) -> int:
281
- """Calculate the number of task_ins in store.
311
+ def num_message_ins(self) -> int:
312
+ """Calculate the number of instruction Messages in store.
282
313
 
283
- This includes delivered but not yet deleted task_ins.
314
+ This includes delivered but not yet deleted.
284
315
  """
285
- return len(self.task_ins_store)
316
+ return len(self.message_ins_store)
286
317
 
287
- def num_task_res(self) -> int:
288
- """Calculate the number of task_res in store.
318
+ def num_message_res(self) -> int:
319
+ """Calculate the number of reply Messages in store.
289
320
 
290
- This includes delivered but not yet deleted task_res.
321
+ This includes delivered but not yet deleted.
291
322
  """
292
- return len(self.task_res_store)
323
+ return len(self.message_res_store)
293
324
 
294
325
  def create_node(self, ping_interval: float) -> int:
295
326
  """Create, store in the link state, and return `node_id`."""
@@ -303,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
303
334
  log(ERROR, "Unexpected node registration failure.")
304
335
  return 0
305
336
 
337
+ # Mark the node online util time.time() + ping_interval
306
338
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
307
339
  return node_id
308
340
 
@@ -367,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
367
399
  fab_version: Optional[str],
368
400
  fab_hash: Optional[str],
369
401
  override_config: UserConfig,
370
- federation_options: ConfigsRecord,
402
+ federation_options: ConfigRecord,
371
403
  ) -> int:
372
404
  """Create a new run for the specified `fab_hash`."""
373
405
  # Sample a random int64 as run_id
@@ -496,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
496
528
 
497
529
  return pending_run_id
498
530
 
499
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
531
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
500
532
  """Retrieve the federation options for the specified `run_id`."""
501
533
  with self.lock:
502
534
  if run_id not in self.run_ids:
@@ -505,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
505
537
  return self.federation_options[run_id]
506
538
 
507
539
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
508
- """Acknowledge a ping received from a node, serving as a heartbeat."""
540
+ """Acknowledge a ping received from a node, serving as a heartbeat.
541
+
542
+ It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
543
+ marking the node as offline, where PING_PATIENCE = 2 in default.
544
+ """
509
545
  with self.lock:
510
546
  if node_id in self.node_ids:
511
- self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
547
+ self.node_ids[node_id] = (
548
+ time.time() + PING_PATIENCE * ping_interval,
549
+ ping_interval,
550
+ )
512
551
  return True
513
552
  return False
514
553