flwr-nightly 1.16.0.dev20250304__py3-none-any.whl → 1.16.0.dev20250306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flwr/server/compat/app.py CHANGED
@@ -79,10 +79,13 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
79
79
  log(INFO, "")
80
80
 
81
81
  # Start the thread updating nodes
82
- thread, f_stop = start_update_client_manager_thread(
82
+ thread, f_stop, c_done = start_update_client_manager_thread(
83
83
  driver, initialized_server.client_manager()
84
84
  )
85
85
 
86
+ # Wait until the node registration done
87
+ c_done.wait()
88
+
86
89
  # Start training
87
90
  hist = run_fl(
88
91
  server=initialized_server,
@@ -27,7 +27,7 @@ from ..driver import Driver
27
27
  def start_update_client_manager_thread(
28
28
  driver: Driver,
29
29
  client_manager: ClientManager,
30
- ) -> tuple[threading.Thread, threading.Event]:
30
+ ) -> tuple[threading.Thread, threading.Event, threading.Event]:
31
31
  """Periodically update the nodes list in the client manager in a thread.
32
32
 
33
33
  This function starts a thread that periodically uses the associated driver to
@@ -51,26 +51,31 @@ def start_update_client_manager_thread(
51
51
  A thread that updates the ClientManager and handles the stop event.
52
52
  threading.Event
53
53
  An event that, when set, signals the thread to stop.
54
+ threading.Event
55
+ An event that, when set, signals the node registration done.
54
56
  """
55
57
  f_stop = threading.Event()
58
+ c_done = threading.Event()
56
59
  thread = threading.Thread(
57
60
  target=_update_client_manager,
58
61
  args=(
59
62
  driver,
60
63
  client_manager,
61
64
  f_stop,
65
+ c_done,
62
66
  ),
63
67
  daemon=True,
64
68
  )
65
69
  thread.start()
66
70
 
67
- return thread, f_stop
71
+ return thread, f_stop, c_done
68
72
 
69
73
 
70
74
  def _update_client_manager(
71
75
  driver: Driver,
72
76
  client_manager: ClientManager,
73
77
  f_stop: threading.Event,
78
+ c_done: threading.Event,
74
79
  ) -> None:
75
80
  """Update the nodes list in the client manager."""
76
81
  # Loop until the driver is disconnected
@@ -102,6 +107,9 @@ def _update_client_manager(
102
107
  else:
103
108
  raise RuntimeError("Could not register node.")
104
109
 
110
+ # Flag first pass for nodes registration is completed
111
+ c_done.set()
112
+
105
113
  # Sleep for 3 seconds
106
114
  if not f_stop.is_set():
107
115
  f_stop.wait(3)
@@ -85,7 +85,7 @@ class Driver(ABC):
85
85
  """
86
86
 
87
87
  @abstractmethod
88
- def get_node_ids(self) -> list[int]:
88
+ def get_node_ids(self) -> Iterable[int]:
89
89
  """Get node IDs."""
90
90
 
91
91
  @abstractmethod
@@ -183,7 +183,7 @@ class GrpcDriver(Driver):
183
183
  )
184
184
  return Message(metadata=metadata, content=content)
185
185
 
186
- def get_node_ids(self) -> list[int]:
186
+ def get_node_ids(self) -> Iterable[int]:
187
187
  """Get node IDs."""
188
188
  # Call GrpcDriverStub method
189
189
  res: GetNodesResponse = self._stub.GetNodes(
@@ -109,9 +109,9 @@ class InMemoryDriver(Driver):
109
109
  )
110
110
  return Message(metadata=metadata, content=content)
111
111
 
112
- def get_node_ids(self) -> list[int]:
112
+ def get_node_ids(self) -> Iterable[int]:
113
113
  """Get node IDs."""
114
- return list(self.state.get_nodes(cast(Run, self._run).run_id))
114
+ return self.state.get_nodes(cast(Run, self._run).run_id)
115
115
 
116
116
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
117
117
  """Push messages to specified node IDs.
@@ -23,7 +23,7 @@ from logging import ERROR, WARNING
23
23
  from typing import Optional
24
24
  from uuid import UUID, uuid4
25
25
 
26
- from flwr.common import Context, log, now
26
+ from flwr.common import Context, Message, log, now
27
27
  from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
@@ -35,13 +35,15 @@ from flwr.common.record import ConfigsRecord
35
35
  from flwr.common.typing import Run, RunStatus, UserConfig
36
36
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
37
37
  from flwr.server.superlink.linkstate.linkstate import LinkState
38
- from flwr.server.utils import validate_task_ins_or_res
38
+ from flwr.server.utils import validate_message, validate_task_ins_or_res
39
39
 
40
40
  from .utils import (
41
41
  generate_rand_int_from_bytes,
42
42
  has_valid_sub_status,
43
43
  is_valid_transition,
44
+ verify_found_message_replies,
44
45
  verify_found_taskres,
46
+ verify_message_ids,
45
47
  verify_taskins_ids,
46
48
  )
47
49
 
@@ -72,6 +74,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
72
74
  self.task_ins_store: dict[UUID, TaskIns] = {}
73
75
  self.task_res_store: dict[UUID, TaskRes] = {}
74
76
  self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
77
+ self.message_ins_store: dict[UUID, Message] = {}
78
+ self.message_res_store: dict[UUID, Message] = {}
79
+ self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
75
80
 
76
81
  self.node_public_keys: set[bytes] = set()
77
82
 
@@ -116,6 +121,46 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
116
121
  # Return the new task_id
117
122
  return task_id
118
123
 
124
+ def store_message_ins(self, message: Message) -> Optional[UUID]:
125
+ """Store one Message."""
126
+ # Validate message
127
+ errors = validate_message(message, is_reply_message=False)
128
+ if any(errors):
129
+ log(ERROR, errors)
130
+ return None
131
+ # Validate run_id
132
+ if message.metadata.run_id not in self.run_ids:
133
+ log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
134
+ return None
135
+ # Validate source node ID
136
+ if message.metadata.src_node_id != SUPERLINK_NODE_ID:
137
+ log(
138
+ ERROR,
139
+ "Invalid source node ID for Message: %s",
140
+ message.metadata.src_node_id,
141
+ )
142
+ return None
143
+ # Validate destination node ID
144
+ if message.metadata.dst_node_id not in self.node_ids:
145
+ log(
146
+ ERROR,
147
+ "Invalid destination node ID for Message: %s",
148
+ message.metadata.dst_node_id,
149
+ )
150
+ return None
151
+
152
+ # Create message_id
153
+ message_id = uuid4()
154
+
155
+ # Store Message
156
+ # pylint: disable-next=W0212
157
+ message.metadata._message_id = str(message_id) # type: ignore
158
+ with self.lock:
159
+ self.message_ins_store[message_id] = message
160
+
161
+ # Return the new message_id
162
+ return message_id
163
+
119
164
  def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
120
165
  """Get all TaskIns that have not been delivered yet."""
121
166
  if limit is not None and limit < 1:
@@ -143,6 +188,34 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
143
188
  # Return TaskIns
144
189
  return task_ins_list
145
190
 
191
+ def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
192
+ """Get all Messages that have not been delivered yet."""
193
+ if limit is not None and limit < 1:
194
+ raise AssertionError("`limit` must be >= 1")
195
+
196
+ # Find Message for node_id that were not delivered yet
197
+ message_ins_list: list[Message] = []
198
+ current_time = time.time()
199
+ with self.lock:
200
+ for _, msg_ins in self.message_ins_store.items():
201
+ if (
202
+ msg_ins.metadata.dst_node_id == node_id
203
+ and msg_ins.metadata.delivered_at == ""
204
+ and msg_ins.metadata.created_at + msg_ins.metadata.ttl
205
+ > current_time
206
+ ):
207
+ message_ins_list.append(msg_ins)
208
+ if limit and len(message_ins_list) == limit:
209
+ break
210
+
211
+ # Mark all of them as delivered
212
+ delivered_at = now().isoformat()
213
+ for msg_ins in message_ins_list:
214
+ msg_ins.metadata.delivered_at = delivered_at
215
+
216
+ # Return list of messages
217
+ return message_ins_list
218
+
146
219
  # pylint: disable=R0911
147
220
  def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
148
221
  """Store one TaskRes."""
@@ -215,6 +288,87 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
215
288
  # Return the new task_id
216
289
  return task_id
217
290
 
291
+ # pylint: disable=R0911
292
+ def store_message_res(self, message: Message) -> Optional[UUID]:
293
+ """Store one Message."""
294
+ # Validate message
295
+ errors = validate_message(message, is_reply_message=True)
296
+ if any(errors):
297
+ log(ERROR, errors)
298
+ return None
299
+
300
+ res_metadata = message.metadata
301
+ with self.lock:
302
+ # Check if the Message it is replying to exists and is valid
303
+ msg_ins_id = res_metadata.reply_to_message
304
+ msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
305
+
306
+ # Ensure that dst_node_id of original Message matches the src_node_id of
307
+ # reply Message.
308
+ if (
309
+ msg_ins
310
+ and message
311
+ and msg_ins.metadata.dst_node_id != res_metadata.src_node_id
312
+ ):
313
+ return None
314
+
315
+ if msg_ins is None:
316
+ log(
317
+ ERROR,
318
+ "Message with ID %s does not exist.",
319
+ msg_ins_id,
320
+ )
321
+ return None
322
+
323
+ ins_metadata = msg_ins.metadata
324
+ if ins_metadata.created_at + ins_metadata.ttl <= time.time():
325
+ log(
326
+ ERROR,
327
+ "Failed to store Message: the message it is replying to "
328
+ "(with ID %s) has expired",
329
+ msg_ins_id,
330
+ )
331
+ return None
332
+
333
+ # Fail if the Message TTL exceeds the
334
+ # expiration time of the Message it replies to.
335
+ # Condition: ins_metadata.created_at + ins_metadata.ttl ≥
336
+ # res_metadata.created_at + res_metadata.ttl
337
+ # A small tolerance is introduced to account
338
+ # for floating-point precision issues.
339
+ max_allowed_ttl = (
340
+ ins_metadata.created_at + ins_metadata.ttl - res_metadata.created_at
341
+ )
342
+ if res_metadata.ttl and (
343
+ res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
344
+ ):
345
+ log(
346
+ WARNING,
347
+ "Received Message with TTL %.2f exceeding the allowed maximum "
348
+ "TTL %.2f.",
349
+ res_metadata.ttl,
350
+ max_allowed_ttl,
351
+ )
352
+ return None
353
+
354
+ # Validate run_id
355
+ if res_metadata.run_id != ins_metadata.run_id:
356
+ log(ERROR, "`metadata.run_id` is invalid")
357
+ return None
358
+
359
+ # Create message_id
360
+ message_id = uuid4()
361
+
362
+ # Store Message
363
+ # pylint: disable-next=W0212
364
+ message.metadata._message_id = str(message_id) # type: ignore
365
+ with self.lock:
366
+ self.message_res_store[message_id] = message
367
+ self.message_ins_id_to_message_res_id[UUID(msg_ins_id)] = message_id
368
+
369
+ # Return the new message_id
370
+ return message_id
371
+
218
372
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
219
373
  """Get TaskRes for the given TaskIns IDs."""
220
374
  ret: dict[UUID, TaskRes] = {}
@@ -252,6 +406,45 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
252
406
 
253
407
  return list(ret.values())
254
408
 
409
+ def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
410
+ """Get reply Messages for the given Message IDs."""
411
+ ret: dict[UUID, Message] = {}
412
+
413
+ with self.lock:
414
+ current = time.time()
415
+
416
+ # Verify Messge IDs
417
+ ret = verify_message_ids(
418
+ inquired_message_ids=message_ids,
419
+ found_message_ins_dict=self.message_ins_store,
420
+ current_time=current,
421
+ )
422
+
423
+ # Find all reply Messages
424
+ message_res_found: list[Message] = []
425
+ for message_id in message_ids:
426
+ # If Message exists and is not delivered, add it to the list
427
+ if message_res_id := self.message_ins_id_to_message_res_id.get(
428
+ message_id
429
+ ):
430
+ message_res = self.message_res_store[message_res_id]
431
+ if message_res.metadata.delivered_at == "":
432
+ message_res_found.append(message_res)
433
+ tmp_ret_dict = verify_found_message_replies(
434
+ inquired_message_ids=message_ids,
435
+ found_message_ins_dict=self.message_ins_store,
436
+ found_message_res_list=message_res_found,
437
+ current_time=current,
438
+ )
439
+ ret.update(tmp_ret_dict)
440
+
441
+ # Mark existing reply Messages to be returned as delivered
442
+ delivered_at = now().isoformat()
443
+ for message_res in message_res_found:
444
+ message_res.metadata.delivered_at = delivered_at
445
+
446
+ return list(ret.values())
447
+
255
448
  def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
256
449
  """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
257
450
  if not task_ins_ids:
@@ -267,6 +460,23 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
267
460
  task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
268
461
  del self.task_res_store[task_res_id]
269
462
 
463
+ def delete_messages(self, message_ins_ids: set[UUID]) -> None:
464
+ """Delete a Message and its reply based on provided Message IDs."""
465
+ if not message_ins_ids:
466
+ return
467
+
468
+ with self.lock:
469
+ for message_id in message_ins_ids:
470
+ # Delete Messages
471
+ if message_id in self.message_ins_store:
472
+ del self.message_ins_store[message_id]
473
+ # Delete Message replies
474
+ if message_id in self.message_ins_id_to_message_res_id:
475
+ message_res_id = self.message_ins_id_to_message_res_id.pop(
476
+ message_id
477
+ )
478
+ del self.message_res_store[message_res_id]
479
+
270
480
  def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
271
481
  """Get all TaskIns IDs for the given run_id."""
272
482
  task_id_list: set[UUID] = set()
@@ -277,6 +487,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
277
487
 
278
488
  return task_id_list
279
489
 
490
+ def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
491
+ """Get all instruction Message IDs for the given run_id."""
492
+ message_id_list: set[UUID] = set()
493
+ with self.lock:
494
+ for message_id, message in self.message_ins_store.items():
495
+ if message.metadata.run_id == run_id:
496
+ message_id_list.add(message_id)
497
+
498
+ return message_id_list
499
+
280
500
  def num_task_ins(self) -> int:
281
501
  """Calculate the number of task_ins in store.
282
502
 
@@ -284,6 +504,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
284
504
  """
285
505
  return len(self.task_ins_store)
286
506
 
507
+ def num_message_ins(self) -> int:
508
+ """Calculate the number of instruction Messages in store.
509
+
510
+ This includes delivered but not yet deleted.
511
+ """
512
+ return len(self.message_ins_store)
513
+
287
514
  def num_task_res(self) -> int:
288
515
  """Calculate the number of task_res in store.
289
516
 
@@ -291,6 +518,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
291
518
  """
292
519
  return len(self.task_res_store)
293
520
 
521
+ def num_message_res(self) -> int:
522
+ """Calculate the number of reply Messages in store.
523
+
524
+ This includes delivered but not yet deleted.
525
+ """
526
+ return len(self.message_res_store)
527
+
294
528
  def create_node(self, ping_interval: float) -> int:
295
529
  """Create, store in the link state, and return `node_id`."""
296
530
  # Sample a random int64 as node_id
@@ -19,7 +19,7 @@ import abc
19
19
  from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
- from flwr.common import Context
22
+ from flwr.common import Context, Message
23
23
  from flwr.common.record import ConfigsRecord
24
24
  from flwr.common.typing import Run, RunStatus, UserConfig
25
25
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
@@ -46,6 +46,24 @@ class LinkState(abc.ABC): # pylint: disable=R0904
46
46
  storing the `task_ins` MUST fail.
47
47
  """
48
48
 
49
+ @abc.abstractmethod
50
+ def store_message_ins(self, message: Message) -> Optional[UUID]:
51
+ """Store one Message.
52
+
53
+ Usually, the ServerAppIo API calls this to schedule instructions.
54
+
55
+ Stores the value of the `message` in the link state and, if successful,
56
+ returns the `message_id` (UUID) of the `message`. If, for any reason,
57
+ storing the `message` fails, `None` is returned.
58
+
59
+ Constraints
60
+ -----------
61
+ `message.metadata.dst_node_id` MUST be set (not constant.SUPERLINK_NODE_ID)
62
+
63
+ If `message.metadata.run_id` is invalid, then
64
+ storing the `message` MUST fail.
65
+ """
66
+
49
67
  @abc.abstractmethod
50
68
  def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
51
69
  """Get TaskIns optionally filtered by node_id.
@@ -68,6 +86,21 @@ class LinkState(abc.ABC): # pylint: disable=R0904
68
86
  `limit` is set, it has to be greater zero.
69
87
  """
70
88
 
89
+ @abc.abstractmethod
90
+ def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
91
+ """Get zero or more `Message` objects for the provided `node_id`.
92
+
93
+ Usually, the Fleet API calls this for Nodes planning to work on one or more
94
+ Message.
95
+
96
+ Constraints
97
+ -----------
98
+ Retrieve all Message where the `message.metadata.dst_node_id` equals `node_id`.
99
+
100
+ If `limit` is not `None`, return, at most, `limit` number of `message`. If
101
+ `limit` is set, it has to be greater zero.
102
+ """
103
+
71
104
  @abc.abstractmethod
72
105
  def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
73
106
  """Store one TaskRes.
@@ -86,6 +119,23 @@ class LinkState(abc.ABC): # pylint: disable=R0904
86
119
  storing the `task_res` MUST fail.
87
120
  """
88
121
 
122
+ @abc.abstractmethod
123
+ def store_message_res(self, message: Message) -> Optional[UUID]:
124
+ """Store one Message.
125
+
126
+ Usually, the Fleet API calls this for Nodes returning results.
127
+
128
+ Stores the Message and, if successful, returns the `message_id` (UUID) of
129
+ the `message`. If storing the `message` fails, `None` is returned.
130
+
131
+ Constraints
132
+ -----------
133
+ `message.metadata.dst_node_id` MUST be set (not constant.SUPERLINK_NODE_ID)
134
+
135
+ If `message.metadata.run_id` is invalid, then
136
+ storing the `message` MUST fail.
137
+ """
138
+
89
139
  @abc.abstractmethod
90
140
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
91
141
  """Get TaskRes for the given TaskIns IDs.
@@ -111,6 +161,33 @@ class LinkState(abc.ABC): # pylint: disable=R0904
111
161
  TaskRes could be found for any of the task IDs, an empty list is returned.
112
162
  """
113
163
 
164
+ @abc.abstractmethod
165
+ def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
166
+ """Get reply Messages for the given Message IDs.
167
+
168
+ This method is typically called by the ServerAppIo API to obtain
169
+ results (type Message) for previously scheduled instructions (type Message).
170
+ For each message_id passed, this method returns one of the following responses:
171
+
172
+ - An error Message if there was no message registered with such message IDs
173
+ or has expired.
174
+ - An error Message if the reply Message exists but has expired.
175
+ - The reply Message.
176
+ - Nothing if the Message with the passed message_id is still valid and waiting
177
+ for a reply Message.
178
+
179
+ Parameters
180
+ ----------
181
+ message_ids : set[UUID]
182
+ A set of Message IDs used to retrieve reply Messages responding to them.
183
+
184
+ Returns
185
+ -------
186
+ list[Message]
187
+ A list of reply Message responding to the given message IDs or Messages
188
+ carrying an Error.
189
+ """
190
+
114
191
  @abc.abstractmethod
115
192
  def num_task_ins(self) -> int:
116
193
  """Calculate the number of task_ins in store.
@@ -118,6 +195,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
118
195
  This includes delivered but not yet deleted task_ins.
119
196
  """
120
197
 
198
+ @abc.abstractmethod
199
+ def num_message_ins(self) -> int:
200
+ """Calculate the number of Messages awaiting a reply."""
201
+
121
202
  @abc.abstractmethod
122
203
  def num_task_res(self) -> int:
123
204
  """Calculate the number of task_res in store.
@@ -125,6 +206,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
125
206
  This includes delivered but not yet deleted task_res.
126
207
  """
127
208
 
209
+ @abc.abstractmethod
210
+ def num_message_res(self) -> int:
211
+ """Calculate the number of reply Messages in store."""
212
+
128
213
  @abc.abstractmethod
129
214
  def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
130
215
  """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
@@ -136,10 +221,25 @@ class LinkState(abc.ABC): # pylint: disable=R0904
136
221
  TaskIns and its associated TaskRes will be deleted.
137
222
  """
138
223
 
224
+ @abc.abstractmethod
225
+ def delete_messages(self, message_ins_ids: set[UUID]) -> None:
226
+ """Delete a Message and its reply based on provided Message IDs.
227
+
228
+ Parameters
229
+ ----------
230
+ message_ins_ids : set[UUID]
231
+ A set of Message IDs. For each ID in the set, the corresponding
232
+ Message and its associated reply Message will be deleted.
233
+ """
234
+
139
235
  @abc.abstractmethod
140
236
  def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
141
237
  """Get all TaskIns IDs for the given run_id."""
142
238
 
239
+ @abc.abstractmethod
240
+ def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
241
+ """Get all instruction Message IDs for the given run_id."""
242
+
143
243
  @abc.abstractmethod
144
244
  def create_node(self, ping_interval: float) -> int:
145
245
  """Create, store in the link state, and return `node_id`."""