flwr-nightly 1.12.0.dev20240916__py3-none-any.whl → 1.12.0.dev20241006__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/log.py +234 -0
- flwr/cli/new/new.py +1 -1
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/run/run.py +17 -1
- flwr/client/grpc_rere_client/client_interceptor.py +3 -0
- flwr/client/grpc_rere_client/connection.py +3 -3
- flwr/client/grpc_rere_client/grpc_adapter.py +14 -3
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +1 -0
- flwr/common/constant.py +6 -3
- flwr/common/secure_aggregation/secaggplus_utils.py +4 -4
- flwr/common/serde.py +22 -7
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/control_pb2.py +27 -0
- flwr/proto/control_pb2.pyi +7 -0
- flwr/proto/control_pb2_grpc.py +135 -0
- flwr/proto/control_pb2_grpc.pyi +53 -0
- flwr/proto/driver_pb2.py +15 -24
- flwr/proto/driver_pb2.pyi +0 -52
- flwr/proto/driver_pb2_grpc.py +6 -6
- flwr/proto/driver_pb2_grpc.pyi +4 -4
- flwr/proto/exec_pb2.py +1 -1
- flwr/proto/fab_pb2.py +8 -7
- flwr/proto/fab_pb2.pyi +7 -1
- flwr/proto/fleet_pb2.py +10 -10
- flwr/proto/fleet_pb2.pyi +6 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +1 -1
- flwr/proto/recordset_pb2.py +35 -33
- flwr/proto/recordset_pb2.pyi +40 -14
- flwr/proto/run_pb2.py +33 -9
- flwr/proto/run_pb2.pyi +150 -1
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +8 -8
- flwr/proto/transport_pb2.pyi +9 -6
- flwr/server/run_serverapp.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +4 -0
- flwr/server/superlink/state/in_memory_state.py +17 -0
- flwr/server/superlink/state/sqlite_state.py +142 -24
- flwr/server/superlink/state/utils.py +98 -2
- flwr/server/utils/validator.py +6 -0
- flwr/superexec/deployment.py +3 -1
- flwr/superexec/exec_servicer.py +68 -3
- flwr/superexec/executor.py +2 -1
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/METADATA +4 -2
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/RECORD +53 -48
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/entry_points.txt +0 -0
|
@@ -33,7 +33,14 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
|
33
33
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
34
34
|
|
|
35
35
|
from .state import State
|
|
36
|
-
from .utils import
|
|
36
|
+
from .utils import (
|
|
37
|
+
convert_sint64_to_uint64,
|
|
38
|
+
convert_sint64_values_in_dict_to_uint64,
|
|
39
|
+
convert_uint64_to_sint64,
|
|
40
|
+
convert_uint64_values_in_dict_to_sint64,
|
|
41
|
+
generate_rand_int_from_bytes,
|
|
42
|
+
make_node_unavailable_taskres,
|
|
43
|
+
)
|
|
37
44
|
|
|
38
45
|
SQL_CREATE_TABLE_NODE = """
|
|
39
46
|
CREATE TABLE IF NOT EXISTS node(
|
|
@@ -223,6 +230,12 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
223
230
|
# Store TaskIns
|
|
224
231
|
task_ins.task_id = str(task_id)
|
|
225
232
|
data = (task_ins_to_dict(task_ins),)
|
|
233
|
+
|
|
234
|
+
# Convert values from uint64 to sint64 for SQLite
|
|
235
|
+
convert_uint64_values_in_dict_to_sint64(
|
|
236
|
+
data[0], ["run_id", "producer_node_id", "consumer_node_id"]
|
|
237
|
+
)
|
|
238
|
+
|
|
226
239
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
227
240
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
228
241
|
|
|
@@ -284,6 +297,9 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
284
297
|
AND delivered_at = ""
|
|
285
298
|
"""
|
|
286
299
|
else:
|
|
300
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
301
|
+
data["node_id"] = convert_uint64_to_sint64(node_id)
|
|
302
|
+
|
|
287
303
|
# Retrieve all TaskIns for node_id
|
|
288
304
|
query = """
|
|
289
305
|
SELECT task_id
|
|
@@ -292,7 +308,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
292
308
|
AND consumer_node_id == :node_id
|
|
293
309
|
AND delivered_at = ""
|
|
294
310
|
"""
|
|
295
|
-
data["node_id"] = node_id
|
|
296
311
|
|
|
297
312
|
if limit is not None:
|
|
298
313
|
query += " LIMIT :limit"
|
|
@@ -322,6 +337,12 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
322
337
|
# Run query
|
|
323
338
|
rows = self.query(query, data)
|
|
324
339
|
|
|
340
|
+
for row in rows:
|
|
341
|
+
# Convert values from sint64 to uint64
|
|
342
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
343
|
+
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
344
|
+
)
|
|
345
|
+
|
|
325
346
|
result = [dict_to_task_ins(row) for row in rows]
|
|
326
347
|
|
|
327
348
|
return result
|
|
@@ -351,9 +372,26 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
351
372
|
# Create task_id
|
|
352
373
|
task_id = uuid4()
|
|
353
374
|
|
|
354
|
-
|
|
375
|
+
task_ins_id = task_res.task.ancestry[0]
|
|
376
|
+
task_ins = self.get_valid_task_ins(task_ins_id)
|
|
377
|
+
if task_ins is None:
|
|
378
|
+
log(
|
|
379
|
+
ERROR,
|
|
380
|
+
"Failed to store TaskRes: "
|
|
381
|
+
"TaskIns with task_id %s does not exist or has expired.",
|
|
382
|
+
task_ins_id,
|
|
383
|
+
)
|
|
384
|
+
return None
|
|
385
|
+
|
|
386
|
+
# Store TaskRes
|
|
355
387
|
task_res.task_id = str(task_id)
|
|
356
388
|
data = (task_res_to_dict(task_res),)
|
|
389
|
+
|
|
390
|
+
# Convert values from uint64 to sint64 for SQLite
|
|
391
|
+
convert_uint64_values_in_dict_to_sint64(
|
|
392
|
+
data[0], ["run_id", "producer_node_id", "consumer_node_id"]
|
|
393
|
+
)
|
|
394
|
+
|
|
357
395
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
358
396
|
query = f"INSERT INTO task_res VALUES({columns});"
|
|
359
397
|
|
|
@@ -431,6 +469,12 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
431
469
|
# Run query
|
|
432
470
|
rows = self.query(query, data)
|
|
433
471
|
|
|
472
|
+
for row in rows:
|
|
473
|
+
# Convert values from sint64 to uint64
|
|
474
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
475
|
+
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
476
|
+
)
|
|
477
|
+
|
|
434
478
|
result = [dict_to_task_res(row) for row in rows]
|
|
435
479
|
|
|
436
480
|
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
@@ -474,6 +518,13 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
474
518
|
for row in task_ins_rows:
|
|
475
519
|
if limit and len(result) == limit:
|
|
476
520
|
break
|
|
521
|
+
|
|
522
|
+
for row in rows:
|
|
523
|
+
# Convert values from sint64 to uint64
|
|
524
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
525
|
+
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
526
|
+
)
|
|
527
|
+
|
|
477
528
|
task_ins = dict_to_task_ins(row)
|
|
478
529
|
err_taskres = make_node_unavailable_taskres(
|
|
479
530
|
ref_taskins=task_ins,
|
|
@@ -544,8 +595,11 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
544
595
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
545
596
|
) -> int:
|
|
546
597
|
"""Create, store in state, and return `node_id`."""
|
|
547
|
-
# Sample a random
|
|
548
|
-
|
|
598
|
+
# Sample a random uint64 as node_id
|
|
599
|
+
uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
600
|
+
|
|
601
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
602
|
+
sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
|
|
549
603
|
|
|
550
604
|
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
551
605
|
row = self.query(query, {"public_key": public_key})
|
|
@@ -562,17 +616,28 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
562
616
|
|
|
563
617
|
try:
|
|
564
618
|
self.query(
|
|
565
|
-
query,
|
|
619
|
+
query,
|
|
620
|
+
(
|
|
621
|
+
sint64_node_id,
|
|
622
|
+
time.time() + ping_interval,
|
|
623
|
+
ping_interval,
|
|
624
|
+
public_key,
|
|
625
|
+
),
|
|
566
626
|
)
|
|
567
627
|
except sqlite3.IntegrityError:
|
|
568
628
|
log(ERROR, "Unexpected node registration failure.")
|
|
569
629
|
return 0
|
|
570
|
-
|
|
630
|
+
|
|
631
|
+
# Note: we need to return the uint64 value of the node_id
|
|
632
|
+
return uint64_node_id
|
|
571
633
|
|
|
572
634
|
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
573
635
|
"""Delete a node."""
|
|
636
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
637
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
638
|
+
|
|
574
639
|
query = "DELETE FROM node WHERE node_id = ?"
|
|
575
|
-
params = (
|
|
640
|
+
params = (sint64_node_id,)
|
|
576
641
|
|
|
577
642
|
if public_key is not None:
|
|
578
643
|
query += " AND public_key = ?"
|
|
@@ -597,15 +662,20 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
597
662
|
If the provided `run_id` does not exist or has no matching nodes,
|
|
598
663
|
an empty `Set` MUST be returned.
|
|
599
664
|
"""
|
|
665
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
666
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
667
|
+
|
|
600
668
|
# Validate run ID
|
|
601
669
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
602
|
-
if self.query(query, (
|
|
670
|
+
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
603
671
|
return set()
|
|
604
672
|
|
|
605
673
|
# Get nodes
|
|
606
674
|
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
|
607
675
|
rows = self.query(query, (time.time(),))
|
|
608
|
-
|
|
676
|
+
|
|
677
|
+
# Convert sint64 node_ids to uint64
|
|
678
|
+
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
|
609
679
|
return result
|
|
610
680
|
|
|
611
681
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
@@ -614,7 +684,11 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
614
684
|
row = self.query(query, {"public_key": node_public_key})
|
|
615
685
|
if len(row) > 0:
|
|
616
686
|
node_id: int = row[0]["node_id"]
|
|
617
|
-
|
|
687
|
+
|
|
688
|
+
# Convert the sint64 value to uint64 after reading from SQLite
|
|
689
|
+
uint64_node_id = convert_sint64_to_uint64(node_id)
|
|
690
|
+
|
|
691
|
+
return uint64_node_id
|
|
618
692
|
return None
|
|
619
693
|
|
|
620
694
|
def create_run(
|
|
@@ -626,12 +700,15 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
626
700
|
) -> int:
|
|
627
701
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
628
702
|
# Sample a random int64 as run_id
|
|
629
|
-
|
|
703
|
+
uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
704
|
+
|
|
705
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
706
|
+
sint64_run_id = convert_uint64_to_sint64(uint64_run_id)
|
|
630
707
|
|
|
631
708
|
# Check conflicts
|
|
632
709
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
633
|
-
# If
|
|
634
|
-
if self.query(query, (
|
|
710
|
+
# If sint64_run_id does not exist
|
|
711
|
+
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
635
712
|
query = (
|
|
636
713
|
"INSERT INTO run "
|
|
637
714
|
"(run_id, fab_id, fab_version, fab_hash, override_config)"
|
|
@@ -639,14 +716,22 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
639
716
|
)
|
|
640
717
|
if fab_hash:
|
|
641
718
|
self.query(
|
|
642
|
-
query,
|
|
719
|
+
query,
|
|
720
|
+
(sint64_run_id, "", "", fab_hash, json.dumps(override_config)),
|
|
643
721
|
)
|
|
644
722
|
else:
|
|
645
723
|
self.query(
|
|
646
724
|
query,
|
|
647
|
-
(
|
|
725
|
+
(
|
|
726
|
+
sint64_run_id,
|
|
727
|
+
fab_id,
|
|
728
|
+
fab_version,
|
|
729
|
+
"",
|
|
730
|
+
json.dumps(override_config),
|
|
731
|
+
),
|
|
648
732
|
)
|
|
649
|
-
return run_id
|
|
733
|
+
# Note: we need to return the uint64 value of the run_id
|
|
734
|
+
return uint64_run_id
|
|
650
735
|
log(ERROR, "Unexpected run creation failure.")
|
|
651
736
|
return 0
|
|
652
737
|
|
|
@@ -705,31 +790,64 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
705
790
|
|
|
706
791
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
707
792
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
793
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
794
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
708
795
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
709
|
-
|
|
710
|
-
|
|
796
|
+
rows = self.query(query, (sint64_run_id,))
|
|
797
|
+
if rows:
|
|
798
|
+
row = rows[0]
|
|
711
799
|
return Run(
|
|
712
|
-
run_id=run_id,
|
|
800
|
+
run_id=convert_sint64_to_uint64(row["run_id"]),
|
|
713
801
|
fab_id=row["fab_id"],
|
|
714
802
|
fab_version=row["fab_version"],
|
|
715
803
|
fab_hash=row["fab_hash"],
|
|
716
804
|
override_config=json.loads(row["override_config"]),
|
|
717
805
|
)
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
return None
|
|
806
|
+
log(ERROR, "`run_id` does not exist.")
|
|
807
|
+
return None
|
|
721
808
|
|
|
722
809
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
723
810
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
811
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
812
|
+
|
|
724
813
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
725
814
|
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
|
|
726
815
|
try:
|
|
727
|
-
self.query(
|
|
816
|
+
self.query(
|
|
817
|
+
query, (time.time() + ping_interval, ping_interval, sint64_node_id)
|
|
818
|
+
)
|
|
728
819
|
return True
|
|
729
820
|
except sqlite3.IntegrityError:
|
|
730
821
|
log(ERROR, "`node_id` does not exist.")
|
|
731
822
|
return False
|
|
732
823
|
|
|
824
|
+
def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
|
|
825
|
+
"""Check if the TaskIns exists and is valid (not expired).
|
|
826
|
+
|
|
827
|
+
Return TaskIns if valid.
|
|
828
|
+
"""
|
|
829
|
+
query = """
|
|
830
|
+
SELECT *
|
|
831
|
+
FROM task_ins
|
|
832
|
+
WHERE task_id = :task_id
|
|
833
|
+
"""
|
|
834
|
+
data = {"task_id": task_id}
|
|
835
|
+
rows = self.query(query, data)
|
|
836
|
+
if not rows:
|
|
837
|
+
# TaskIns does not exist
|
|
838
|
+
return None
|
|
839
|
+
|
|
840
|
+
task_ins = rows[0]
|
|
841
|
+
created_at = task_ins["created_at"]
|
|
842
|
+
ttl = task_ins["ttl"]
|
|
843
|
+
current_time = time.time()
|
|
844
|
+
|
|
845
|
+
# Check if TaskIns is expired
|
|
846
|
+
if ttl is not None and created_at + ttl <= current_time:
|
|
847
|
+
return None
|
|
848
|
+
|
|
849
|
+
return task_ins
|
|
850
|
+
|
|
733
851
|
|
|
734
852
|
def dict_factory(
|
|
735
853
|
cursor: sqlite3.Cursor,
|
|
@@ -33,8 +33,104 @@ NODE_UNAVAILABLE_ERROR_REASON = (
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def generate_rand_int_from_bytes(num_bytes: int) -> int:
|
|
36
|
-
"""Generate a random `num_bytes`
|
|
37
|
-
return int.from_bytes(urandom(num_bytes), "little", signed=
|
|
36
|
+
"""Generate a random unsigned integer from `num_bytes` bytes."""
|
|
37
|
+
return int.from_bytes(urandom(num_bytes), "little", signed=False)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def convert_uint64_to_sint64(u: int) -> int:
|
|
41
|
+
"""Convert a uint64 value to a sint64 value with the same bit sequence.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
u : int
|
|
46
|
+
The unsigned 64-bit integer to convert.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
int
|
|
51
|
+
The signed 64-bit integer equivalent.
|
|
52
|
+
|
|
53
|
+
The signed 64-bit integer will have the same bit pattern as the
|
|
54
|
+
unsigned 64-bit integer but may have a different decimal value.
|
|
55
|
+
|
|
56
|
+
For numbers within the range [0, `sint64` max value], the decimal
|
|
57
|
+
value remains the same. However, for numbers greater than the `sint64`
|
|
58
|
+
max value, the decimal value will differ due to the wraparound caused
|
|
59
|
+
by the sign bit.
|
|
60
|
+
"""
|
|
61
|
+
if u >= (1 << 63):
|
|
62
|
+
return u - (1 << 64)
|
|
63
|
+
return u
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def convert_sint64_to_uint64(s: int) -> int:
|
|
67
|
+
"""Convert a sint64 value to a uint64 value with the same bit sequence.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
s : int
|
|
72
|
+
The signed 64-bit integer to convert.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
int
|
|
77
|
+
The unsigned 64-bit integer equivalent.
|
|
78
|
+
|
|
79
|
+
The unsigned 64-bit integer will have the same bit pattern as the
|
|
80
|
+
signed 64-bit integer but may have a different decimal value.
|
|
81
|
+
|
|
82
|
+
For negative `sint64` values, the conversion adds 2^64 to the
|
|
83
|
+
signed value to obtain the equivalent `uint64` value. For non-negative
|
|
84
|
+
`sint64` values, the decimal value remains unchanged in the `uint64`
|
|
85
|
+
representation.
|
|
86
|
+
"""
|
|
87
|
+
if s < 0:
|
|
88
|
+
return s + (1 << 64)
|
|
89
|
+
return s
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def convert_uint64_values_in_dict_to_sint64(
|
|
93
|
+
data_dict: dict[str, int], keys: list[str]
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Convert uint64 values to sint64 in the given dictionary.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
data_dict : dict[str, int]
|
|
100
|
+
A dictionary where the values are integers to be converted.
|
|
101
|
+
keys : list[str]
|
|
102
|
+
A list of keys in the dictionary whose values need to be converted.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
None
|
|
107
|
+
This function does not return a value. It modifies `data_dict` in place.
|
|
108
|
+
"""
|
|
109
|
+
for key in keys:
|
|
110
|
+
if key in data_dict:
|
|
111
|
+
data_dict[key] = convert_uint64_to_sint64(data_dict[key])
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def convert_sint64_values_in_dict_to_uint64(
|
|
115
|
+
data_dict: dict[str, int], keys: list[str]
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Convert sint64 values to uint64 in the given dictionary.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
data_dict : dict[str, int]
|
|
122
|
+
A dictionary where the values are integers to be converted.
|
|
123
|
+
keys : list[str]
|
|
124
|
+
A list of keys in the dictionary whose values need to be converted.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
None
|
|
129
|
+
This function does not return a value. It modifies `data_dict` in place.
|
|
130
|
+
"""
|
|
131
|
+
for key in keys:
|
|
132
|
+
if key in data_dict:
|
|
133
|
+
data_dict[key] = convert_sint64_to_uint64(data_dict[key])
|
|
38
134
|
|
|
39
135
|
|
|
40
136
|
def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
flwr/server/utils/validator.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Validators."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import time
|
|
18
19
|
from typing import Union
|
|
19
20
|
|
|
20
21
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -47,6 +48,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str
|
|
|
47
48
|
# unix timestamp of 27 March 2024 00h:00m:00s UTC
|
|
48
49
|
validation_errors.append("`pushed_at` is not a recent timestamp")
|
|
49
50
|
|
|
51
|
+
# Verify TTL and created_at time
|
|
52
|
+
current_time = time.time()
|
|
53
|
+
if tasks_ins_res.task.created_at + tasks_ins_res.task.ttl <= current_time:
|
|
54
|
+
validation_errors.append("Task TTL has expired")
|
|
55
|
+
|
|
50
56
|
# TaskIns specific
|
|
51
57
|
if isinstance(tasks_ins_res, TaskIns):
|
|
52
58
|
# Task producer
|
flwr/superexec/deployment.py
CHANGED
|
@@ -28,8 +28,8 @@ from flwr.common.grpc import create_channel
|
|
|
28
28
|
from flwr.common.logger import log
|
|
29
29
|
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
30
30
|
from flwr.common.typing import Fab, UserConfig
|
|
31
|
-
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
|
|
32
31
|
from flwr.proto.driver_pb2_grpc import DriverStub
|
|
32
|
+
from flwr.proto.run_pb2 import CreateRunRequest # pylint: disable=E0611
|
|
33
33
|
|
|
34
34
|
from .executor import Executor, RunTracker
|
|
35
35
|
|
|
@@ -167,6 +167,8 @@ class DeploymentEngine(Executor):
|
|
|
167
167
|
# Execute the command
|
|
168
168
|
proc = subprocess.Popen( # pylint: disable=consider-using-with
|
|
169
169
|
command,
|
|
170
|
+
stdout=subprocess.PIPE,
|
|
171
|
+
stderr=subprocess.PIPE,
|
|
170
172
|
text=True,
|
|
171
173
|
)
|
|
172
174
|
log(INFO, "Started run %s", str(run_id))
|
flwr/superexec/exec_servicer.py
CHANGED
|
@@ -15,6 +15,10 @@
|
|
|
15
15
|
"""SuperExec API servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import select
|
|
19
|
+
import sys
|
|
20
|
+
import threading
|
|
21
|
+
import time
|
|
18
22
|
from collections.abc import Generator
|
|
19
23
|
from logging import ERROR, INFO
|
|
20
24
|
from typing import Any
|
|
@@ -33,6 +37,8 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
|
33
37
|
|
|
34
38
|
from .executor import Executor, RunTracker
|
|
35
39
|
|
|
40
|
+
SELECT_TIMEOUT = 1 # Timeout for selecting ready-to-read file descriptors (in seconds)
|
|
41
|
+
|
|
36
42
|
|
|
37
43
|
class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
38
44
|
"""SuperExec API servicer."""
|
|
@@ -59,13 +65,72 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
59
65
|
|
|
60
66
|
self.runs[run.run_id] = run
|
|
61
67
|
|
|
68
|
+
# Start a background thread to capture the log output
|
|
69
|
+
capture_thread = threading.Thread(
|
|
70
|
+
target=_capture_logs, args=(run,), daemon=True
|
|
71
|
+
)
|
|
72
|
+
capture_thread.start()
|
|
73
|
+
|
|
62
74
|
return StartRunResponse(run_id=run.run_id)
|
|
63
75
|
|
|
64
|
-
def StreamLogs(
|
|
76
|
+
def StreamLogs( # pylint: disable=C0103
|
|
65
77
|
self, request: StreamLogsRequest, context: grpc.ServicerContext
|
|
66
78
|
) -> Generator[StreamLogsResponse, Any, None]:
|
|
67
79
|
"""Get logs."""
|
|
68
|
-
|
|
80
|
+
log(INFO, "ExecServicer.StreamLogs")
|
|
81
|
+
|
|
82
|
+
# Exit if `run_id` not found
|
|
83
|
+
if request.run_id not in self.runs:
|
|
84
|
+
context.abort(grpc.StatusCode.NOT_FOUND, "Run ID not found")
|
|
85
|
+
|
|
86
|
+
last_sent_index = 0
|
|
69
87
|
while context.is_active():
|
|
70
|
-
|
|
88
|
+
# Yield n'th row of logs, if n'th row < len(logs)
|
|
89
|
+
logs = self.runs[request.run_id].logs
|
|
90
|
+
for i in range(last_sent_index, len(logs)):
|
|
71
91
|
yield StreamLogsResponse(log_output=logs[i])
|
|
92
|
+
last_sent_index = len(logs)
|
|
93
|
+
|
|
94
|
+
# Wait for and continue to yield more log responses only if the
|
|
95
|
+
# run isn't completed yet. If the run is finished, the entire log
|
|
96
|
+
# is returned at this point and the server ends the stream.
|
|
97
|
+
if self.runs[request.run_id].proc.poll() is not None:
|
|
98
|
+
log(INFO, "All logs for run ID `%s` returned", request.run_id)
|
|
99
|
+
context.set_code(grpc.StatusCode.OK)
|
|
100
|
+
context.cancel()
|
|
101
|
+
|
|
102
|
+
time.sleep(1.0) # Sleep briefly to avoid busy waiting
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _capture_logs(
|
|
106
|
+
run: RunTracker,
|
|
107
|
+
) -> None:
|
|
108
|
+
while True:
|
|
109
|
+
# Explicitly check if Popen.poll() is None. Required for `pytest`.
|
|
110
|
+
if run.proc.poll() is None:
|
|
111
|
+
# Select streams only when ready to read
|
|
112
|
+
ready_to_read, _, _ = select.select(
|
|
113
|
+
[run.proc.stdout, run.proc.stderr],
|
|
114
|
+
[],
|
|
115
|
+
[],
|
|
116
|
+
SELECT_TIMEOUT,
|
|
117
|
+
)
|
|
118
|
+
# Read from std* and append to RunTracker.logs
|
|
119
|
+
for stream in ready_to_read:
|
|
120
|
+
# Flush stdout to view output in real time
|
|
121
|
+
readline = stream.readline()
|
|
122
|
+
sys.stdout.write(readline)
|
|
123
|
+
sys.stdout.flush()
|
|
124
|
+
# Append to logs
|
|
125
|
+
line = readline.rstrip()
|
|
126
|
+
if line:
|
|
127
|
+
run.logs.append(f"{line}")
|
|
128
|
+
|
|
129
|
+
# Close std* to prevent blocking
|
|
130
|
+
elif run.proc.poll() is not None:
|
|
131
|
+
log(INFO, "Subprocess finished, exiting log capture")
|
|
132
|
+
if run.proc.stdout:
|
|
133
|
+
run.proc.stdout.close()
|
|
134
|
+
if run.proc.stderr:
|
|
135
|
+
run.proc.stderr.close()
|
|
136
|
+
break
|
flwr/superexec/executor.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Execute and monitor a Flower run."""
|
|
16
16
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
|
-
from dataclasses import dataclass
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
19
|
from subprocess import Popen
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
@@ -28,6 +28,7 @@ class RunTracker:
|
|
|
28
28
|
|
|
29
29
|
run_id: int
|
|
30
30
|
proc: Popen # type: ignore
|
|
31
|
+
logs: list[str] = field(default_factory=list)
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class Executor(ABC):
|
{flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: flwr-nightly
|
|
3
|
-
Version: 1.12.0.
|
|
3
|
+
Version: 1.12.0.dev20241006
|
|
4
4
|
Summary: Flower: A Friendly Federated Learning Framework
|
|
5
5
|
Home-page: https://flower.ai
|
|
6
6
|
License: Apache-2.0
|
|
@@ -43,7 +43,7 @@ Requires-Dist: requests (>=2.31.0,<3.0.0) ; extra == "rest"
|
|
|
43
43
|
Requires-Dist: starlette (>=0.31.0,<0.32.0) ; extra == "rest"
|
|
44
44
|
Requires-Dist: tomli (>=2.0.1,<3.0.0)
|
|
45
45
|
Requires-Dist: tomli-w (>=1.0.0,<2.0.0)
|
|
46
|
-
Requires-Dist: typer
|
|
46
|
+
Requires-Dist: typer (>=0.12.5,<0.13.0)
|
|
47
47
|
Requires-Dist: uvicorn[standard] (>=0.23.0,<0.24.0) ; extra == "rest"
|
|
48
48
|
Project-URL: Documentation, https://flower.ai
|
|
49
49
|
Project-URL: Repository, https://github.com/adap/flower
|
|
@@ -69,6 +69,7 @@ Description-Content-Type: text/markdown
|
|
|
69
69
|
[](https://github.com/adap/flower/blob/main/CONTRIBUTING.md)
|
|
70
70
|

|
|
71
71
|
[](https://pepy.tech/project/flwr)
|
|
72
|
+
[](https://hub.docker.com/u/flwr)
|
|
72
73
|
[](https://flower.ai/join-slack)
|
|
73
74
|
|
|
74
75
|
Flower (`flwr`) is a framework for building federated learning systems. The
|
|
@@ -152,6 +153,7 @@ Flower Baselines is a collection of community-contributed projects that reproduc
|
|
|
152
153
|
- [FedNova](https://github.com/adap/flower/tree/main/baselines/fednova)
|
|
153
154
|
- [HeteroFL](https://github.com/adap/flower/tree/main/baselines/heterofl)
|
|
154
155
|
- [FedAvgM](https://github.com/adap/flower/tree/main/baselines/fedavgm)
|
|
156
|
+
- [FedRep](https://github.com/adap/flower/tree/main/baselines/fedrep)
|
|
155
157
|
- [FedStar](https://github.com/adap/flower/tree/main/baselines/fedstar)
|
|
156
158
|
- [FedWav2vec2](https://github.com/adap/flower/tree/main/baselines/fedwav2vec2)
|
|
157
159
|
- [FjORD](https://github.com/adap/flower/tree/main/baselines/fjord)
|