flwr-nightly 1.11.0.dev20240823__py3-none-any.whl → 1.11.1.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +0 -2
- flwr/cli/new/new.py +41 -40
- flwr/cli/new/templates/app/LICENSE.tpl +202 -0
- flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
- flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +2 -2
- flwr/client/__init__.py +0 -4
- flwr/client/app.py +3 -4
- flwr/client/client_app.py +2 -2
- flwr/client/grpc_rere_client/client_interceptor.py +15 -7
- flwr/client/supernode/app.py +8 -7
- flwr/common/config.py +14 -11
- flwr/common/constant.py +12 -1
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +24 -1
- flwr/common/telemetry.py +36 -30
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +21 -22
- flwr/server/compat/app.py +0 -5
- flwr/server/driver/grpc_driver.py +3 -6
- flwr/server/run_serverapp.py +20 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +13 -12
- flwr/server/superlink/fleet/vce/backend/raybackend.py +21 -12
- flwr/server/superlink/state/in_memory_state.py +15 -15
- flwr/server/superlink/state/sqlite_state.py +10 -10
- flwr/server/superlink/state/state.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +2 -2
- flwr/simulation/run_simulation.py +37 -8
- flwr/superexec/__init__.py +0 -6
- flwr/superexec/app.py +5 -3
- flwr/superexec/deployment.py +2 -2
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/METADATA +3 -3
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/RECORD +56 -48
- flwr_nightly-1.11.1.dev20240912.dist-info/entry_points.txt +10 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
- flwr_nightly-1.11.0.dev20240823.dist-info/entry_points.txt +0 -10
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/WHEEL +0 -0
flwr/common/telemetry.py
CHANGED
|
@@ -132,53 +132,59 @@ class EventType(str, Enum):
|
|
|
132
132
|
# Ping
|
|
133
133
|
PING = auto()
|
|
134
134
|
|
|
135
|
-
#
|
|
135
|
+
# --- LEGACY FUNCTIONS -------------------------------------------------------------
|
|
136
|
+
|
|
137
|
+
# Legacy: `start_client` function
|
|
136
138
|
START_CLIENT_ENTER = auto()
|
|
137
139
|
START_CLIENT_LEAVE = auto()
|
|
138
140
|
|
|
139
|
-
#
|
|
141
|
+
# Legacy: `start_server` function
|
|
140
142
|
START_SERVER_ENTER = auto()
|
|
141
143
|
START_SERVER_LEAVE = auto()
|
|
142
144
|
|
|
143
|
-
#
|
|
144
|
-
|
|
145
|
-
|
|
145
|
+
# Legacy: `start_simulation` function
|
|
146
|
+
START_SIMULATION_ENTER = auto()
|
|
147
|
+
START_SIMULATION_LEAVE = auto()
|
|
146
148
|
|
|
147
|
-
#
|
|
148
|
-
RUN_FLEET_API_ENTER = auto()
|
|
149
|
-
RUN_FLEET_API_LEAVE = auto()
|
|
149
|
+
# --- `flwr` CLI -------------------------------------------------------------------
|
|
150
150
|
|
|
151
|
-
#
|
|
152
|
-
RUN_SUPERLINK_ENTER = auto()
|
|
153
|
-
RUN_SUPERLINK_LEAVE = auto()
|
|
151
|
+
# Not yet implemented
|
|
154
152
|
|
|
155
|
-
#
|
|
156
|
-
START_SIMULATION_ENTER = auto()
|
|
157
|
-
START_SIMULATION_LEAVE = auto()
|
|
153
|
+
# --- SuperExec --------------------------------------------------------------------
|
|
158
154
|
|
|
159
|
-
#
|
|
160
|
-
|
|
161
|
-
|
|
155
|
+
# SuperExec
|
|
156
|
+
RUN_SUPEREXEC_ENTER = auto()
|
|
157
|
+
RUN_SUPEREXEC_LEAVE = auto()
|
|
162
158
|
|
|
163
|
-
#
|
|
164
|
-
START_DRIVER_ENTER = auto()
|
|
165
|
-
START_DRIVER_LEAVE = auto()
|
|
159
|
+
# --- Simulation Engine ------------------------------------------------------------
|
|
166
160
|
|
|
167
|
-
# flower-
|
|
168
|
-
|
|
169
|
-
|
|
161
|
+
# CLI: flower-simulation
|
|
162
|
+
CLI_FLOWER_SIMULATION_ENTER = auto()
|
|
163
|
+
CLI_FLOWER_SIMULATION_LEAVE = auto()
|
|
170
164
|
|
|
171
|
-
#
|
|
172
|
-
|
|
173
|
-
|
|
165
|
+
# Python API: `run_simulation`
|
|
166
|
+
PYTHON_API_RUN_SIMULATION_ENTER = auto()
|
|
167
|
+
PYTHON_API_RUN_SIMULATION_LEAVE = auto()
|
|
174
168
|
|
|
175
|
-
#
|
|
169
|
+
# --- Deployment Engine ------------------------------------------------------------
|
|
170
|
+
|
|
171
|
+
# CLI: `flower-superlink`
|
|
172
|
+
RUN_SUPERLINK_ENTER = auto()
|
|
173
|
+
RUN_SUPERLINK_LEAVE = auto()
|
|
174
|
+
|
|
175
|
+
# CLI: `flower-supernode`
|
|
176
176
|
RUN_SUPERNODE_ENTER = auto()
|
|
177
177
|
RUN_SUPERNODE_LEAVE = auto()
|
|
178
178
|
|
|
179
|
-
#
|
|
180
|
-
|
|
181
|
-
|
|
179
|
+
# CLI: `flower-server-app`
|
|
180
|
+
RUN_SERVER_APP_ENTER = auto()
|
|
181
|
+
RUN_SERVER_APP_LEAVE = auto()
|
|
182
|
+
|
|
183
|
+
# --- DEPRECATED -------------------------------------------------------------------
|
|
184
|
+
|
|
185
|
+
# [DEPRECATED] CLI: `flower-client-app`
|
|
186
|
+
RUN_CLIENT_APP_ENTER = auto()
|
|
187
|
+
RUN_CLIENT_APP_LEAVE = auto()
|
|
182
188
|
|
|
183
189
|
|
|
184
190
|
# Use the ThreadPoolExecutor with max_workers=1 to have a queue
|
flwr/server/__init__.py
CHANGED
|
@@ -17,14 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
from . import strategy
|
|
19
19
|
from . import workflow as workflow
|
|
20
|
-
from .app import run_superlink as run_superlink
|
|
21
20
|
from .app import start_server as start_server
|
|
22
21
|
from .client_manager import ClientManager as ClientManager
|
|
23
22
|
from .client_manager import SimpleClientManager as SimpleClientManager
|
|
24
23
|
from .compat import LegacyContext as LegacyContext
|
|
25
24
|
from .driver import Driver as Driver
|
|
26
25
|
from .history import History as History
|
|
27
|
-
from .run_serverapp import run_server_app as run_server_app
|
|
28
26
|
from .server import Server as Server
|
|
29
27
|
from .server_app import ServerApp as ServerApp
|
|
30
28
|
from .server_config import ServerConfig as ServerConfig
|
|
@@ -40,8 +38,6 @@ __all__ = [
|
|
|
40
38
|
"ServerAppComponents",
|
|
41
39
|
"ServerConfig",
|
|
42
40
|
"SimpleClientManager",
|
|
43
|
-
"run_server_app",
|
|
44
|
-
"run_superlink",
|
|
45
41
|
"start_server",
|
|
46
42
|
"strategy",
|
|
47
43
|
"workflow",
|
flwr/server/app.py
CHANGED
|
@@ -36,6 +36,10 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
|
36
36
|
from flwr.common.address import parse_address
|
|
37
37
|
from flwr.common.config import get_flwr_dir
|
|
38
38
|
from flwr.common.constant import (
|
|
39
|
+
DRIVER_API_DEFAULT_ADDRESS,
|
|
40
|
+
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
41
|
+
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
42
|
+
FLEET_API_REST_DEFAULT_ADDRESS,
|
|
39
43
|
MISSING_EXTRA_REST,
|
|
40
44
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
41
45
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
@@ -68,18 +72,13 @@ from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
|
68
72
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
69
73
|
from .superlink.state import StateFactory
|
|
70
74
|
|
|
71
|
-
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
72
|
-
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
73
|
-
ADDRESS_FLEET_API_GRPC_BIDI = "[::]:8080" # IPv6 to keep start_server compatible
|
|
74
|
-
ADDRESS_FLEET_API_REST = "0.0.0.0:9093"
|
|
75
|
-
|
|
76
75
|
DATABASE = ":flwr-in-memory-state:"
|
|
77
76
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
78
77
|
|
|
79
78
|
|
|
80
79
|
def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
81
80
|
*,
|
|
82
|
-
server_address: str =
|
|
81
|
+
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
83
82
|
server: Optional[Server] = None,
|
|
84
83
|
config: Optional[ServerConfig] = None,
|
|
85
84
|
strategy: Optional[Strategy] = None,
|
|
@@ -232,9 +231,9 @@ def run_superlink() -> None:
|
|
|
232
231
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
233
232
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
234
233
|
]:
|
|
235
|
-
args.fleet_api_address =
|
|
234
|
+
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
|
|
236
235
|
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
237
|
-
args.fleet_api_address =
|
|
236
|
+
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS
|
|
238
237
|
|
|
239
238
|
fleet_address, host, port = _format_address(args.fleet_api_address)
|
|
240
239
|
|
|
@@ -278,24 +277,24 @@ def run_superlink() -> None:
|
|
|
278
277
|
fleet_thread.start()
|
|
279
278
|
bckg_threads.append(fleet_thread)
|
|
280
279
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
281
|
-
maybe_keys =
|
|
280
|
+
maybe_keys = _try_setup_node_authentication(args, certificates)
|
|
282
281
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
283
282
|
if maybe_keys is not None:
|
|
284
283
|
(
|
|
285
|
-
|
|
284
|
+
node_public_keys,
|
|
286
285
|
server_private_key,
|
|
287
286
|
server_public_key,
|
|
288
287
|
) = maybe_keys
|
|
289
288
|
state = state_factory.state()
|
|
290
|
-
state.
|
|
289
|
+
state.store_node_public_keys(node_public_keys)
|
|
291
290
|
state.store_server_private_public_key(
|
|
292
291
|
private_key_to_bytes(server_private_key),
|
|
293
292
|
public_key_to_bytes(server_public_key),
|
|
294
293
|
)
|
|
295
294
|
log(
|
|
296
295
|
INFO,
|
|
297
|
-
"
|
|
298
|
-
len(
|
|
296
|
+
"Node authentication enabled with %d known public keys",
|
|
297
|
+
len(node_public_keys),
|
|
299
298
|
)
|
|
300
299
|
interceptors = [AuthenticateServerInterceptor(state)]
|
|
301
300
|
|
|
@@ -344,7 +343,7 @@ def _format_address(address: str) -> Tuple[str, str, int]:
|
|
|
344
343
|
return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
|
|
345
344
|
|
|
346
345
|
|
|
347
|
-
def
|
|
346
|
+
def _try_setup_node_authentication(
|
|
348
347
|
args: argparse.Namespace,
|
|
349
348
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
350
349
|
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
@@ -373,16 +372,16 @@ def _try_setup_client_authentication(
|
|
|
373
372
|
"`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
|
|
374
373
|
)
|
|
375
374
|
|
|
376
|
-
|
|
377
|
-
if not
|
|
375
|
+
node_keys_file_path = Path(args.auth_list_public_keys)
|
|
376
|
+
if not node_keys_file_path.exists():
|
|
378
377
|
sys.exit(
|
|
379
378
|
"The provided path to the known public keys CSV file does not exist: "
|
|
380
|
-
f"{
|
|
379
|
+
f"{node_keys_file_path}. "
|
|
381
380
|
"Please provide the CSV file path containing known public keys "
|
|
382
381
|
"to '--auth-list-public-keys'."
|
|
383
382
|
)
|
|
384
383
|
|
|
385
|
-
|
|
384
|
+
node_public_keys: Set[bytes] = set()
|
|
386
385
|
|
|
387
386
|
try:
|
|
388
387
|
ssh_private_key = load_ssh_private_key(
|
|
@@ -413,13 +412,13 @@ def _try_setup_client_authentication(
|
|
|
413
412
|
"path points to a valid public key file and try again."
|
|
414
413
|
)
|
|
415
414
|
|
|
416
|
-
with open(
|
|
415
|
+
with open(node_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
417
416
|
reader = csv.reader(csvfile)
|
|
418
417
|
for row in reader:
|
|
419
418
|
for element in row:
|
|
420
419
|
public_key = load_ssh_public_key(element.encode())
|
|
421
420
|
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
422
|
-
|
|
421
|
+
node_public_keys.add(public_key_to_bytes(public_key))
|
|
423
422
|
else:
|
|
424
423
|
sys.exit(
|
|
425
424
|
"Error: Unable to parse the public keys in the CSV "
|
|
@@ -427,7 +426,7 @@ def _try_setup_client_authentication(
|
|
|
427
426
|
"known SSH public keys files and try again."
|
|
428
427
|
)
|
|
429
428
|
return (
|
|
430
|
-
|
|
429
|
+
node_public_keys,
|
|
431
430
|
ssh_private_key,
|
|
432
431
|
ssh_public_key,
|
|
433
432
|
)
|
|
@@ -653,7 +652,7 @@ def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
|
|
653
652
|
parser.add_argument(
|
|
654
653
|
"--driver-api-address",
|
|
655
654
|
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name).",
|
|
656
|
-
default=
|
|
655
|
+
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
657
656
|
)
|
|
658
657
|
|
|
659
658
|
|
flwr/server/compat/app.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
from logging import INFO
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import EventType, event
|
|
22
21
|
from flwr.common.logger import log
|
|
23
22
|
from flwr.server.client_manager import ClientManager
|
|
24
23
|
from flwr.server.history import History
|
|
@@ -65,8 +64,6 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
65
64
|
hist : flwr.server.history.History
|
|
66
65
|
Object containing training and evaluation metrics.
|
|
67
66
|
"""
|
|
68
|
-
event(EventType.START_DRIVER_ENTER)
|
|
69
|
-
|
|
70
67
|
# Initialize the Driver API server and config
|
|
71
68
|
initialized_server, initialized_config = init_defaults(
|
|
72
69
|
server=server,
|
|
@@ -96,6 +93,4 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
96
93
|
f_stop.set()
|
|
97
94
|
thread.join()
|
|
98
95
|
|
|
99
|
-
event(EventType.START_SERVER_LEAVE)
|
|
100
|
-
|
|
101
96
|
return hist
|
|
@@ -21,7 +21,8 @@ from typing import Iterable, List, Optional, cast
|
|
|
21
21
|
|
|
22
22
|
import grpc
|
|
23
23
|
|
|
24
|
-
from flwr.common import DEFAULT_TTL,
|
|
24
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
25
|
+
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
|
|
25
26
|
from flwr.common.grpc import create_channel
|
|
26
27
|
from flwr.common.logger import log
|
|
27
28
|
from flwr.common.serde import (
|
|
@@ -45,8 +46,6 @@ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
|
45
46
|
|
|
46
47
|
from .driver import Driver
|
|
47
48
|
|
|
48
|
-
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
49
|
-
|
|
50
49
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
51
50
|
[Driver] Error: Not connected.
|
|
52
51
|
|
|
@@ -73,7 +72,7 @@ class GrpcDriver(Driver):
|
|
|
73
72
|
def __init__( # pylint: disable=too-many-arguments
|
|
74
73
|
self,
|
|
75
74
|
run_id: int,
|
|
76
|
-
driver_service_address: str =
|
|
75
|
+
driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
|
|
77
76
|
root_certificates: Optional[bytes] = None,
|
|
78
77
|
) -> None:
|
|
79
78
|
self._run_id = run_id
|
|
@@ -94,7 +93,6 @@ class GrpcDriver(Driver):
|
|
|
94
93
|
|
|
95
94
|
This will not call GetRun.
|
|
96
95
|
"""
|
|
97
|
-
event(EventType.DRIVER_CONNECT)
|
|
98
96
|
if self._is_connected:
|
|
99
97
|
log(WARNING, "Already connected")
|
|
100
98
|
return
|
|
@@ -108,7 +106,6 @@ class GrpcDriver(Driver):
|
|
|
108
106
|
|
|
109
107
|
def _disconnect(self) -> None:
|
|
110
108
|
"""Disconnect from the Driver API."""
|
|
111
|
-
event(EventType.DRIVER_DISCONNECT)
|
|
112
109
|
if not self._is_connected:
|
|
113
110
|
log(DEBUG, "Already disconnected")
|
|
114
111
|
return
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -31,6 +31,7 @@ from flwr.common.config import (
|
|
|
31
31
|
get_project_config,
|
|
32
32
|
get_project_dir,
|
|
33
33
|
)
|
|
34
|
+
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
|
|
34
35
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
35
36
|
from flwr.common.object_ref import load_app
|
|
36
37
|
from flwr.common.typing import UserConfig
|
|
@@ -44,8 +45,6 @@ from .driver import Driver
|
|
|
44
45
|
from .driver.grpc_driver import GrpcDriver
|
|
45
46
|
from .server_app import LoadServerAppError, ServerApp
|
|
46
47
|
|
|
47
|
-
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
48
|
-
|
|
49
48
|
|
|
50
49
|
def run(
|
|
51
50
|
driver: Driver,
|
|
@@ -97,11 +96,26 @@ def run_server_app() -> None:
|
|
|
97
96
|
|
|
98
97
|
args = _parse_args_run_server_app().parse_args()
|
|
99
98
|
|
|
100
|
-
if
|
|
99
|
+
# Check if the server app reference is passed.
|
|
100
|
+
# Since Flower 1.11, passing a reference is not allowed.
|
|
101
|
+
app_path: Optional[str] = args.app
|
|
102
|
+
# If the provided app_path doesn't exist, and contains a ":",
|
|
103
|
+
# it is likely to be a server app reference instead of a path.
|
|
104
|
+
if app_path is not None and not Path(app_path).exists() and ":" in app_path:
|
|
105
|
+
sys.exit(
|
|
106
|
+
"It appears you've passed a reference like `server:app`.\n\n"
|
|
107
|
+
"Note that since version `1.11.0`, `flower-server-app` no longer supports "
|
|
108
|
+
"passing a reference to a `ServerApp` attribute. Instead, you need to pass "
|
|
109
|
+
"the path to Flower app via the argument `--app`. This is the path to a "
|
|
110
|
+
"directory containing a `pyproject.toml`. You can create a valid Flower "
|
|
111
|
+
"app by executing `flwr new` and following the prompt."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if args.server != DRIVER_API_DEFAULT_ADDRESS:
|
|
101
115
|
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
102
116
|
warn_deprecated_feature(warn)
|
|
103
117
|
|
|
104
|
-
if args.superlink !=
|
|
118
|
+
if args.superlink != DRIVER_API_DEFAULT_ADDRESS:
|
|
105
119
|
# if `--superlink` also passed, then
|
|
106
120
|
# warn user that this argument overrides what was passed with `--server`
|
|
107
121
|
log(
|
|
@@ -151,7 +165,6 @@ def run_server_app() -> None:
|
|
|
151
165
|
cert_path,
|
|
152
166
|
)
|
|
153
167
|
|
|
154
|
-
app_path: Optional[str] = args.app
|
|
155
168
|
if not (app_path is None) ^ (args.run_id is None):
|
|
156
169
|
raise sys.exit(
|
|
157
170
|
"Please provide either a Flower App path or a Run ID, but not both. "
|
|
@@ -261,12 +274,12 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
261
274
|
)
|
|
262
275
|
parser.add_argument(
|
|
263
276
|
"--server",
|
|
264
|
-
default=
|
|
277
|
+
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
265
278
|
help="Server address",
|
|
266
279
|
)
|
|
267
280
|
parser.add_argument(
|
|
268
281
|
"--superlink",
|
|
269
|
-
default=
|
|
282
|
+
default=DRIVER_API_DEFAULT_ADDRESS,
|
|
270
283
|
help="SuperLink Driver API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
271
284
|
)
|
|
272
285
|
parser.add_argument(
|
|
@@ -51,19 +51,22 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
51
51
|
self, request: CreateNodeRequest, context: grpc.ServicerContext
|
|
52
52
|
) -> CreateNodeResponse:
|
|
53
53
|
"""."""
|
|
54
|
-
log(INFO, "
|
|
54
|
+
log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
|
|
55
|
+
log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
|
|
55
56
|
response = message_handler.create_node(
|
|
56
57
|
request=request,
|
|
57
58
|
state=self.state_factory.state(),
|
|
58
59
|
)
|
|
59
|
-
log(INFO, "
|
|
60
|
+
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
|
61
|
+
log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
|
|
60
62
|
return response
|
|
61
63
|
|
|
62
64
|
def DeleteNode(
|
|
63
65
|
self, request: DeleteNodeRequest, context: grpc.ServicerContext
|
|
64
66
|
) -> DeleteNodeResponse:
|
|
65
67
|
"""."""
|
|
66
|
-
log(INFO, "
|
|
68
|
+
log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
|
|
69
|
+
log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
|
|
67
70
|
return message_handler.delete_node(
|
|
68
71
|
request=request,
|
|
69
72
|
state=self.state_factory.state(),
|
|
@@ -71,7 +74,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
71
74
|
|
|
72
75
|
def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
|
|
73
76
|
"""."""
|
|
74
|
-
log(DEBUG, "
|
|
77
|
+
log(DEBUG, "[Fleet.Ping] Request: %s", request)
|
|
75
78
|
return message_handler.ping(
|
|
76
79
|
request=request,
|
|
77
80
|
state=self.state_factory.state(),
|
|
@@ -81,7 +84,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
81
84
|
self, request: PullTaskInsRequest, context: grpc.ServicerContext
|
|
82
85
|
) -> PullTaskInsResponse:
|
|
83
86
|
"""Pull TaskIns."""
|
|
84
|
-
log(INFO, "
|
|
87
|
+
log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id)
|
|
88
|
+
log(DEBUG, "[Fleet.PullTaskIns] Request: %s", request)
|
|
85
89
|
return message_handler.pull_task_ins(
|
|
86
90
|
request=request,
|
|
87
91
|
state=self.state_factory.state(),
|
|
@@ -91,7 +95,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
91
95
|
self, request: PushTaskResRequest, context: grpc.ServicerContext
|
|
92
96
|
) -> PushTaskResResponse:
|
|
93
97
|
"""Push TaskRes."""
|
|
94
|
-
|
|
98
|
+
if request.task_res_list:
|
|
99
|
+
log(
|
|
100
|
+
INFO,
|
|
101
|
+
"[Fleet.PushTaskRes] Push results from node_id=%s",
|
|
102
|
+
request.task_res_list[0].task.producer.node_id,
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
|
95
106
|
return message_handler.push_task_res(
|
|
96
107
|
request=request,
|
|
97
108
|
state=self.state_factory.state(),
|
|
@@ -101,7 +112,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
101
112
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
102
113
|
) -> GetRunResponse:
|
|
103
114
|
"""Get run information."""
|
|
104
|
-
log(INFO, "
|
|
115
|
+
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
|
|
105
116
|
return message_handler.get_run(
|
|
106
117
|
request=request,
|
|
107
118
|
state=self.state_factory.state(),
|
|
@@ -111,7 +122,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
111
122
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
112
123
|
) -> GetFabResponse:
|
|
113
124
|
"""Get FAB."""
|
|
114
|
-
log(
|
|
125
|
+
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
|
|
115
126
|
return message_handler.get_fab(
|
|
116
127
|
request=request,
|
|
117
128
|
ffs=self.ffs_factory.ffs(),
|
|
@@ -78,13 +78,13 @@ def _get_value_from_tuples(
|
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
81
|
-
"""Server interceptor for
|
|
81
|
+
"""Server interceptor for node authentication."""
|
|
82
82
|
|
|
83
83
|
def __init__(self, state: State):
|
|
84
84
|
self.state = state
|
|
85
85
|
|
|
86
|
-
self.
|
|
87
|
-
if len(self.
|
|
86
|
+
self.node_public_keys = state.get_node_public_keys()
|
|
87
|
+
if len(self.node_public_keys) == 0:
|
|
88
88
|
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
89
89
|
|
|
90
90
|
private_key = self.state.get_server_private_key()
|
|
@@ -103,9 +103,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
103
103
|
) -> grpc.RpcMethodHandler:
|
|
104
104
|
"""Flower server interceptor authentication logic.
|
|
105
105
|
|
|
106
|
-
Intercept all unary calls from
|
|
107
|
-
|
|
108
|
-
|
|
106
|
+
Intercept all unary calls from nodes and authenticate nodes by validating auth
|
|
107
|
+
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
|
108
|
+
terminate RPC call by setting context to abort.
|
|
109
109
|
"""
|
|
110
110
|
# One of the method handlers in
|
|
111
111
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
@@ -119,17 +119,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
119
119
|
request: Request,
|
|
120
120
|
context: grpc.ServicerContext,
|
|
121
121
|
) -> Response:
|
|
122
|
-
|
|
122
|
+
node_public_key_bytes = base64.urlsafe_b64decode(
|
|
123
123
|
_get_value_from_tuples(
|
|
124
124
|
_PUBLIC_KEY_HEADER, context.invocation_metadata()
|
|
125
125
|
)
|
|
126
126
|
)
|
|
127
|
-
if
|
|
127
|
+
if node_public_key_bytes not in self.node_public_keys:
|
|
128
128
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
129
129
|
|
|
130
130
|
if isinstance(request, CreateNodeRequest):
|
|
131
131
|
response = self._create_authenticated_node(
|
|
132
|
-
|
|
132
|
+
node_public_key_bytes, request, context
|
|
133
133
|
)
|
|
134
134
|
log(
|
|
135
135
|
INFO,
|
|
@@ -144,13 +144,13 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
144
144
|
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
|
145
145
|
)
|
|
146
146
|
)
|
|
147
|
-
public_key = bytes_to_public_key(
|
|
147
|
+
public_key = bytes_to_public_key(node_public_key_bytes)
|
|
148
148
|
|
|
149
149
|
if not self._verify_hmac(public_key, request, hmac_value):
|
|
150
150
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
151
151
|
|
|
152
152
|
# Verify node_id
|
|
153
|
-
node_id = self.state.get_node_id(
|
|
153
|
+
node_id = self.state.get_node_id(node_public_key_bytes)
|
|
154
154
|
|
|
155
155
|
if not self._verify_node_id(node_id, request):
|
|
156
156
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
@@ -188,7 +188,8 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
188
188
|
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
|
189
189
|
) -> bool:
|
|
190
190
|
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
|
191
|
-
|
|
191
|
+
message_bytes = request.SerializeToString(deterministic=True)
|
|
192
|
+
return verify_hmac(shared_secret, message_bytes, hmac_value)
|
|
192
193
|
|
|
193
194
|
def _create_authenticated_node(
|
|
194
195
|
self,
|
|
@@ -52,16 +52,11 @@ class RayBackend(Backend):
|
|
|
52
52
|
|
|
53
53
|
# Validate client resources
|
|
54
54
|
self.client_resources_key = "client_resources"
|
|
55
|
-
client_resources = self._validate_client_resources(config=backend_config)
|
|
55
|
+
self.client_resources = self._validate_client_resources(config=backend_config)
|
|
56
56
|
|
|
57
|
-
#
|
|
58
|
-
actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
59
|
-
|
|
60
|
-
self.pool = BasicActorPool(
|
|
61
|
-
actor_type=ClientAppActor,
|
|
62
|
-
client_resources=client_resources,
|
|
63
|
-
actor_kwargs=actor_kwargs,
|
|
64
|
-
)
|
|
57
|
+
# Valide actor resources
|
|
58
|
+
self.actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
59
|
+
self.pool: Optional[BasicActorPool] = None
|
|
65
60
|
|
|
66
61
|
self.app_fn: Optional[Callable[[], ClientApp]] = None
|
|
67
62
|
|
|
@@ -122,14 +117,24 @@ class RayBackend(Backend):
|
|
|
122
117
|
@property
|
|
123
118
|
def num_workers(self) -> int:
|
|
124
119
|
"""Return number of actors in pool."""
|
|
125
|
-
return self.pool.num_actors
|
|
120
|
+
return self.pool.num_actors if self.pool else 0
|
|
126
121
|
|
|
127
122
|
def is_worker_idle(self) -> bool:
|
|
128
123
|
"""Report whether the pool has idle actors."""
|
|
129
|
-
return self.pool.is_actor_available()
|
|
124
|
+
return self.pool.is_actor_available() if self.pool else False
|
|
130
125
|
|
|
131
126
|
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
|
132
127
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
128
|
+
# Create Actor Pool
|
|
129
|
+
try:
|
|
130
|
+
self.pool = BasicActorPool(
|
|
131
|
+
actor_type=ClientAppActor,
|
|
132
|
+
client_resources=self.client_resources,
|
|
133
|
+
actor_kwargs=self.actor_kwargs,
|
|
134
|
+
)
|
|
135
|
+
except Exception as ex:
|
|
136
|
+
raise ex
|
|
137
|
+
|
|
133
138
|
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
134
139
|
# Set ClientApp callable that ray actors will use
|
|
135
140
|
self.app_fn = app_fn
|
|
@@ -146,6 +151,9 @@ class RayBackend(Backend):
|
|
|
146
151
|
"""
|
|
147
152
|
partition_id = context.node_config[PARTITION_ID_KEY]
|
|
148
153
|
|
|
154
|
+
if self.pool is None:
|
|
155
|
+
raise ValueError("The actor pool is empty, unfit to process messages.")
|
|
156
|
+
|
|
149
157
|
if self.app_fn is None:
|
|
150
158
|
raise ValueError(
|
|
151
159
|
"Unspecified function to load a `ClientApp`. "
|
|
@@ -179,6 +187,7 @@ class RayBackend(Backend):
|
|
|
179
187
|
|
|
180
188
|
def terminate(self) -> None:
|
|
181
189
|
"""Terminate all actors in actor pool."""
|
|
182
|
-
self.pool
|
|
190
|
+
if self.pool:
|
|
191
|
+
self.pool.terminate_all_actors()
|
|
183
192
|
ray.shutdown()
|
|
184
193
|
log(DEBUG, "Terminated %s", self.__class__.__name__)
|