flwr-nightly 1.12.0.dev20240916__py3-none-any.whl → 1.12.0.dev20241006__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (53) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/log.py +234 -0
  3. flwr/cli/new/new.py +1 -1
  4. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/run/run.py +17 -1
  7. flwr/client/grpc_rere_client/client_interceptor.py +3 -0
  8. flwr/client/grpc_rere_client/connection.py +3 -3
  9. flwr/client/grpc_rere_client/grpc_adapter.py +14 -3
  10. flwr/client/rest_client/connection.py +3 -3
  11. flwr/client/supernode/app.py +1 -0
  12. flwr/common/constant.py +6 -3
  13. flwr/common/secure_aggregation/secaggplus_utils.py +4 -4
  14. flwr/common/serde.py +22 -7
  15. flwr/proto/clientappio_pb2.py +1 -1
  16. flwr/proto/control_pb2.py +27 -0
  17. flwr/proto/control_pb2.pyi +7 -0
  18. flwr/proto/control_pb2_grpc.py +135 -0
  19. flwr/proto/control_pb2_grpc.pyi +53 -0
  20. flwr/proto/driver_pb2.py +15 -24
  21. flwr/proto/driver_pb2.pyi +0 -52
  22. flwr/proto/driver_pb2_grpc.py +6 -6
  23. flwr/proto/driver_pb2_grpc.pyi +4 -4
  24. flwr/proto/exec_pb2.py +1 -1
  25. flwr/proto/fab_pb2.py +8 -7
  26. flwr/proto/fab_pb2.pyi +7 -1
  27. flwr/proto/fleet_pb2.py +10 -10
  28. flwr/proto/fleet_pb2.pyi +6 -1
  29. flwr/proto/message_pb2.py +1 -1
  30. flwr/proto/node_pb2.py +1 -1
  31. flwr/proto/recordset_pb2.py +35 -33
  32. flwr/proto/recordset_pb2.pyi +40 -14
  33. flwr/proto/run_pb2.py +33 -9
  34. flwr/proto/run_pb2.pyi +150 -1
  35. flwr/proto/task_pb2.py +1 -1
  36. flwr/proto/transport_pb2.py +8 -8
  37. flwr/proto/transport_pb2.pyi +9 -6
  38. flwr/server/run_serverapp.py +2 -2
  39. flwr/server/superlink/driver/driver_servicer.py +2 -2
  40. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -2
  41. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +4 -0
  42. flwr/server/superlink/state/in_memory_state.py +17 -0
  43. flwr/server/superlink/state/sqlite_state.py +142 -24
  44. flwr/server/superlink/state/utils.py +98 -2
  45. flwr/server/utils/validator.py +6 -0
  46. flwr/superexec/deployment.py +3 -1
  47. flwr/superexec/exec_servicer.py +68 -3
  48. flwr/superexec/executor.py +2 -1
  49. {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/METADATA +4 -2
  50. {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/RECORD +53 -48
  51. {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/LICENSE +0 -0
  52. {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/WHEEL +0 -0
  53. {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/entry_points.txt +0 -0
@@ -33,7 +33,14 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
33
33
  from flwr.server.utils.validator import validate_task_ins_or_res
34
34
 
35
35
  from .state import State
36
- from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
36
+ from .utils import (
37
+ convert_sint64_to_uint64,
38
+ convert_sint64_values_in_dict_to_uint64,
39
+ convert_uint64_to_sint64,
40
+ convert_uint64_values_in_dict_to_sint64,
41
+ generate_rand_int_from_bytes,
42
+ make_node_unavailable_taskres,
43
+ )
37
44
 
38
45
  SQL_CREATE_TABLE_NODE = """
39
46
  CREATE TABLE IF NOT EXISTS node(
@@ -223,6 +230,12 @@ class SqliteState(State): # pylint: disable=R0904
223
230
  # Store TaskIns
224
231
  task_ins.task_id = str(task_id)
225
232
  data = (task_ins_to_dict(task_ins),)
233
+
234
+ # Convert values from uint64 to sint64 for SQLite
235
+ convert_uint64_values_in_dict_to_sint64(
236
+ data[0], ["run_id", "producer_node_id", "consumer_node_id"]
237
+ )
238
+
226
239
  columns = ", ".join([f":{key}" for key in data[0]])
227
240
  query = f"INSERT INTO task_ins VALUES({columns});"
228
241
 
@@ -284,6 +297,9 @@ class SqliteState(State): # pylint: disable=R0904
284
297
  AND delivered_at = ""
285
298
  """
286
299
  else:
300
+ # Convert the uint64 value to sint64 for SQLite
301
+ data["node_id"] = convert_uint64_to_sint64(node_id)
302
+
287
303
  # Retrieve all TaskIns for node_id
288
304
  query = """
289
305
  SELECT task_id
@@ -292,7 +308,6 @@ class SqliteState(State): # pylint: disable=R0904
292
308
  AND consumer_node_id == :node_id
293
309
  AND delivered_at = ""
294
310
  """
295
- data["node_id"] = node_id
296
311
 
297
312
  if limit is not None:
298
313
  query += " LIMIT :limit"
@@ -322,6 +337,12 @@ class SqliteState(State): # pylint: disable=R0904
322
337
  # Run query
323
338
  rows = self.query(query, data)
324
339
 
340
+ for row in rows:
341
+ # Convert values from sint64 to uint64
342
+ convert_sint64_values_in_dict_to_uint64(
343
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
344
+ )
345
+
325
346
  result = [dict_to_task_ins(row) for row in rows]
326
347
 
327
348
  return result
@@ -351,9 +372,26 @@ class SqliteState(State): # pylint: disable=R0904
351
372
  # Create task_id
352
373
  task_id = uuid4()
353
374
 
354
- # Store TaskIns
375
+ task_ins_id = task_res.task.ancestry[0]
376
+ task_ins = self.get_valid_task_ins(task_ins_id)
377
+ if task_ins is None:
378
+ log(
379
+ ERROR,
380
+ "Failed to store TaskRes: "
381
+ "TaskIns with task_id %s does not exist or has expired.",
382
+ task_ins_id,
383
+ )
384
+ return None
385
+
386
+ # Store TaskRes
355
387
  task_res.task_id = str(task_id)
356
388
  data = (task_res_to_dict(task_res),)
389
+
390
+ # Convert values from uint64 to sint64 for SQLite
391
+ convert_uint64_values_in_dict_to_sint64(
392
+ data[0], ["run_id", "producer_node_id", "consumer_node_id"]
393
+ )
394
+
357
395
  columns = ", ".join([f":{key}" for key in data[0]])
358
396
  query = f"INSERT INTO task_res VALUES({columns});"
359
397
 
@@ -431,6 +469,12 @@ class SqliteState(State): # pylint: disable=R0904
431
469
  # Run query
432
470
  rows = self.query(query, data)
433
471
 
472
+ for row in rows:
473
+ # Convert values from sint64 to uint64
474
+ convert_sint64_values_in_dict_to_uint64(
475
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
476
+ )
477
+
434
478
  result = [dict_to_task_res(row) for row in rows]
435
479
 
436
480
  # 1. Query: Fetch consumer_node_id of remaining task_ids
@@ -474,6 +518,13 @@ class SqliteState(State): # pylint: disable=R0904
474
518
  for row in task_ins_rows:
475
519
  if limit and len(result) == limit:
476
520
  break
521
+
522
+ for row in rows:
523
+ # Convert values from sint64 to uint64
524
+ convert_sint64_values_in_dict_to_uint64(
525
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
526
+ )
527
+
477
528
  task_ins = dict_to_task_ins(row)
478
529
  err_taskres = make_node_unavailable_taskres(
479
530
  ref_taskins=task_ins,
@@ -544,8 +595,11 @@ class SqliteState(State): # pylint: disable=R0904
544
595
  self, ping_interval: float, public_key: Optional[bytes] = None
545
596
  ) -> int:
546
597
  """Create, store in state, and return `node_id`."""
547
- # Sample a random int64 as node_id
548
- node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
598
+ # Sample a random uint64 as node_id
599
+ uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
600
+
601
+ # Convert the uint64 value to sint64 for SQLite
602
+ sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
549
603
 
550
604
  query = "SELECT node_id FROM node WHERE public_key = :public_key;"
551
605
  row = self.query(query, {"public_key": public_key})
@@ -562,17 +616,28 @@ class SqliteState(State): # pylint: disable=R0904
562
616
 
563
617
  try:
564
618
  self.query(
565
- query, (node_id, time.time() + ping_interval, ping_interval, public_key)
619
+ query,
620
+ (
621
+ sint64_node_id,
622
+ time.time() + ping_interval,
623
+ ping_interval,
624
+ public_key,
625
+ ),
566
626
  )
567
627
  except sqlite3.IntegrityError:
568
628
  log(ERROR, "Unexpected node registration failure.")
569
629
  return 0
570
- return node_id
630
+
631
+ # Note: we need to return the uint64 value of the node_id
632
+ return uint64_node_id
571
633
 
572
634
  def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
573
635
  """Delete a node."""
636
+ # Convert the uint64 value to sint64 for SQLite
637
+ sint64_node_id = convert_uint64_to_sint64(node_id)
638
+
574
639
  query = "DELETE FROM node WHERE node_id = ?"
575
- params = (node_id,)
640
+ params = (sint64_node_id,)
576
641
 
577
642
  if public_key is not None:
578
643
  query += " AND public_key = ?"
@@ -597,15 +662,20 @@ class SqliteState(State): # pylint: disable=R0904
597
662
  If the provided `run_id` does not exist or has no matching nodes,
598
663
  an empty `Set` MUST be returned.
599
664
  """
665
+ # Convert the uint64 value to sint64 for SQLite
666
+ sint64_run_id = convert_uint64_to_sint64(run_id)
667
+
600
668
  # Validate run ID
601
669
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
602
- if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
670
+ if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
603
671
  return set()
604
672
 
605
673
  # Get nodes
606
674
  query = "SELECT node_id FROM node WHERE online_until > ?;"
607
675
  rows = self.query(query, (time.time(),))
608
- result: set[int] = {row["node_id"] for row in rows}
676
+
677
+ # Convert sint64 node_ids to uint64
678
+ result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
609
679
  return result
610
680
 
611
681
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
@@ -614,7 +684,11 @@ class SqliteState(State): # pylint: disable=R0904
614
684
  row = self.query(query, {"public_key": node_public_key})
615
685
  if len(row) > 0:
616
686
  node_id: int = row[0]["node_id"]
617
- return node_id
687
+
688
+ # Convert the sint64 value to uint64 after reading from SQLite
689
+ uint64_node_id = convert_sint64_to_uint64(node_id)
690
+
691
+ return uint64_node_id
618
692
  return None
619
693
 
620
694
  def create_run(
@@ -626,12 +700,15 @@ class SqliteState(State): # pylint: disable=R0904
626
700
  ) -> int:
627
701
  """Create a new run for the specified `fab_id` and `fab_version`."""
628
702
  # Sample a random int64 as run_id
629
- run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
703
+ uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
704
+
705
+ # Convert the uint64 value to sint64 for SQLite
706
+ sint64_run_id = convert_uint64_to_sint64(uint64_run_id)
630
707
 
631
708
  # Check conflicts
632
709
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
633
- # If run_id does not exist
634
- if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
710
+ # If sint64_run_id does not exist
711
+ if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
635
712
  query = (
636
713
  "INSERT INTO run "
637
714
  "(run_id, fab_id, fab_version, fab_hash, override_config)"
@@ -639,14 +716,22 @@ class SqliteState(State): # pylint: disable=R0904
639
716
  )
640
717
  if fab_hash:
641
718
  self.query(
642
- query, (run_id, "", "", fab_hash, json.dumps(override_config))
719
+ query,
720
+ (sint64_run_id, "", "", fab_hash, json.dumps(override_config)),
643
721
  )
644
722
  else:
645
723
  self.query(
646
724
  query,
647
- (run_id, fab_id, fab_version, "", json.dumps(override_config)),
725
+ (
726
+ sint64_run_id,
727
+ fab_id,
728
+ fab_version,
729
+ "",
730
+ json.dumps(override_config),
731
+ ),
648
732
  )
649
- return run_id
733
+ # Note: we need to return the uint64 value of the run_id
734
+ return uint64_run_id
650
735
  log(ERROR, "Unexpected run creation failure.")
651
736
  return 0
652
737
 
@@ -705,31 +790,64 @@ class SqliteState(State): # pylint: disable=R0904
705
790
 
706
791
  def get_run(self, run_id: int) -> Optional[Run]:
707
792
  """Retrieve information about the run with the specified `run_id`."""
793
+ # Convert the uint64 value to sint64 for SQLite
794
+ sint64_run_id = convert_uint64_to_sint64(run_id)
708
795
  query = "SELECT * FROM run WHERE run_id = ?;"
709
- try:
710
- row = self.query(query, (run_id,))[0]
796
+ rows = self.query(query, (sint64_run_id,))
797
+ if rows:
798
+ row = rows[0]
711
799
  return Run(
712
- run_id=run_id,
800
+ run_id=convert_sint64_to_uint64(row["run_id"]),
713
801
  fab_id=row["fab_id"],
714
802
  fab_version=row["fab_version"],
715
803
  fab_hash=row["fab_hash"],
716
804
  override_config=json.loads(row["override_config"]),
717
805
  )
718
- except sqlite3.IntegrityError:
719
- log(ERROR, "`run_id` does not exist.")
720
- return None
806
+ log(ERROR, "`run_id` does not exist.")
807
+ return None
721
808
 
722
809
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
723
810
  """Acknowledge a ping received from a node, serving as a heartbeat."""
811
+ sint64_node_id = convert_uint64_to_sint64(node_id)
812
+
724
813
  # Update `online_until` and `ping_interval` for the given `node_id`
725
814
  query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
726
815
  try:
727
- self.query(query, (time.time() + ping_interval, ping_interval, node_id))
816
+ self.query(
817
+ query, (time.time() + ping_interval, ping_interval, sint64_node_id)
818
+ )
728
819
  return True
729
820
  except sqlite3.IntegrityError:
730
821
  log(ERROR, "`node_id` does not exist.")
731
822
  return False
732
823
 
824
+ def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
825
+ """Check if the TaskIns exists and is valid (not expired).
826
+
827
+ Return TaskIns if valid.
828
+ """
829
+ query = """
830
+ SELECT *
831
+ FROM task_ins
832
+ WHERE task_id = :task_id
833
+ """
834
+ data = {"task_id": task_id}
835
+ rows = self.query(query, data)
836
+ if not rows:
837
+ # TaskIns does not exist
838
+ return None
839
+
840
+ task_ins = rows[0]
841
+ created_at = task_ins["created_at"]
842
+ ttl = task_ins["ttl"]
843
+ current_time = time.time()
844
+
845
+ # Check if TaskIns is expired
846
+ if ttl is not None and created_at + ttl <= current_time:
847
+ return None
848
+
849
+ return task_ins
850
+
733
851
 
734
852
  def dict_factory(
735
853
  cursor: sqlite3.Cursor,
@@ -33,8 +33,104 @@ NODE_UNAVAILABLE_ERROR_REASON = (
33
33
 
34
34
 
35
35
  def generate_rand_int_from_bytes(num_bytes: int) -> int:
36
- """Generate a random `num_bytes` integer."""
37
- return int.from_bytes(urandom(num_bytes), "little", signed=True)
36
+ """Generate a random unsigned integer from `num_bytes` bytes."""
37
+ return int.from_bytes(urandom(num_bytes), "little", signed=False)
38
+
39
+
40
+ def convert_uint64_to_sint64(u: int) -> int:
41
+ """Convert a uint64 value to a sint64 value with the same bit sequence.
42
+
43
+ Parameters
44
+ ----------
45
+ u : int
46
+ The unsigned 64-bit integer to convert.
47
+
48
+ Returns
49
+ -------
50
+ int
51
+ The signed 64-bit integer equivalent.
52
+
53
+ The signed 64-bit integer will have the same bit pattern as the
54
+ unsigned 64-bit integer but may have a different decimal value.
55
+
56
+ For numbers within the range [0, `sint64` max value], the decimal
57
+ value remains the same. However, for numbers greater than the `sint64`
58
+ max value, the decimal value will differ due to the wraparound caused
59
+ by the sign bit.
60
+ """
61
+ if u >= (1 << 63):
62
+ return u - (1 << 64)
63
+ return u
64
+
65
+
66
+ def convert_sint64_to_uint64(s: int) -> int:
67
+ """Convert a sint64 value to a uint64 value with the same bit sequence.
68
+
69
+ Parameters
70
+ ----------
71
+ s : int
72
+ The signed 64-bit integer to convert.
73
+
74
+ Returns
75
+ -------
76
+ int
77
+ The unsigned 64-bit integer equivalent.
78
+
79
+ The unsigned 64-bit integer will have the same bit pattern as the
80
+ signed 64-bit integer but may have a different decimal value.
81
+
82
+ For negative `sint64` values, the conversion adds 2^64 to the
83
+ signed value to obtain the equivalent `uint64` value. For non-negative
84
+ `sint64` values, the decimal value remains unchanged in the `uint64`
85
+ representation.
86
+ """
87
+ if s < 0:
88
+ return s + (1 << 64)
89
+ return s
90
+
91
+
92
+ def convert_uint64_values_in_dict_to_sint64(
93
+ data_dict: dict[str, int], keys: list[str]
94
+ ) -> None:
95
+ """Convert uint64 values to sint64 in the given dictionary.
96
+
97
+ Parameters
98
+ ----------
99
+ data_dict : dict[str, int]
100
+ A dictionary where the values are integers to be converted.
101
+ keys : list[str]
102
+ A list of keys in the dictionary whose values need to be converted.
103
+
104
+ Returns
105
+ -------
106
+ None
107
+ This function does not return a value. It modifies `data_dict` in place.
108
+ """
109
+ for key in keys:
110
+ if key in data_dict:
111
+ data_dict[key] = convert_uint64_to_sint64(data_dict[key])
112
+
113
+
114
+ def convert_sint64_values_in_dict_to_uint64(
115
+ data_dict: dict[str, int], keys: list[str]
116
+ ) -> None:
117
+ """Convert sint64 values to uint64 in the given dictionary.
118
+
119
+ Parameters
120
+ ----------
121
+ data_dict : dict[str, int]
122
+ A dictionary where the values are integers to be converted.
123
+ keys : list[str]
124
+ A list of keys in the dictionary whose values need to be converted.
125
+
126
+ Returns
127
+ -------
128
+ None
129
+ This function does not return a value. It modifies `data_dict` in place.
130
+ """
131
+ for key in keys:
132
+ if key in data_dict:
133
+ data_dict[key] = convert_sint64_to_uint64(data_dict[key])
38
134
 
39
135
 
40
136
  def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
@@ -15,6 +15,7 @@
15
15
  """Validators."""
16
16
 
17
17
 
18
+ import time
18
19
  from typing import Union
19
20
 
20
21
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
@@ -47,6 +48,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str
47
48
  # unix timestamp of 27 March 2024 00h:00m:00s UTC
48
49
  validation_errors.append("`pushed_at` is not a recent timestamp")
49
50
 
51
+ # Verify TTL and created_at time
52
+ current_time = time.time()
53
+ if tasks_ins_res.task.created_at + tasks_ins_res.task.ttl <= current_time:
54
+ validation_errors.append("Task TTL has expired")
55
+
50
56
  # TaskIns specific
51
57
  if isinstance(tasks_ins_res, TaskIns):
52
58
  # Task producer
@@ -28,8 +28,8 @@ from flwr.common.grpc import create_channel
28
28
  from flwr.common.logger import log
29
29
  from flwr.common.serde import fab_to_proto, user_config_to_proto
30
30
  from flwr.common.typing import Fab, UserConfig
31
- from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
32
31
  from flwr.proto.driver_pb2_grpc import DriverStub
32
+ from flwr.proto.run_pb2 import CreateRunRequest # pylint: disable=E0611
33
33
 
34
34
  from .executor import Executor, RunTracker
35
35
 
@@ -167,6 +167,8 @@ class DeploymentEngine(Executor):
167
167
  # Execute the command
168
168
  proc = subprocess.Popen( # pylint: disable=consider-using-with
169
169
  command,
170
+ stdout=subprocess.PIPE,
171
+ stderr=subprocess.PIPE,
170
172
  text=True,
171
173
  )
172
174
  log(INFO, "Started run %s", str(run_id))
@@ -15,6 +15,10 @@
15
15
  """SuperExec API servicer."""
16
16
 
17
17
 
18
+ import select
19
+ import sys
20
+ import threading
21
+ import time
18
22
  from collections.abc import Generator
19
23
  from logging import ERROR, INFO
20
24
  from typing import Any
@@ -33,6 +37,8 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
33
37
 
34
38
  from .executor import Executor, RunTracker
35
39
 
40
+ SELECT_TIMEOUT = 1 # Timeout for selecting ready-to-read file descriptors (in seconds)
41
+
36
42
 
37
43
  class ExecServicer(exec_pb2_grpc.ExecServicer):
38
44
  """SuperExec API servicer."""
@@ -59,13 +65,72 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
59
65
 
60
66
  self.runs[run.run_id] = run
61
67
 
68
+ # Start a background thread to capture the log output
69
+ capture_thread = threading.Thread(
70
+ target=_capture_logs, args=(run,), daemon=True
71
+ )
72
+ capture_thread.start()
73
+
62
74
  return StartRunResponse(run_id=run.run_id)
63
75
 
64
- def StreamLogs(
76
+ def StreamLogs( # pylint: disable=C0103
65
77
  self, request: StreamLogsRequest, context: grpc.ServicerContext
66
78
  ) -> Generator[StreamLogsResponse, Any, None]:
67
79
  """Get logs."""
68
- logs = ["a", "b", "c"]
80
+ log(INFO, "ExecServicer.StreamLogs")
81
+
82
+ # Exit if `run_id` not found
83
+ if request.run_id not in self.runs:
84
+ context.abort(grpc.StatusCode.NOT_FOUND, "Run ID not found")
85
+
86
+ last_sent_index = 0
69
87
  while context.is_active():
70
- for i in range(len(logs)): # pylint: disable=C0200
88
+ # Yield n'th row of logs, if n'th row < len(logs)
89
+ logs = self.runs[request.run_id].logs
90
+ for i in range(last_sent_index, len(logs)):
71
91
  yield StreamLogsResponse(log_output=logs[i])
92
+ last_sent_index = len(logs)
93
+
94
+ # Wait for and continue to yield more log responses only if the
95
+ # run isn't completed yet. If the run is finished, the entire log
96
+ # is returned at this point and the server ends the stream.
97
+ if self.runs[request.run_id].proc.poll() is not None:
98
+ log(INFO, "All logs for run ID `%s` returned", request.run_id)
99
+ context.set_code(grpc.StatusCode.OK)
100
+ context.cancel()
101
+
102
+ time.sleep(1.0) # Sleep briefly to avoid busy waiting
103
+
104
+
105
+ def _capture_logs(
106
+ run: RunTracker,
107
+ ) -> None:
108
+ while True:
109
+ # Explicitly check if Popen.poll() is None. Required for `pytest`.
110
+ if run.proc.poll() is None:
111
+ # Select streams only when ready to read
112
+ ready_to_read, _, _ = select.select(
113
+ [run.proc.stdout, run.proc.stderr],
114
+ [],
115
+ [],
116
+ SELECT_TIMEOUT,
117
+ )
118
+ # Read from std* and append to RunTracker.logs
119
+ for stream in ready_to_read:
120
+ # Flush stdout to view output in real time
121
+ readline = stream.readline()
122
+ sys.stdout.write(readline)
123
+ sys.stdout.flush()
124
+ # Append to logs
125
+ line = readline.rstrip()
126
+ if line:
127
+ run.logs.append(f"{line}")
128
+
129
+ # Close std* to prevent blocking
130
+ elif run.proc.poll() is not None:
131
+ log(INFO, "Subprocess finished, exiting log capture")
132
+ if run.proc.stdout:
133
+ run.proc.stdout.close()
134
+ if run.proc.stderr:
135
+ run.proc.stderr.close()
136
+ break
@@ -15,7 +15,7 @@
15
15
  """Execute and monitor a Flower run."""
16
16
 
17
17
  from abc import ABC, abstractmethod
18
- from dataclasses import dataclass
18
+ from dataclasses import dataclass, field
19
19
  from subprocess import Popen
20
20
  from typing import Optional
21
21
 
@@ -28,6 +28,7 @@ class RunTracker:
28
28
 
29
29
  run_id: int
30
30
  proc: Popen # type: ignore
31
+ logs: list[str] = field(default_factory=list)
31
32
 
32
33
 
33
34
  class Executor(ABC):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.12.0.dev20240916
3
+ Version: 1.12.0.dev20241006
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -43,7 +43,7 @@ Requires-Dist: requests (>=2.31.0,<3.0.0) ; extra == "rest"
43
43
  Requires-Dist: starlette (>=0.31.0,<0.32.0) ; extra == "rest"
44
44
  Requires-Dist: tomli (>=2.0.1,<3.0.0)
45
45
  Requires-Dist: tomli-w (>=1.0.0,<2.0.0)
46
- Requires-Dist: typer[all] (>=0.9.0,<0.10.0)
46
+ Requires-Dist: typer (>=0.12.5,<0.13.0)
47
47
  Requires-Dist: uvicorn[standard] (>=0.23.0,<0.24.0) ; extra == "rest"
48
48
  Project-URL: Documentation, https://flower.ai
49
49
  Project-URL: Repository, https://github.com/adap/flower
@@ -69,6 +69,7 @@ Description-Content-Type: text/markdown
69
69
  [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/adap/flower/blob/main/CONTRIBUTING.md)
70
70
  ![Build](https://github.com/adap/flower/actions/workflows/framework.yml/badge.svg)
71
71
  [![Downloads](https://static.pepy.tech/badge/flwr)](https://pepy.tech/project/flwr)
72
+ [![Docker Hub](https://img.shields.io/badge/Docker%20Hub-flwr-blue)](https://hub.docker.com/u/flwr)
72
73
  [![Slack](https://img.shields.io/badge/Chat-Slack-red)](https://flower.ai/join-slack)
73
74
 
74
75
  Flower (`flwr`) is a framework for building federated learning systems. The
@@ -152,6 +153,7 @@ Flower Baselines is a collection of community-contributed projects that reproduc
152
153
  - [FedNova](https://github.com/adap/flower/tree/main/baselines/fednova)
153
154
  - [HeteroFL](https://github.com/adap/flower/tree/main/baselines/heterofl)
154
155
  - [FedAvgM](https://github.com/adap/flower/tree/main/baselines/fedavgm)
156
+ - [FedRep](https://github.com/adap/flower/tree/main/baselines/fedrep)
155
157
  - [FedStar](https://github.com/adap/flower/tree/main/baselines/fedstar)
156
158
  - [FedWav2vec2](https://github.com/adap/flower/tree/main/baselines/fedwav2vec2)
157
159
  - [FjORD](https://github.com/adap/flower/tree/main/baselines/fjord)