flwr-nightly 1.13.0.dev20241023__py3-none-any.whl → 1.13.0.dev20241025__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/proto/driver_pb2.py +24 -15
- flwr/proto/driver_pb2.pyi +59 -0
- flwr/proto/driver_pb2_grpc.py +68 -0
- flwr/proto/driver_pb2_grpc.pyi +26 -0
- flwr/server/app.py +4 -2
- flwr/server/run_serverapp.py +13 -9
- flwr/server/superlink/driver/driver_servicer.py +65 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +12 -1
- flwr/server/superlink/linkstate/linkstate.py +29 -0
- flwr/server/superlink/linkstate/sqlite_linkstate.py +51 -6
- flwr/server/superlink/linkstate/utils.py +12 -1
- flwr/simulation/run_simulation.py +12 -4
- flwr/superexec/app.py +3 -138
- flwr/superexec/deployment.py +34 -25
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +11 -1
- flwr/superexec/executor.py +19 -0
- flwr/superexec/simulation.py +8 -0
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/RECORD +24 -25
- flwr/client/node_state_tests.py +0 -65
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/entry_points.txt +0 -0
|
@@ -19,13 +19,14 @@
|
|
|
19
19
|
import json
|
|
20
20
|
import re
|
|
21
21
|
import sqlite3
|
|
22
|
+
import threading
|
|
22
23
|
import time
|
|
23
24
|
from collections.abc import Sequence
|
|
24
25
|
from logging import DEBUG, ERROR, WARNING
|
|
25
26
|
from typing import Any, Optional, Union, cast
|
|
26
27
|
from uuid import UUID, uuid4
|
|
27
28
|
|
|
28
|
-
from flwr.common import log, now
|
|
29
|
+
from flwr.common import Context, log, now
|
|
29
30
|
from flwr.common.constant import (
|
|
30
31
|
MESSAGE_TTL_TOLERANCE,
|
|
31
32
|
NODE_ID_NUM_BYTES,
|
|
@@ -33,13 +34,19 @@ from flwr.common.constant import (
|
|
|
33
34
|
Status,
|
|
34
35
|
)
|
|
35
36
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
from flwr.proto.
|
|
37
|
+
|
|
38
|
+
# pylint: disable=E0611
|
|
39
|
+
from flwr.proto.node_pb2 import Node
|
|
40
|
+
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
41
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
42
|
+
|
|
43
|
+
# pylint: enable=E0611
|
|
39
44
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
40
45
|
|
|
41
46
|
from .linkstate import LinkState
|
|
42
47
|
from .utils import (
|
|
48
|
+
context_from_bytes,
|
|
49
|
+
context_to_bytes,
|
|
43
50
|
convert_sint64_to_uint64,
|
|
44
51
|
convert_sint64_values_in_dict_to_uint64,
|
|
45
52
|
convert_uint64_to_sint64,
|
|
@@ -92,6 +99,14 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
92
99
|
);
|
|
93
100
|
"""
|
|
94
101
|
|
|
102
|
+
SQL_CREATE_TABLE_CONTEXT = """
|
|
103
|
+
CREATE TABLE IF NOT EXISTS context(
|
|
104
|
+
run_id INTEGER UNIQUE,
|
|
105
|
+
context BLOB,
|
|
106
|
+
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
107
|
+
);
|
|
108
|
+
"""
|
|
109
|
+
|
|
95
110
|
SQL_CREATE_TABLE_TASK_INS = """
|
|
96
111
|
CREATE TABLE IF NOT EXISTS task_ins(
|
|
97
112
|
task_id TEXT UNIQUE,
|
|
@@ -152,6 +167,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
152
167
|
"""
|
|
153
168
|
self.database_path = database_path
|
|
154
169
|
self.conn: Optional[sqlite3.Connection] = None
|
|
170
|
+
self.lock = threading.RLock()
|
|
155
171
|
|
|
156
172
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
157
173
|
"""Create tables if they don't exist yet.
|
|
@@ -175,6 +191,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
175
191
|
|
|
176
192
|
# Create each table if not exists queries
|
|
177
193
|
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
194
|
+
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
178
195
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
179
196
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
180
197
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
@@ -970,6 +987,34 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
970
987
|
log(ERROR, "`node_id` does not exist.")
|
|
971
988
|
return False
|
|
972
989
|
|
|
990
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
991
|
+
"""Get the context for the specified `run_id`."""
|
|
992
|
+
# Retrieve context if any
|
|
993
|
+
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
994
|
+
rows = self.query(query, (convert_uint64_to_sint64(run_id),))
|
|
995
|
+
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
996
|
+
return context
|
|
997
|
+
|
|
998
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
999
|
+
"""Set the context for the specified `run_id`."""
|
|
1000
|
+
# Convert context to bytes
|
|
1001
|
+
context_bytes = context_to_bytes(context)
|
|
1002
|
+
sint_run_id = convert_uint64_to_sint64(run_id)
|
|
1003
|
+
|
|
1004
|
+
# Check if any existing Context assigned to the run_id
|
|
1005
|
+
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
1006
|
+
if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
|
|
1007
|
+
# Update context
|
|
1008
|
+
query = "UPDATE context SET context = ? WHERE run_id = ?;"
|
|
1009
|
+
self.query(query, (context_bytes, sint_run_id))
|
|
1010
|
+
else:
|
|
1011
|
+
try:
|
|
1012
|
+
# Store context
|
|
1013
|
+
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
|
|
1014
|
+
self.query(query, (sint_run_id, context_bytes))
|
|
1015
|
+
except sqlite3.IntegrityError:
|
|
1016
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1017
|
+
|
|
973
1018
|
def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
|
|
974
1019
|
"""Check if the TaskIns exists and is valid (not expired).
|
|
975
1020
|
|
|
@@ -1054,7 +1099,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
|
1054
1099
|
|
|
1055
1100
|
def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
1056
1101
|
"""Turn task_dict into protobuf message."""
|
|
1057
|
-
recordset =
|
|
1102
|
+
recordset = ProtoRecordSet()
|
|
1058
1103
|
recordset.ParseFromString(task_dict["recordset"])
|
|
1059
1104
|
|
|
1060
1105
|
result = TaskIns(
|
|
@@ -1084,7 +1129,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
|
1084
1129
|
|
|
1085
1130
|
def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
1086
1131
|
"""Turn task_dict into protobuf message."""
|
|
1087
|
-
recordset =
|
|
1132
|
+
recordset = ProtoRecordSet()
|
|
1088
1133
|
recordset.ParseFromString(task_dict["recordset"])
|
|
1089
1134
|
|
|
1090
1135
|
result = TaskRes(
|
|
@@ -20,10 +20,11 @@ from logging import ERROR
|
|
|
20
20
|
from os import urandom
|
|
21
21
|
from uuid import uuid4
|
|
22
22
|
|
|
23
|
-
from flwr.common import log
|
|
23
|
+
from flwr.common import Context, log, serde
|
|
24
24
|
from flwr.common.constant import ErrorCode, Status, SubStatus
|
|
25
25
|
from flwr.common.typing import RunStatus
|
|
26
26
|
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
|
|
27
|
+
from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611
|
|
27
28
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
29
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
29
30
|
|
|
@@ -135,6 +136,16 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
135
136
|
data_dict[key] = convert_sint64_to_uint64(data_dict[key])
|
|
136
137
|
|
|
137
138
|
|
|
139
|
+
def context_to_bytes(context: Context) -> bytes:
|
|
140
|
+
"""Serialize `Context` to bytes."""
|
|
141
|
+
return serde.context_to_proto(context).SerializeToString()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def context_from_bytes(context_bytes: bytes) -> Context:
|
|
145
|
+
"""Deserialize `Context` from bytes."""
|
|
146
|
+
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
|
147
|
+
|
|
148
|
+
|
|
138
149
|
def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
139
150
|
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
|
|
140
151
|
current_time = time.time()
|
|
@@ -29,7 +29,7 @@ from typing import Any, Optional
|
|
|
29
29
|
|
|
30
30
|
from flwr.cli.config_utils import load_and_validate
|
|
31
31
|
from flwr.client import ClientApp
|
|
32
|
-
from flwr.common import EventType, event, log, now
|
|
32
|
+
from flwr.common import Context, EventType, RecordSet, event, log, now
|
|
33
33
|
from flwr.common.config import get_fused_config_from_dir, parse_config_args
|
|
34
34
|
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
|
|
35
35
|
from flwr.common.logger import (
|
|
@@ -40,7 +40,7 @@ from flwr.common.logger import (
|
|
|
40
40
|
)
|
|
41
41
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
42
42
|
from flwr.server.driver import Driver, InMemoryDriver
|
|
43
|
-
from flwr.server.run_serverapp import run as
|
|
43
|
+
from flwr.server.run_serverapp import run as _run
|
|
44
44
|
from flwr.server.server_app import ServerApp
|
|
45
45
|
from flwr.server.superlink.fleet import vce
|
|
46
46
|
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
|
|
@@ -333,11 +333,19 @@ def run_serverapp_th(
|
|
|
333
333
|
log(INFO, "Enabling GPU growth for Tensorflow on the server thread.")
|
|
334
334
|
enable_gpu_growth()
|
|
335
335
|
|
|
336
|
+
# Initialize Context
|
|
337
|
+
context = Context(
|
|
338
|
+
node_id=0,
|
|
339
|
+
node_config={},
|
|
340
|
+
state=RecordSet(),
|
|
341
|
+
run_config=_server_app_run_config,
|
|
342
|
+
)
|
|
343
|
+
|
|
336
344
|
# Run ServerApp
|
|
337
|
-
|
|
345
|
+
_run(
|
|
338
346
|
driver=_driver,
|
|
347
|
+
context=context,
|
|
339
348
|
server_app_dir=_server_app_dir,
|
|
340
|
-
server_app_run_config=_server_app_run_config,
|
|
341
349
|
server_app_attr=_server_app_attr,
|
|
342
350
|
loaded_server_app=_server_app,
|
|
343
351
|
)
|
flwr/superexec/app.py
CHANGED
|
@@ -16,21 +16,11 @@
|
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
18
|
import sys
|
|
19
|
-
from logging import INFO
|
|
20
|
-
from pathlib import Path
|
|
21
|
-
from typing import Optional
|
|
19
|
+
from logging import INFO
|
|
22
20
|
|
|
23
|
-
import
|
|
24
|
-
|
|
25
|
-
from flwr.common import EventType, event, log
|
|
26
|
-
from flwr.common.address import parse_address
|
|
27
|
-
from flwr.common.config import parse_config_args
|
|
28
|
-
from flwr.common.constant import EXEC_API_DEFAULT_ADDRESS
|
|
29
|
-
from flwr.common.exit_handlers import register_exit_handlers
|
|
30
|
-
from flwr.common.logger import warn_deprecated_feature
|
|
21
|
+
from flwr.common import log
|
|
31
22
|
from flwr.common.object_ref import load_app, validate
|
|
32
23
|
|
|
33
|
-
from .exec_grpc import run_superexec_api_grpc
|
|
34
24
|
from .executor import Executor
|
|
35
25
|
|
|
36
26
|
|
|
@@ -38,137 +28,12 @@ def run_superexec() -> None:
|
|
|
38
28
|
"""Run Flower SuperExec."""
|
|
39
29
|
log(INFO, "Starting Flower SuperExec")
|
|
40
30
|
|
|
41
|
-
|
|
31
|
+
sys.exit(
|
|
42
32
|
"Manually launching the SuperExec is deprecated. Since `flwr 1.13.0` "
|
|
43
33
|
"the executor service runs in the SuperLink. Launching it manually is not "
|
|
44
34
|
"recommended."
|
|
45
35
|
)
|
|
46
36
|
|
|
47
|
-
event(EventType.RUN_SUPEREXEC_ENTER)
|
|
48
|
-
|
|
49
|
-
args = _parse_args_run_superexec().parse_args()
|
|
50
|
-
|
|
51
|
-
# Parse IP address
|
|
52
|
-
parsed_address = parse_address(args.address)
|
|
53
|
-
if not parsed_address:
|
|
54
|
-
sys.exit(f"SuperExec IP address ({args.address}) cannot be parsed.")
|
|
55
|
-
host, port, is_v6 = parsed_address
|
|
56
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
57
|
-
|
|
58
|
-
# Obtain certificates
|
|
59
|
-
certificates = _try_obtain_certificates(args)
|
|
60
|
-
|
|
61
|
-
# Start SuperExec API
|
|
62
|
-
superexec_server: grpc.Server = run_superexec_api_grpc(
|
|
63
|
-
address=address,
|
|
64
|
-
executor=load_executor(args),
|
|
65
|
-
certificates=certificates,
|
|
66
|
-
config=parse_config_args(
|
|
67
|
-
[args.executor_config] if args.executor_config else args.executor_config
|
|
68
|
-
),
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
grpc_servers = [superexec_server]
|
|
72
|
-
|
|
73
|
-
# Graceful shutdown
|
|
74
|
-
register_exit_handlers(
|
|
75
|
-
event_type=EventType.RUN_SUPEREXEC_LEAVE,
|
|
76
|
-
grpc_servers=grpc_servers,
|
|
77
|
-
bckg_threads=None,
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
superexec_server.wait_for_termination()
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def _parse_args_run_superexec() -> argparse.ArgumentParser:
|
|
84
|
-
"""Parse command line arguments for SuperExec."""
|
|
85
|
-
parser = argparse.ArgumentParser(
|
|
86
|
-
description="Start a Flower SuperExec",
|
|
87
|
-
)
|
|
88
|
-
parser.add_argument(
|
|
89
|
-
"--address",
|
|
90
|
-
help="SuperExec (gRPC) server address (IPv4, IPv6, or a domain name)",
|
|
91
|
-
default=EXEC_API_DEFAULT_ADDRESS,
|
|
92
|
-
)
|
|
93
|
-
parser.add_argument(
|
|
94
|
-
"--executor",
|
|
95
|
-
help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.",
|
|
96
|
-
default="flwr.superexec.deployment:executor",
|
|
97
|
-
)
|
|
98
|
-
parser.add_argument(
|
|
99
|
-
"--executor-dir",
|
|
100
|
-
help="The directory for the executor.",
|
|
101
|
-
default=".",
|
|
102
|
-
)
|
|
103
|
-
parser.add_argument(
|
|
104
|
-
"--executor-config",
|
|
105
|
-
help="Key-value pairs for the executor config, separated by spaces. "
|
|
106
|
-
'For example:\n\n`--executor-config \'superlink="superlink:9091" '
|
|
107
|
-
'root-certificates="certificates/superlink-ca.crt"\'`',
|
|
108
|
-
)
|
|
109
|
-
parser.add_argument(
|
|
110
|
-
"--insecure",
|
|
111
|
-
action="store_true",
|
|
112
|
-
help="Run the SuperExec without HTTPS, regardless of whether certificate "
|
|
113
|
-
"paths are provided. By default, the server runs with HTTPS enabled. "
|
|
114
|
-
"Use this flag only if you understand the risks.",
|
|
115
|
-
)
|
|
116
|
-
parser.add_argument(
|
|
117
|
-
"--ssl-certfile",
|
|
118
|
-
help="SuperExec server SSL certificate file (as a path str) "
|
|
119
|
-
"to create a secure connection.",
|
|
120
|
-
type=str,
|
|
121
|
-
default=None,
|
|
122
|
-
)
|
|
123
|
-
parser.add_argument(
|
|
124
|
-
"--ssl-keyfile",
|
|
125
|
-
help="SuperExec server SSL private key file (as a path str) "
|
|
126
|
-
"to create a secure connection.",
|
|
127
|
-
type=str,
|
|
128
|
-
)
|
|
129
|
-
parser.add_argument(
|
|
130
|
-
"--ssl-ca-certfile",
|
|
131
|
-
help="SuperExec server SSL CA certificate file (as a path str) "
|
|
132
|
-
"to create a secure connection.",
|
|
133
|
-
type=str,
|
|
134
|
-
)
|
|
135
|
-
return parser
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def _try_obtain_certificates(
|
|
139
|
-
args: argparse.Namespace,
|
|
140
|
-
) -> Optional[tuple[bytes, bytes, bytes]]:
|
|
141
|
-
# Obtain certificates
|
|
142
|
-
if args.insecure:
|
|
143
|
-
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
144
|
-
return None
|
|
145
|
-
# Check if certificates are provided
|
|
146
|
-
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
147
|
-
if not Path(args.ssl_ca_certfile).is_file():
|
|
148
|
-
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
149
|
-
if not Path(args.ssl_certfile).is_file():
|
|
150
|
-
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
151
|
-
if not Path(args.ssl_keyfile).is_file():
|
|
152
|
-
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
153
|
-
certificates = (
|
|
154
|
-
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
|
|
155
|
-
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
156
|
-
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
157
|
-
)
|
|
158
|
-
return certificates
|
|
159
|
-
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
|
|
160
|
-
sys.exit(
|
|
161
|
-
"You need to provide valid file paths to `--ssl-certfile`, "
|
|
162
|
-
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
|
|
163
|
-
"connection in SuperExec server (gRPC-rere)."
|
|
164
|
-
)
|
|
165
|
-
sys.exit(
|
|
166
|
-
"Certificates are required unless running in insecure mode. "
|
|
167
|
-
"Please provide certificate paths to `--ssl-certfile`, "
|
|
168
|
-
"`--ssl-keyfile`, and `—-ssl-ca-certfile` or run the server "
|
|
169
|
-
"in insecure mode using '--insecure' if you understand the risks."
|
|
170
|
-
)
|
|
171
|
-
|
|
172
37
|
|
|
173
38
|
def load_executor(
|
|
174
39
|
args: argparse.Namespace,
|
flwr/superexec/deployment.py
CHANGED
|
@@ -24,12 +24,11 @@ from typing_extensions import override
|
|
|
24
24
|
|
|
25
25
|
from flwr.cli.install import install_from_fab
|
|
26
26
|
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
|
|
27
|
-
from flwr.common.grpc import create_channel
|
|
28
27
|
from flwr.common.logger import log
|
|
29
|
-
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
30
28
|
from flwr.common.typing import Fab, UserConfig
|
|
31
|
-
from flwr.
|
|
32
|
-
from flwr.
|
|
29
|
+
from flwr.server.superlink.ffs import Ffs
|
|
30
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
31
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
33
32
|
|
|
34
33
|
from .executor import Executor, RunTracker
|
|
35
34
|
|
|
@@ -62,7 +61,30 @@ class DeploymentEngine(Executor):
|
|
|
62
61
|
self.root_certificates = root_certificates
|
|
63
62
|
self.root_certificates_bytes = Path(root_certificates).read_bytes()
|
|
64
63
|
self.flwr_dir = flwr_dir
|
|
65
|
-
self.
|
|
64
|
+
self.linkstate_factory: Optional[LinkStateFactory] = None
|
|
65
|
+
self.ffs_factory: Optional[FfsFactory] = None
|
|
66
|
+
|
|
67
|
+
@override
|
|
68
|
+
def initialize(
|
|
69
|
+
self, linkstate_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
70
|
+
) -> None:
|
|
71
|
+
"""Initialize the executor with the necessary factories."""
|
|
72
|
+
self.linkstate_factory = linkstate_factory
|
|
73
|
+
self.ffs_factory = ffs_factory
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def linkstate(self) -> LinkState:
|
|
77
|
+
"""Return the LinkState."""
|
|
78
|
+
if self.linkstate_factory is None:
|
|
79
|
+
raise RuntimeError("Executor is not initialized.")
|
|
80
|
+
return self.linkstate_factory.state()
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def ffs(self) -> Ffs:
|
|
84
|
+
"""Return the Flower File Storage (FFS)."""
|
|
85
|
+
if self.ffs_factory is None:
|
|
86
|
+
raise RuntimeError("Executor is not initialized.")
|
|
87
|
+
return self.ffs_factory.ffs()
|
|
66
88
|
|
|
67
89
|
@override
|
|
68
90
|
def set_config(
|
|
@@ -101,32 +123,19 @@ class DeploymentEngine(Executor):
|
|
|
101
123
|
raise ValueError("The `flwr-dir` value should be of type `str`.")
|
|
102
124
|
self.flwr_dir = str(flwr_dir)
|
|
103
125
|
|
|
104
|
-
def _connect(self) -> None:
|
|
105
|
-
if self.stub is not None:
|
|
106
|
-
return
|
|
107
|
-
channel = create_channel(
|
|
108
|
-
server_address=self.superlink,
|
|
109
|
-
insecure=(self.root_certificates_bytes is None),
|
|
110
|
-
root_certificates=self.root_certificates_bytes,
|
|
111
|
-
)
|
|
112
|
-
self.stub = DriverStub(channel)
|
|
113
|
-
|
|
114
126
|
def _create_run(
|
|
115
127
|
self,
|
|
116
128
|
fab: Fab,
|
|
117
129
|
override_config: UserConfig,
|
|
118
130
|
) -> int:
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
131
|
+
fab_hash = self.ffs.put(fab.content, {})
|
|
132
|
+
if fab_hash != fab.hash_str:
|
|
133
|
+
raise RuntimeError(
|
|
134
|
+
f"FAB ({fab.hash_str}) hash from request doesn't match contents"
|
|
135
|
+
)
|
|
123
136
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
override_config=user_config_to_proto(override_config),
|
|
127
|
-
)
|
|
128
|
-
res = self.stub.CreateRun(request=req)
|
|
129
|
-
return int(res.run_id)
|
|
137
|
+
run_id = self.linkstate.create_run(None, None, fab_hash, override_config)
|
|
138
|
+
return run_id
|
|
130
139
|
|
|
131
140
|
@override
|
|
132
141
|
def start_run(
|
flwr/superexec/exec_grpc.py
CHANGED
|
@@ -23,33 +23,40 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
|
23
23
|
from flwr.common.logger import log
|
|
24
24
|
from flwr.common.typing import UserConfig
|
|
25
25
|
from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
|
|
26
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
26
27
|
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
27
29
|
|
|
28
30
|
from .exec_servicer import ExecServicer
|
|
29
31
|
from .executor import Executor
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
|
|
34
|
+
# pylint: disable-next=too-many-arguments, too-many-positional-arguments
|
|
35
|
+
def run_exec_api_grpc(
|
|
33
36
|
address: str,
|
|
34
37
|
executor: Executor,
|
|
38
|
+
state_factory: LinkStateFactory,
|
|
39
|
+
ffs_factory: FfsFactory,
|
|
35
40
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
36
41
|
config: UserConfig,
|
|
37
42
|
) -> grpc.Server:
|
|
38
|
-
"""Run
|
|
43
|
+
"""Run Exec API (gRPC, request-response)."""
|
|
39
44
|
executor.set_config(config)
|
|
40
45
|
|
|
41
46
|
exec_servicer: grpc.Server = ExecServicer(
|
|
47
|
+
linkstate_factory=state_factory,
|
|
48
|
+
ffs_factory=ffs_factory,
|
|
42
49
|
executor=executor,
|
|
43
50
|
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
servicer_and_add_fn=(exec_servicer,
|
|
51
|
+
exec_add_servicer_to_server_fn = add_ExecServicer_to_server
|
|
52
|
+
exec_grpc_server = generic_create_grpc_server(
|
|
53
|
+
servicer_and_add_fn=(exec_servicer, exec_add_servicer_to_server_fn),
|
|
47
54
|
server_address=address,
|
|
48
55
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
49
56
|
certificates=certificates,
|
|
50
57
|
)
|
|
51
58
|
|
|
52
|
-
log(INFO, "
|
|
53
|
-
|
|
59
|
+
log(INFO, "Flower Deployment Engine: Starting Exec API on %s", address)
|
|
60
|
+
exec_grpc_server.start()
|
|
54
61
|
|
|
55
|
-
return
|
|
62
|
+
return exec_grpc_server
|
flwr/superexec/exec_servicer.py
CHANGED
|
@@ -34,6 +34,8 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
|
34
34
|
StreamLogsRequest,
|
|
35
35
|
StreamLogsResponse,
|
|
36
36
|
)
|
|
37
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
38
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
37
39
|
|
|
38
40
|
from .executor import Executor, RunTracker
|
|
39
41
|
|
|
@@ -43,8 +45,16 @@ SELECT_TIMEOUT = 1 # Timeout for selecting ready-to-read file descriptors (in s
|
|
|
43
45
|
class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
44
46
|
"""SuperExec API servicer."""
|
|
45
47
|
|
|
46
|
-
def __init__(
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
linkstate_factory: LinkStateFactory,
|
|
51
|
+
ffs_factory: FfsFactory,
|
|
52
|
+
executor: Executor,
|
|
53
|
+
) -> None:
|
|
54
|
+
self.linkstate_factory = linkstate_factory
|
|
55
|
+
self.ffs_factory = ffs_factory
|
|
47
56
|
self.executor = executor
|
|
57
|
+
self.executor.initialize(linkstate_factory, ffs_factory)
|
|
48
58
|
self.runs: dict[int, RunTracker] = {}
|
|
49
59
|
|
|
50
60
|
def StartRun(
|
flwr/superexec/executor.py
CHANGED
|
@@ -20,6 +20,8 @@ from subprocess import Popen
|
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
22
|
from flwr.common.typing import UserConfig
|
|
23
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
24
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
@dataclass
|
|
@@ -34,6 +36,23 @@ class RunTracker:
|
|
|
34
36
|
class Executor(ABC):
|
|
35
37
|
"""Execute and monitor a Flower run."""
|
|
36
38
|
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def initialize(
|
|
41
|
+
self, linkstate_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Initialize the executor with the necessary factories.
|
|
44
|
+
|
|
45
|
+
This method sets up the executor by providing it with the factories required
|
|
46
|
+
to access the LinkState and the Flower File Storage (FFS) in the SuperLink.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
linkstate_factory : LinkStateFactory
|
|
51
|
+
The factory to create access to the LinkState.
|
|
52
|
+
ffs_factory : FfsFactory
|
|
53
|
+
The factory to create access to the Flower File Storage (FFS).
|
|
54
|
+
"""
|
|
55
|
+
|
|
37
56
|
@abstractmethod
|
|
38
57
|
def set_config(
|
|
39
58
|
self,
|
flwr/superexec/simulation.py
CHANGED
|
@@ -29,6 +29,8 @@ from flwr.common.config import unflatten_dict
|
|
|
29
29
|
from flwr.common.constant import RUN_ID_NUM_BYTES
|
|
30
30
|
from flwr.common.logger import log
|
|
31
31
|
from flwr.common.typing import UserConfig
|
|
32
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
33
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
32
34
|
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
|
|
33
35
|
|
|
34
36
|
from .executor import Executor, RunTracker
|
|
@@ -70,6 +72,12 @@ class SimulationEngine(Executor):
|
|
|
70
72
|
self.num_supernodes = num_supernodes
|
|
71
73
|
self.verbose = verbose
|
|
72
74
|
|
|
75
|
+
@override
|
|
76
|
+
def initialize(
|
|
77
|
+
self, linkstate_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
78
|
+
) -> None:
|
|
79
|
+
"""Initialize the executor with the necessary factories."""
|
|
80
|
+
|
|
73
81
|
@override
|
|
74
82
|
def set_config(
|
|
75
83
|
self,
|