flwr-nightly 1.16.0.dev20250305__py3-none-any.whl → 1.16.0.dev20250307__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/common/message.py +7 -7
- flwr/common/record/recordset.py +4 -12
- flwr/common/serde.py +8 -126
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +1 -1
- flwr/server/driver/inmemory_driver.py +17 -20
- flwr/server/superlink/driver/serverappio_servicer.py +18 -23
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -35
- flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -221
- flwr/server/superlink/linkstate/linkstate.py +0 -113
- flwr/server/superlink/linkstate/sqlite_linkstate.py +2 -511
- flwr/server/superlink/linkstate/utils.py +2 -179
- flwr/server/utils/__init__.py +0 -2
- flwr/server/utils/validator.py +0 -88
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/exec_servicer.py +3 -3
- {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/METADATA +1 -1
- {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/RECORD +27 -32
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.16.0.dev20250305.dist-info → flwr_nightly-1.16.0.dev20250307.dist-info}/entry_points.txt +0 -0
@@ -29,6 +29,7 @@ from typing import Callable, Optional
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
30
30
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
31
31
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
32
|
+
from flwr.common import Message
|
32
33
|
from flwr.common.constant import (
|
33
34
|
NUM_PARTITIONS_KEY,
|
34
35
|
PARTITION_ID_KEY,
|
@@ -37,9 +38,7 @@ from flwr.common.constant import (
|
|
37
38
|
)
|
38
39
|
from flwr.common.logger import log
|
39
40
|
from flwr.common.message import Error
|
40
|
-
from flwr.common.serde import message_from_taskins, message_to_taskres
|
41
41
|
from flwr.common.typing import Run
|
42
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
43
42
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
44
43
|
|
45
44
|
from .backend import Backend, error_messages_backends, supported_backends
|
@@ -87,33 +86,33 @@ def _register_node_info_stores(
|
|
87
86
|
|
88
87
|
# pylint: disable=too-many-arguments,too-many-locals
|
89
88
|
def worker(
|
90
|
-
|
91
|
-
|
89
|
+
messageins_queue: Queue[Message],
|
90
|
+
messageres_queue: Queue[Message],
|
92
91
|
node_info_store: dict[int, DeprecatedRunInfoStore],
|
93
92
|
backend: Backend,
|
94
93
|
f_stop: threading.Event,
|
95
94
|
) -> None:
|
96
|
-
"""
|
95
|
+
"""Process messages from the queue, execute them, update context, and enqueue
|
96
|
+
replies."""
|
97
97
|
while not f_stop.is_set():
|
98
98
|
out_mssg = None
|
99
99
|
try:
|
100
100
|
# Fetch from queue with timeout. We use a timeout so
|
101
101
|
# the stopping event can be evaluated even when the queue is empty.
|
102
|
-
|
103
|
-
node_id =
|
102
|
+
message: Message = messageins_queue.get(timeout=1.0)
|
103
|
+
node_id = message.metadata.dst_node_id
|
104
104
|
|
105
105
|
# Retrieve context
|
106
|
-
context = node_info_store[node_id].retrieve_context(
|
107
|
-
|
108
|
-
|
109
|
-
message = message_from_taskins(task_ins)
|
106
|
+
context = node_info_store[node_id].retrieve_context(
|
107
|
+
run_id=message.metadata.run_id
|
108
|
+
)
|
110
109
|
|
111
110
|
# Let backend process message
|
112
111
|
out_mssg, updated_context = backend.process_message(message, context)
|
113
112
|
|
114
113
|
# Update Context
|
115
114
|
node_info_store[node_id].update_context(
|
116
|
-
|
115
|
+
message.metadata.run_id, context=updated_context
|
117
116
|
)
|
118
117
|
except Empty:
|
119
118
|
# An exception raised if queue.get times out
|
@@ -137,35 +136,33 @@ def worker(
|
|
137
136
|
|
138
137
|
finally:
|
139
138
|
if out_mssg:
|
140
|
-
#
|
141
|
-
|
142
|
-
# Store TaskRes in state
|
143
|
-
taskres_queue.put(task_res)
|
139
|
+
# Store reply Messages in state
|
140
|
+
messageres_queue.put(out_mssg)
|
144
141
|
|
145
142
|
|
146
|
-
def
|
143
|
+
def add_messages_to_queue(
|
147
144
|
state: LinkState,
|
148
|
-
queue:
|
145
|
+
queue: Queue[Message],
|
149
146
|
nodes_mapping: NodeToPartitionMapping,
|
150
147
|
f_stop: threading.Event,
|
151
148
|
) -> None:
|
152
|
-
"""Put
|
149
|
+
"""Put Messages in the queue from the LinkState."""
|
153
150
|
while not f_stop.is_set():
|
154
151
|
for node_id in nodes_mapping.keys():
|
155
|
-
|
156
|
-
for
|
157
|
-
queue.put(
|
152
|
+
message_ins_list = state.get_message_ins(node_id=node_id, limit=1)
|
153
|
+
for msg in message_ins_list:
|
154
|
+
queue.put(msg)
|
158
155
|
sleep(0.1)
|
159
156
|
|
160
157
|
|
161
|
-
def
|
162
|
-
state: LinkState, queue:
|
158
|
+
def put_message_into_state(
|
159
|
+
state: LinkState, queue: Queue[Message], f_stop: threading.Event
|
163
160
|
) -> None:
|
164
|
-
"""
|
161
|
+
"""Store reply Messages into the LinkState from the queue."""
|
165
162
|
while not f_stop.is_set():
|
166
163
|
try:
|
167
|
-
|
168
|
-
state.
|
164
|
+
message_reply = queue.get(timeout=1.0)
|
165
|
+
state.store_message_res(message_reply)
|
169
166
|
except Empty:
|
170
167
|
# queue is empty when timeout was triggered
|
171
168
|
pass
|
@@ -181,8 +178,8 @@ def run_api(
|
|
181
178
|
f_stop: threading.Event,
|
182
179
|
) -> None:
|
183
180
|
"""Run the VCE."""
|
184
|
-
|
185
|
-
|
181
|
+
messageins_queue: Queue[Message] = Queue()
|
182
|
+
messageres_queue: Queue[Message] = Queue()
|
186
183
|
|
187
184
|
try:
|
188
185
|
|
@@ -196,10 +193,10 @@ def run_api(
|
|
196
193
|
state = state_factory.state()
|
197
194
|
|
198
195
|
extractor_th = threading.Thread(
|
199
|
-
target=
|
196
|
+
target=add_messages_to_queue,
|
200
197
|
args=(
|
201
198
|
state,
|
202
|
-
|
199
|
+
messageins_queue,
|
203
200
|
nodes_mapping,
|
204
201
|
f_stop,
|
205
202
|
),
|
@@ -207,10 +204,10 @@ def run_api(
|
|
207
204
|
extractor_th.start()
|
208
205
|
|
209
206
|
injector_th = threading.Thread(
|
210
|
-
target=
|
207
|
+
target=put_message_into_state,
|
211
208
|
args=(
|
212
209
|
state,
|
213
|
-
|
210
|
+
messageres_queue,
|
214
211
|
f_stop,
|
215
212
|
),
|
216
213
|
)
|
@@ -220,8 +217,8 @@ def run_api(
|
|
220
217
|
_ = [
|
221
218
|
executor.submit(
|
222
219
|
worker,
|
223
|
-
|
224
|
-
|
220
|
+
messageins_queue,
|
221
|
+
messageres_queue,
|
225
222
|
node_info_stores,
|
226
223
|
backend,
|
227
224
|
f_stop,
|
@@ -33,18 +33,15 @@ from flwr.common.constant import (
|
|
33
33
|
)
|
34
34
|
from flwr.common.record import ConfigsRecord
|
35
35
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
36
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
37
36
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
38
|
-
from flwr.server.utils import validate_message
|
37
|
+
from flwr.server.utils import validate_message
|
39
38
|
|
40
39
|
from .utils import (
|
41
40
|
generate_rand_int_from_bytes,
|
42
41
|
has_valid_sub_status,
|
43
42
|
is_valid_transition,
|
44
43
|
verify_found_message_replies,
|
45
|
-
verify_found_taskres,
|
46
44
|
verify_message_ids,
|
47
|
-
verify_taskins_ids,
|
48
45
|
)
|
49
46
|
|
50
47
|
|
@@ -71,9 +68,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
71
68
|
self.run_ids: dict[int, RunRecord] = {}
|
72
69
|
self.contexts: dict[int, Context] = {}
|
73
70
|
self.federation_options: dict[int, ConfigsRecord] = {}
|
74
|
-
self.task_ins_store: dict[UUID, TaskIns] = {}
|
75
|
-
self.task_res_store: dict[UUID, TaskRes] = {}
|
76
|
-
self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
|
77
71
|
self.message_ins_store: dict[UUID, Message] = {}
|
78
72
|
self.message_res_store: dict[UUID, Message] = {}
|
79
73
|
self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
|
@@ -82,45 +76,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
82
76
|
|
83
77
|
self.lock = threading.RLock()
|
84
78
|
|
85
|
-
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
86
|
-
"""Store one TaskIns."""
|
87
|
-
# Validate task
|
88
|
-
errors = validate_task_ins_or_res(task_ins)
|
89
|
-
if any(errors):
|
90
|
-
log(ERROR, errors)
|
91
|
-
return None
|
92
|
-
# Validate run_id
|
93
|
-
if task_ins.run_id not in self.run_ids:
|
94
|
-
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
95
|
-
return None
|
96
|
-
# Validate source node ID
|
97
|
-
if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
|
98
|
-
log(
|
99
|
-
ERROR,
|
100
|
-
"Invalid source node ID for TaskIns: %s",
|
101
|
-
task_ins.task.producer.node_id,
|
102
|
-
)
|
103
|
-
return None
|
104
|
-
# Validate destination node ID
|
105
|
-
if task_ins.task.consumer.node_id not in self.node_ids:
|
106
|
-
log(
|
107
|
-
ERROR,
|
108
|
-
"Invalid destination node ID for TaskIns: %s",
|
109
|
-
task_ins.task.consumer.node_id,
|
110
|
-
)
|
111
|
-
return None
|
112
|
-
|
113
|
-
# Create task_id
|
114
|
-
task_id = uuid4()
|
115
|
-
|
116
|
-
# Store TaskIns
|
117
|
-
task_ins.task_id = str(task_id)
|
118
|
-
with self.lock:
|
119
|
-
self.task_ins_store[task_id] = task_ins
|
120
|
-
|
121
|
-
# Return the new task_id
|
122
|
-
return task_id
|
123
|
-
|
124
79
|
def store_message_ins(self, message: Message) -> Optional[UUID]:
|
125
80
|
"""Store one Message."""
|
126
81
|
# Validate message
|
@@ -161,33 +116,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
161
116
|
# Return the new message_id
|
162
117
|
return message_id
|
163
118
|
|
164
|
-
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
165
|
-
"""Get all TaskIns that have not been delivered yet."""
|
166
|
-
if limit is not None and limit < 1:
|
167
|
-
raise AssertionError("`limit` must be >= 1")
|
168
|
-
|
169
|
-
# Find TaskIns for node_id that were not delivered yet
|
170
|
-
task_ins_list: list[TaskIns] = []
|
171
|
-
current_time = time.time()
|
172
|
-
with self.lock:
|
173
|
-
for _, task_ins in self.task_ins_store.items():
|
174
|
-
if (
|
175
|
-
task_ins.task.consumer.node_id == node_id
|
176
|
-
and task_ins.task.delivered_at == ""
|
177
|
-
and task_ins.task.created_at + task_ins.task.ttl > current_time
|
178
|
-
):
|
179
|
-
task_ins_list.append(task_ins)
|
180
|
-
if limit and len(task_ins_list) == limit:
|
181
|
-
break
|
182
|
-
|
183
|
-
# Mark all of them as delivered
|
184
|
-
delivered_at = now().isoformat()
|
185
|
-
for task_ins in task_ins_list:
|
186
|
-
task_ins.task.delivered_at = delivered_at
|
187
|
-
|
188
|
-
# Return TaskIns
|
189
|
-
return task_ins_list
|
190
|
-
|
191
119
|
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
192
120
|
"""Get all Messages that have not been delivered yet."""
|
193
121
|
if limit is not None and limit < 1:
|
@@ -216,78 +144,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
216
144
|
# Return list of messages
|
217
145
|
return message_ins_list
|
218
146
|
|
219
|
-
# pylint: disable=R0911
|
220
|
-
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
|
221
|
-
"""Store one TaskRes."""
|
222
|
-
# Validate task
|
223
|
-
errors = validate_task_ins_or_res(task_res)
|
224
|
-
if any(errors):
|
225
|
-
log(ERROR, errors)
|
226
|
-
return None
|
227
|
-
|
228
|
-
with self.lock:
|
229
|
-
# Check if the TaskIns it is replying to exists and is valid
|
230
|
-
task_ins_id = task_res.task.ancestry[0]
|
231
|
-
task_ins = self.task_ins_store.get(UUID(task_ins_id))
|
232
|
-
|
233
|
-
# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
|
234
|
-
if (
|
235
|
-
task_ins
|
236
|
-
and task_res
|
237
|
-
and task_ins.task.consumer.node_id != task_res.task.producer.node_id
|
238
|
-
):
|
239
|
-
return None
|
240
|
-
|
241
|
-
if task_ins is None:
|
242
|
-
log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id)
|
243
|
-
return None
|
244
|
-
|
245
|
-
if task_ins.task.created_at + task_ins.task.ttl <= time.time():
|
246
|
-
log(
|
247
|
-
ERROR,
|
248
|
-
"Failed to store TaskRes: TaskIns with task_id %s has expired.",
|
249
|
-
task_ins_id,
|
250
|
-
)
|
251
|
-
return None
|
252
|
-
|
253
|
-
# Fail if the TaskRes TTL exceeds the
|
254
|
-
# expiration time of the TaskIns it replies to.
|
255
|
-
# Condition: TaskIns.created_at + TaskIns.ttl ≥
|
256
|
-
# TaskRes.created_at + TaskRes.ttl
|
257
|
-
# A small tolerance is introduced to account
|
258
|
-
# for floating-point precision issues.
|
259
|
-
max_allowed_ttl = (
|
260
|
-
task_ins.task.created_at + task_ins.task.ttl - task_res.task.created_at
|
261
|
-
)
|
262
|
-
if task_res.task.ttl and (
|
263
|
-
task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
|
264
|
-
):
|
265
|
-
log(
|
266
|
-
WARNING,
|
267
|
-
"Received TaskRes with TTL %.2f "
|
268
|
-
"exceeding the allowed maximum TTL %.2f.",
|
269
|
-
task_res.task.ttl,
|
270
|
-
max_allowed_ttl,
|
271
|
-
)
|
272
|
-
return None
|
273
|
-
|
274
|
-
# Validate run_id
|
275
|
-
if task_res.run_id not in self.run_ids:
|
276
|
-
log(ERROR, "`run_id` is invalid")
|
277
|
-
return None
|
278
|
-
|
279
|
-
# Create task_id
|
280
|
-
task_id = uuid4()
|
281
|
-
|
282
|
-
# Store TaskRes
|
283
|
-
task_res.task_id = str(task_id)
|
284
|
-
with self.lock:
|
285
|
-
self.task_res_store[task_id] = task_res
|
286
|
-
self.task_ins_id_to_task_res_id[UUID(task_ins_id)] = task_id
|
287
|
-
|
288
|
-
# Return the new task_id
|
289
|
-
return task_id
|
290
|
-
|
291
147
|
# pylint: disable=R0911
|
292
148
|
def store_message_res(self, message: Message) -> Optional[UUID]:
|
293
149
|
"""Store one Message."""
|
@@ -369,43 +225,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
369
225
|
# Return the new message_id
|
370
226
|
return message_id
|
371
227
|
|
372
|
-
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
373
|
-
"""Get TaskRes for the given TaskIns IDs."""
|
374
|
-
ret: dict[UUID, TaskRes] = {}
|
375
|
-
|
376
|
-
with self.lock:
|
377
|
-
current = time.time()
|
378
|
-
|
379
|
-
# Verify TaskIns IDs
|
380
|
-
ret = verify_taskins_ids(
|
381
|
-
inquired_taskins_ids=task_ids,
|
382
|
-
found_taskins_dict=self.task_ins_store,
|
383
|
-
current_time=current,
|
384
|
-
)
|
385
|
-
|
386
|
-
# Find all TaskRes
|
387
|
-
task_res_found: list[TaskRes] = []
|
388
|
-
for task_id in task_ids:
|
389
|
-
# If TaskRes exists and is not delivered, add it to the list
|
390
|
-
if task_res_id := self.task_ins_id_to_task_res_id.get(task_id):
|
391
|
-
task_res = self.task_res_store[task_res_id]
|
392
|
-
if task_res.task.delivered_at == "":
|
393
|
-
task_res_found.append(task_res)
|
394
|
-
tmp_ret_dict = verify_found_taskres(
|
395
|
-
inquired_taskins_ids=task_ids,
|
396
|
-
found_taskins_dict=self.task_ins_store,
|
397
|
-
found_taskres_list=task_res_found,
|
398
|
-
current_time=current,
|
399
|
-
)
|
400
|
-
ret.update(tmp_ret_dict)
|
401
|
-
|
402
|
-
# Mark existing TaskRes to be returned as delivered
|
403
|
-
delivered_at = now().isoformat()
|
404
|
-
for task_res in task_res_found:
|
405
|
-
task_res.task.delivered_at = delivered_at
|
406
|
-
|
407
|
-
return list(ret.values())
|
408
|
-
|
409
228
|
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
410
229
|
"""Get reply Messages for the given Message IDs."""
|
411
230
|
ret: dict[UUID, Message] = {}
|
@@ -445,21 +264,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
445
264
|
|
446
265
|
return list(ret.values())
|
447
266
|
|
448
|
-
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
449
|
-
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
450
|
-
if not task_ins_ids:
|
451
|
-
return
|
452
|
-
|
453
|
-
with self.lock:
|
454
|
-
for task_id in task_ins_ids:
|
455
|
-
# Delete TaskIns
|
456
|
-
if task_id in self.task_ins_store:
|
457
|
-
del self.task_ins_store[task_id]
|
458
|
-
# Delete TaskRes
|
459
|
-
if task_id in self.task_ins_id_to_task_res_id:
|
460
|
-
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
461
|
-
del self.task_res_store[task_res_id]
|
462
|
-
|
463
267
|
def delete_messages(self, message_ins_ids: set[UUID]) -> None:
|
464
268
|
"""Delete a Message and its reply based on provided Message IDs."""
|
465
269
|
if not message_ins_ids:
|
@@ -477,16 +281,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
477
281
|
)
|
478
282
|
del self.message_res_store[message_res_id]
|
479
283
|
|
480
|
-
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
481
|
-
"""Get all TaskIns IDs for the given run_id."""
|
482
|
-
task_id_list: set[UUID] = set()
|
483
|
-
with self.lock:
|
484
|
-
for task_id, task_ins in self.task_ins_store.items():
|
485
|
-
if task_ins.run_id == run_id:
|
486
|
-
task_id_list.add(task_id)
|
487
|
-
|
488
|
-
return task_id_list
|
489
|
-
|
490
284
|
def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
491
285
|
"""Get all instruction Message IDs for the given run_id."""
|
492
286
|
message_id_list: set[UUID] = set()
|
@@ -497,13 +291,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
497
291
|
|
498
292
|
return message_id_list
|
499
293
|
|
500
|
-
def num_task_ins(self) -> int:
|
501
|
-
"""Calculate the number of task_ins in store.
|
502
|
-
|
503
|
-
This includes delivered but not yet deleted task_ins.
|
504
|
-
"""
|
505
|
-
return len(self.task_ins_store)
|
506
|
-
|
507
294
|
def num_message_ins(self) -> int:
|
508
295
|
"""Calculate the number of instruction Messages in store.
|
509
296
|
|
@@ -511,13 +298,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
511
298
|
"""
|
512
299
|
return len(self.message_ins_store)
|
513
300
|
|
514
|
-
def num_task_res(self) -> int:
|
515
|
-
"""Calculate the number of task_res in store.
|
516
|
-
|
517
|
-
This includes delivered but not yet deleted task_res.
|
518
|
-
"""
|
519
|
-
return len(self.task_res_store)
|
520
|
-
|
521
301
|
def num_message_res(self) -> int:
|
522
302
|
"""Calculate the number of reply Messages in store.
|
523
303
|
|
@@ -22,30 +22,11 @@ from uuid import UUID
|
|
22
22
|
from flwr.common import Context, Message
|
23
23
|
from flwr.common.record import ConfigsRecord
|
24
24
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
25
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
26
25
|
|
27
26
|
|
28
27
|
class LinkState(abc.ABC): # pylint: disable=R0904
|
29
28
|
"""Abstract LinkState."""
|
30
29
|
|
31
|
-
@abc.abstractmethod
|
32
|
-
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
33
|
-
"""Store one TaskIns.
|
34
|
-
|
35
|
-
Usually, the ServerAppIo API calls this to schedule instructions.
|
36
|
-
|
37
|
-
Stores the value of the `task_ins` in the link state and, if successful,
|
38
|
-
returns the `task_id` (UUID) of the `task_ins`. If, for any reason,
|
39
|
-
storing the `task_ins` fails, `None` is returned.
|
40
|
-
|
41
|
-
Constraints
|
42
|
-
-----------
|
43
|
-
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
44
|
-
|
45
|
-
If `task_ins.run_id` is invalid, then
|
46
|
-
storing the `task_ins` MUST fail.
|
47
|
-
"""
|
48
|
-
|
49
30
|
@abc.abstractmethod
|
50
31
|
def store_message_ins(self, message: Message) -> Optional[UUID]:
|
51
32
|
"""Store one Message.
|
@@ -64,28 +45,6 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
64
45
|
storing the `message` MUST fail.
|
65
46
|
"""
|
66
47
|
|
67
|
-
@abc.abstractmethod
|
68
|
-
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
69
|
-
"""Get TaskIns optionally filtered by node_id.
|
70
|
-
|
71
|
-
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
72
|
-
TaskIns.
|
73
|
-
|
74
|
-
Constraints
|
75
|
-
-----------
|
76
|
-
Retrieve all TaskIns where
|
77
|
-
|
78
|
-
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
79
|
-
2. the `task_ins.task.delivered_at` equals `""`.
|
80
|
-
|
81
|
-
|
82
|
-
If `delivered_at` MUST BE set (not `""`) otherwise the TaskIns MUST not be in
|
83
|
-
the result.
|
84
|
-
|
85
|
-
If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
|
86
|
-
`limit` is set, it has to be greater zero.
|
87
|
-
"""
|
88
|
-
|
89
48
|
@abc.abstractmethod
|
90
49
|
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
91
50
|
"""Get zero or more `Message` objects for the provided `node_id`.
|
@@ -101,24 +60,6 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
101
60
|
`limit` is set, it has to be greater zero.
|
102
61
|
"""
|
103
62
|
|
104
|
-
@abc.abstractmethod
|
105
|
-
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
|
106
|
-
"""Store one TaskRes.
|
107
|
-
|
108
|
-
Usually, the Fleet API calls this for Nodes returning results.
|
109
|
-
|
110
|
-
Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
|
111
|
-
the `task_res`. If storing the `task_res` fails, `None` is returned.
|
112
|
-
|
113
|
-
Constraints
|
114
|
-
-----------
|
115
|
-
|
116
|
-
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
117
|
-
|
118
|
-
If `task_res.run_id` is invalid, then
|
119
|
-
storing the `task_res` MUST fail.
|
120
|
-
"""
|
121
|
-
|
122
63
|
@abc.abstractmethod
|
123
64
|
def store_message_res(self, message: Message) -> Optional[UUID]:
|
124
65
|
"""Store one Message.
|
@@ -136,31 +77,6 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
136
77
|
storing the `message` MUST fail.
|
137
78
|
"""
|
138
79
|
|
139
|
-
@abc.abstractmethod
|
140
|
-
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
141
|
-
"""Get TaskRes for the given TaskIns IDs.
|
142
|
-
|
143
|
-
This method is typically called by the ServerAppIo API to obtain
|
144
|
-
results (TaskRes) for previously scheduled instructions (TaskIns).
|
145
|
-
For each task_id provided, this method returns one of the following responses:
|
146
|
-
|
147
|
-
- An error TaskRes if the corresponding TaskIns does not exist or has expired.
|
148
|
-
- An error TaskRes if the corresponding TaskRes exists but has expired.
|
149
|
-
- The valid TaskRes if the TaskIns has a corresponding valid TaskRes.
|
150
|
-
- Nothing if the TaskIns is still valid and waiting for a TaskRes.
|
151
|
-
|
152
|
-
Parameters
|
153
|
-
----------
|
154
|
-
task_ids : set[UUID]
|
155
|
-
A set of TaskIns IDs for which to retrieve results (TaskRes).
|
156
|
-
|
157
|
-
Returns
|
158
|
-
-------
|
159
|
-
list[TaskRes]
|
160
|
-
A list of TaskRes corresponding to the given task IDs. If no
|
161
|
-
TaskRes could be found for any of the task IDs, an empty list is returned.
|
162
|
-
"""
|
163
|
-
|
164
80
|
@abc.abstractmethod
|
165
81
|
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
166
82
|
"""Get reply Messages for the given Message IDs.
|
@@ -188,39 +104,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
188
104
|
carrying an Error.
|
189
105
|
"""
|
190
106
|
|
191
|
-
@abc.abstractmethod
|
192
|
-
def num_task_ins(self) -> int:
|
193
|
-
"""Calculate the number of task_ins in store.
|
194
|
-
|
195
|
-
This includes delivered but not yet deleted task_ins.
|
196
|
-
"""
|
197
|
-
|
198
107
|
@abc.abstractmethod
|
199
108
|
def num_message_ins(self) -> int:
|
200
109
|
"""Calculate the number of Messages awaiting a reply."""
|
201
110
|
|
202
|
-
@abc.abstractmethod
|
203
|
-
def num_task_res(self) -> int:
|
204
|
-
"""Calculate the number of task_res in store.
|
205
|
-
|
206
|
-
This includes delivered but not yet deleted task_res.
|
207
|
-
"""
|
208
|
-
|
209
111
|
@abc.abstractmethod
|
210
112
|
def num_message_res(self) -> int:
|
211
113
|
"""Calculate the number of reply Messages in store."""
|
212
114
|
|
213
|
-
@abc.abstractmethod
|
214
|
-
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
215
|
-
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
|
216
|
-
|
217
|
-
Parameters
|
218
|
-
----------
|
219
|
-
task_ins_ids : set[UUID]
|
220
|
-
A set of TaskIns IDs. For each ID in the set, the corresponding
|
221
|
-
TaskIns and its associated TaskRes will be deleted.
|
222
|
-
"""
|
223
|
-
|
224
115
|
@abc.abstractmethod
|
225
116
|
def delete_messages(self, message_ins_ids: set[UUID]) -> None:
|
226
117
|
"""Delete a Message and its reply based on provided Message IDs.
|
@@ -232,10 +123,6 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
232
123
|
Message and its associated reply Message will be deleted.
|
233
124
|
"""
|
234
125
|
|
235
|
-
@abc.abstractmethod
|
236
|
-
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
237
|
-
"""Get all TaskIns IDs for the given run_id."""
|
238
|
-
|
239
126
|
@abc.abstractmethod
|
240
127
|
def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
241
128
|
"""Get all instruction Message IDs for the given run_id."""
|