flwr-nightly 1.13.0.dev20241019__py3-none-any.whl → 1.13.0.dev20241106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +18 -83
- flwr/client/app.py +13 -14
- flwr/client/clientapp/app.py +1 -2
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/constant.py +39 -4
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +103 -0
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +9 -0
- flwr/proto/exec_pb2.py +6 -6
- flwr/proto/exec_pb2.pyi +8 -2
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +171 -0
- flwr/proto/simulationio_pb2_grpc.pyi +68 -0
- flwr/server/app.py +247 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +26 -33
- flwr/server/driver/inmemory_driver.py +6 -14
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +270 -0
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +180 -21
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +144 -15
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +300 -50
- flwr/server/superlink/{state → linkstate}/utils.py +84 -2
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
- flwr/simulation/__init__.py +2 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +60 -65
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +34 -63
- flwr/superexec/executor.py +22 -4
- flwr/superexec/simulation.py +13 -8
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/RECORD +77 -64
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/entry_points.txt +1 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/WHEEL +0 -0
|
@@ -23,7 +23,7 @@ from typing import Optional, cast
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
-
from flwr.common.constant import
|
|
26
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS
|
|
27
27
|
from flwr.common.grpc import create_channel
|
|
28
28
|
from flwr.common.logger import log
|
|
29
29
|
from flwr.common.serde import (
|
|
@@ -32,7 +32,9 @@ from flwr.common.serde import (
|
|
|
32
32
|
user_config_from_proto,
|
|
33
33
|
)
|
|
34
34
|
from flwr.common.typing import Run
|
|
35
|
-
from flwr.proto.
|
|
35
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
36
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
37
|
+
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
36
38
|
GetNodesRequest,
|
|
37
39
|
GetNodesResponse,
|
|
38
40
|
PullTaskResRequest,
|
|
@@ -40,9 +42,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
40
42
|
PushTaskInsRequest,
|
|
41
43
|
PushTaskInsResponse,
|
|
42
44
|
)
|
|
43
|
-
from flwr.proto.
|
|
44
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
45
|
-
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
45
|
+
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
46
46
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
47
47
|
|
|
48
48
|
from .driver import Driver
|
|
@@ -56,14 +56,12 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class GrpcDriver(Driver):
|
|
59
|
-
"""`GrpcDriver` provides an interface to the
|
|
59
|
+
"""`GrpcDriver` provides an interface to the ServerAppIo API.
|
|
60
60
|
|
|
61
61
|
Parameters
|
|
62
62
|
----------
|
|
63
|
-
|
|
64
|
-
The
|
|
65
|
-
driver_service_address : str (default: "[::]:9091")
|
|
66
|
-
The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
|
|
63
|
+
serverappio_service_address : str (default: "[::]:9091")
|
|
64
|
+
The address (URL, IPv6, IPv4) of the SuperLink ServerAppIo API service.
|
|
67
65
|
root_certificates : Optional[bytes] (default: None)
|
|
68
66
|
The PEM-encoded root certificates as a byte string.
|
|
69
67
|
If provided, a secure connection using the certificates will be
|
|
@@ -72,25 +70,23 @@ class GrpcDriver(Driver):
|
|
|
72
70
|
|
|
73
71
|
def __init__( # pylint: disable=too-many-arguments
|
|
74
72
|
self,
|
|
75
|
-
|
|
76
|
-
driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
|
|
73
|
+
serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_ADDRESS,
|
|
77
74
|
root_certificates: Optional[bytes] = None,
|
|
78
75
|
) -> None:
|
|
79
|
-
self.
|
|
80
|
-
self._addr = driver_service_address
|
|
76
|
+
self._addr = serverappio_service_address
|
|
81
77
|
self._cert = root_certificates
|
|
82
78
|
self._run: Optional[Run] = None
|
|
83
|
-
self._grpc_stub: Optional[
|
|
79
|
+
self._grpc_stub: Optional[ServerAppIoStub] = None
|
|
84
80
|
self._channel: Optional[grpc.Channel] = None
|
|
85
81
|
self.node = Node(node_id=0, anonymous=True)
|
|
86
82
|
|
|
87
83
|
@property
|
|
88
84
|
def _is_connected(self) -> bool:
|
|
89
|
-
"""Check if connected to the
|
|
85
|
+
"""Check if connected to the ServerAppIo API server."""
|
|
90
86
|
return self._channel is not None
|
|
91
87
|
|
|
92
88
|
def _connect(self) -> None:
|
|
93
|
-
"""Connect to the
|
|
89
|
+
"""Connect to the ServerAppIo API.
|
|
94
90
|
|
|
95
91
|
This will not call GetRun.
|
|
96
92
|
"""
|
|
@@ -102,11 +98,11 @@ class GrpcDriver(Driver):
|
|
|
102
98
|
insecure=(self._cert is None),
|
|
103
99
|
root_certificates=self._cert,
|
|
104
100
|
)
|
|
105
|
-
self._grpc_stub =
|
|
101
|
+
self._grpc_stub = ServerAppIoStub(self._channel)
|
|
106
102
|
log(DEBUG, "[Driver] Connected to %s", self._addr)
|
|
107
103
|
|
|
108
104
|
def _disconnect(self) -> None:
|
|
109
|
-
"""Disconnect from the
|
|
105
|
+
"""Disconnect from the ServerAppIo API."""
|
|
110
106
|
if not self._is_connected:
|
|
111
107
|
log(DEBUG, "Already disconnected")
|
|
112
108
|
return
|
|
@@ -116,15 +112,17 @@ class GrpcDriver(Driver):
|
|
|
116
112
|
channel.close()
|
|
117
113
|
log(DEBUG, "[Driver] Disconnected")
|
|
118
114
|
|
|
119
|
-
def
|
|
115
|
+
def init_run(self, run_id: int) -> None:
|
|
116
|
+
"""Initialize the run."""
|
|
120
117
|
# Check if is initialized
|
|
121
118
|
if self._run is not None:
|
|
122
119
|
return
|
|
120
|
+
|
|
123
121
|
# Get the run info
|
|
124
|
-
req = GetRunRequest(run_id=
|
|
122
|
+
req = GetRunRequest(run_id=run_id)
|
|
125
123
|
res: GetRunResponse = self._stub.GetRun(req)
|
|
126
124
|
if not res.HasField("run"):
|
|
127
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
125
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
128
126
|
self._run = Run(
|
|
129
127
|
run_id=res.run.run_id,
|
|
130
128
|
fab_id=res.run.fab_id,
|
|
@@ -136,21 +134,20 @@ class GrpcDriver(Driver):
|
|
|
136
134
|
@property
|
|
137
135
|
def run(self) -> Run:
|
|
138
136
|
"""Run information."""
|
|
139
|
-
self._init_run()
|
|
140
137
|
return Run(**vars(self._run))
|
|
141
138
|
|
|
142
139
|
@property
|
|
143
|
-
def _stub(self) ->
|
|
144
|
-
"""
|
|
140
|
+
def _stub(self) -> ServerAppIoStub:
|
|
141
|
+
"""ServerAppIo stub."""
|
|
145
142
|
if not self._is_connected:
|
|
146
143
|
self._connect()
|
|
147
|
-
return cast(
|
|
144
|
+
return cast(ServerAppIoStub, self._grpc_stub)
|
|
148
145
|
|
|
149
146
|
def _check_message(self, message: Message) -> None:
|
|
150
147
|
# Check if the message is valid
|
|
151
148
|
if not (
|
|
152
149
|
# Assume self._run being initialized
|
|
153
|
-
message.metadata.run_id == self.
|
|
150
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
154
151
|
and message.metadata.src_node_id == self.node.node_id
|
|
155
152
|
and message.metadata.message_id == ""
|
|
156
153
|
and message.metadata.reply_to_message == ""
|
|
@@ -171,7 +168,6 @@ class GrpcDriver(Driver):
|
|
|
171
168
|
This method constructs a new `Message` with given content and metadata.
|
|
172
169
|
The `run_id` and `src_node_id` will be set automatically.
|
|
173
170
|
"""
|
|
174
|
-
self._init_run()
|
|
175
171
|
if ttl:
|
|
176
172
|
warnings.warn(
|
|
177
173
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -182,7 +178,7 @@ class GrpcDriver(Driver):
|
|
|
182
178
|
|
|
183
179
|
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
184
180
|
metadata = Metadata(
|
|
185
|
-
run_id=self.
|
|
181
|
+
run_id=cast(Run, self._run).run_id,
|
|
186
182
|
message_id="", # Will be set by the server
|
|
187
183
|
src_node_id=self.node.node_id,
|
|
188
184
|
dst_node_id=dst_node_id,
|
|
@@ -195,10 +191,9 @@ class GrpcDriver(Driver):
|
|
|
195
191
|
|
|
196
192
|
def get_node_ids(self) -> list[int]:
|
|
197
193
|
"""Get node IDs."""
|
|
198
|
-
self._init_run()
|
|
199
194
|
# Call GrpcDriverStub method
|
|
200
195
|
res: GetNodesResponse = self._stub.GetNodes(
|
|
201
|
-
GetNodesRequest(run_id=self.
|
|
196
|
+
GetNodesRequest(run_id=cast(Run, self._run).run_id)
|
|
202
197
|
)
|
|
203
198
|
return [node.node_id for node in res.nodes]
|
|
204
199
|
|
|
@@ -208,7 +203,6 @@ class GrpcDriver(Driver):
|
|
|
208
203
|
This method takes an iterable of messages and sends each message
|
|
209
204
|
to the node specified in `dst_node_id`.
|
|
210
205
|
"""
|
|
211
|
-
self._init_run()
|
|
212
206
|
# Construct TaskIns
|
|
213
207
|
task_ins_list: list[TaskIns] = []
|
|
214
208
|
for msg in messages:
|
|
@@ -230,7 +224,6 @@ class GrpcDriver(Driver):
|
|
|
230
224
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
231
225
|
set of given message IDs.
|
|
232
226
|
"""
|
|
233
|
-
self._init_run()
|
|
234
227
|
# Pull TaskRes
|
|
235
228
|
res: PullTaskResResponse = self._stub.PullTaskRes(
|
|
236
229
|
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
@@ -25,18 +25,16 @@ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
|
25
25
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
26
26
|
from flwr.common.typing import Run
|
|
27
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
|
-
from flwr.server.superlink.
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
30
|
from .driver import Driver
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class InMemoryDriver(Driver):
|
|
34
|
-
"""`InMemoryDriver` class provides an interface to the
|
|
34
|
+
"""`InMemoryDriver` class provides an interface to the ServerAppIo API.
|
|
35
35
|
|
|
36
36
|
Parameters
|
|
37
37
|
----------
|
|
38
|
-
run_id : int
|
|
39
|
-
The identifier of the run.
|
|
40
38
|
state_factory : StateFactory
|
|
41
39
|
A StateFactory embedding a state that this driver can interface with.
|
|
42
40
|
pull_interval : float (default=0.1)
|
|
@@ -45,18 +43,15 @@ class InMemoryDriver(Driver):
|
|
|
45
43
|
|
|
46
44
|
def __init__(
|
|
47
45
|
self,
|
|
48
|
-
|
|
49
|
-
state_factory: StateFactory,
|
|
46
|
+
state_factory: LinkStateFactory,
|
|
50
47
|
pull_interval: float = 0.1,
|
|
51
48
|
) -> None:
|
|
52
|
-
self._run_id = run_id
|
|
53
49
|
self._run: Optional[Run] = None
|
|
54
50
|
self.state = state_factory.state()
|
|
55
51
|
self.pull_interval = pull_interval
|
|
56
52
|
self.node = Node(node_id=0, anonymous=True)
|
|
57
53
|
|
|
58
54
|
def _check_message(self, message: Message) -> None:
|
|
59
|
-
self._init_run()
|
|
60
55
|
# Check if the message is valid
|
|
61
56
|
if not (
|
|
62
57
|
message.metadata.run_id == cast(Run, self._run).run_id
|
|
@@ -67,19 +62,18 @@ class InMemoryDriver(Driver):
|
|
|
67
62
|
):
|
|
68
63
|
raise ValueError(f"Invalid message: {message}")
|
|
69
64
|
|
|
70
|
-
def
|
|
65
|
+
def init_run(self, run_id: int) -> None:
|
|
71
66
|
"""Initialize the run."""
|
|
72
67
|
if self._run is not None:
|
|
73
68
|
return
|
|
74
|
-
run = self.state.get_run(
|
|
69
|
+
run = self.state.get_run(run_id)
|
|
75
70
|
if run is None:
|
|
76
|
-
raise RuntimeError(f"Cannot find the run with ID: {
|
|
71
|
+
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
|
|
77
72
|
self._run = run
|
|
78
73
|
|
|
79
74
|
@property
|
|
80
75
|
def run(self) -> Run:
|
|
81
76
|
"""Run ID."""
|
|
82
|
-
self._init_run()
|
|
83
77
|
return Run(**vars(cast(Run, self._run)))
|
|
84
78
|
|
|
85
79
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
@@ -95,7 +89,6 @@ class InMemoryDriver(Driver):
|
|
|
95
89
|
This method constructs a new `Message` with given content and metadata.
|
|
96
90
|
The `run_id` and `src_node_id` will be set automatically.
|
|
97
91
|
"""
|
|
98
|
-
self._init_run()
|
|
99
92
|
if ttl:
|
|
100
93
|
warnings.warn(
|
|
101
94
|
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
@@ -119,7 +112,6 @@ class InMemoryDriver(Driver):
|
|
|
119
112
|
|
|
120
113
|
def get_node_ids(self) -> list[int]:
|
|
121
114
|
"""Get node IDs."""
|
|
122
|
-
self._init_run()
|
|
123
115
|
return list(self.state.get_nodes(cast(Run, self._run).run_id))
|
|
124
116
|
|
|
125
117
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -31,10 +31,9 @@ from flwr.common.config import (
|
|
|
31
31
|
get_project_config,
|
|
32
32
|
get_project_dir,
|
|
33
33
|
)
|
|
34
|
-
from flwr.common.constant import
|
|
34
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS
|
|
35
35
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
36
36
|
from flwr.common.object_ref import load_app
|
|
37
|
-
from flwr.common.typing import UserConfig
|
|
38
37
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
39
38
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
40
39
|
CreateRunRequest,
|
|
@@ -48,11 +47,11 @@ from .server_app import LoadServerAppError, ServerApp
|
|
|
48
47
|
|
|
49
48
|
def run(
|
|
50
49
|
driver: Driver,
|
|
50
|
+
context: Context,
|
|
51
51
|
server_app_dir: str,
|
|
52
|
-
server_app_run_config: UserConfig,
|
|
53
52
|
server_app_attr: Optional[str] = None,
|
|
54
53
|
loaded_server_app: Optional[ServerApp] = None,
|
|
55
|
-
) ->
|
|
54
|
+
) -> Context:
|
|
56
55
|
"""Run ServerApp with a given Driver."""
|
|
57
56
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
58
57
|
raise ValueError(
|
|
@@ -78,15 +77,11 @@ def run(
|
|
|
78
77
|
|
|
79
78
|
server_app = _load()
|
|
80
79
|
|
|
81
|
-
# Initialize Context
|
|
82
|
-
context = Context(
|
|
83
|
-
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
|
|
84
|
-
)
|
|
85
|
-
|
|
86
80
|
# Call ServerApp
|
|
87
81
|
server_app(driver=driver, context=context)
|
|
88
82
|
|
|
89
83
|
log(DEBUG, "ServerApp finished running.")
|
|
84
|
+
return context
|
|
90
85
|
|
|
91
86
|
|
|
92
87
|
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
|
@@ -111,18 +106,18 @@ def run_server_app() -> None:
|
|
|
111
106
|
"app by executing `flwr new` and following the prompt."
|
|
112
107
|
)
|
|
113
108
|
|
|
114
|
-
if args.server !=
|
|
109
|
+
if args.server != SERVERAPPIO_API_DEFAULT_ADDRESS:
|
|
115
110
|
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
116
111
|
warn_deprecated_feature(warn)
|
|
117
112
|
|
|
118
|
-
if args.superlink !=
|
|
113
|
+
if args.superlink != SERVERAPPIO_API_DEFAULT_ADDRESS:
|
|
119
114
|
# if `--superlink` also passed, then
|
|
120
115
|
# warn user that this argument overrides what was passed with `--server`
|
|
121
116
|
log(
|
|
122
117
|
WARN,
|
|
123
118
|
"Both `--server` and `--superlink` were passed. "
|
|
124
|
-
"`--server` will be ignored. Connecting to the
|
|
125
|
-
"at %s.",
|
|
119
|
+
"`--server` will be ignored. Connecting to the "
|
|
120
|
+
"SuperLink ServerAppIo API at %s.",
|
|
126
121
|
args.superlink,
|
|
127
122
|
)
|
|
128
123
|
else:
|
|
@@ -175,11 +170,11 @@ def run_server_app() -> None:
|
|
|
175
170
|
if app_path is None:
|
|
176
171
|
# User provided `--run-id`, but not `app_dir`
|
|
177
172
|
driver = GrpcDriver(
|
|
178
|
-
|
|
179
|
-
driver_service_address=args.superlink,
|
|
173
|
+
serverappio_service_address=args.superlink,
|
|
180
174
|
root_certificates=root_certificates,
|
|
181
175
|
)
|
|
182
176
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
177
|
+
driver.init_run(args.run_id)
|
|
183
178
|
run_ = driver.run
|
|
184
179
|
if not run_.fab_hash:
|
|
185
180
|
raise ValueError("FAB hash not provided.")
|
|
@@ -193,12 +188,12 @@ def run_server_app() -> None:
|
|
|
193
188
|
|
|
194
189
|
app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
|
|
195
190
|
config = get_project_config(app_path)
|
|
191
|
+
run_id = run_.run_id
|
|
196
192
|
else:
|
|
197
193
|
# User provided `app_dir`, but not `--run-id`
|
|
198
194
|
# Create run if run_id is not provided
|
|
199
195
|
driver = GrpcDriver(
|
|
200
|
-
|
|
201
|
-
driver_service_address=args.superlink,
|
|
196
|
+
serverappio_service_address=args.superlink,
|
|
202
197
|
root_certificates=root_certificates,
|
|
203
198
|
)
|
|
204
199
|
# Load config from the project directory
|
|
@@ -208,8 +203,9 @@ def run_server_app() -> None:
|
|
|
208
203
|
# Create run
|
|
209
204
|
req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
|
|
210
205
|
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
|
|
211
|
-
#
|
|
212
|
-
driver.
|
|
206
|
+
# Fetch full `Run` using `run_id`
|
|
207
|
+
driver.init_run(res.run_id) # pylint: disable=W0212
|
|
208
|
+
run_id = res.run_id
|
|
213
209
|
|
|
214
210
|
# Obtain server app reference and the run config
|
|
215
211
|
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
@@ -225,11 +221,20 @@ def run_server_app() -> None:
|
|
|
225
221
|
root_certificates,
|
|
226
222
|
)
|
|
227
223
|
|
|
224
|
+
# Initialize Context
|
|
225
|
+
context = Context(
|
|
226
|
+
run_id=run_id,
|
|
227
|
+
node_id=0,
|
|
228
|
+
node_config={},
|
|
229
|
+
state=RecordSet(),
|
|
230
|
+
run_config=server_app_run_config,
|
|
231
|
+
)
|
|
232
|
+
|
|
228
233
|
# Run the ServerApp with the Driver
|
|
229
234
|
run(
|
|
230
235
|
driver=driver,
|
|
236
|
+
context=context,
|
|
231
237
|
server_app_dir=app_path,
|
|
232
|
-
server_app_run_config=server_app_run_config,
|
|
233
238
|
server_app_attr=server_app_attr,
|
|
234
239
|
)
|
|
235
240
|
|
|
@@ -272,13 +277,14 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
272
277
|
)
|
|
273
278
|
parser.add_argument(
|
|
274
279
|
"--server",
|
|
275
|
-
default=
|
|
280
|
+
default=SERVERAPPIO_API_DEFAULT_ADDRESS,
|
|
276
281
|
help="Server address",
|
|
277
282
|
)
|
|
278
283
|
parser.add_argument(
|
|
279
284
|
"--superlink",
|
|
280
|
-
default=
|
|
281
|
-
help="SuperLink
|
|
285
|
+
default=SERVERAPPIO_API_DEFAULT_ADDRESS,
|
|
286
|
+
help="SuperLink ServerAppIo API (gRPC-rere) address "
|
|
287
|
+
"(IPv4, IPv6, or a domain name)",
|
|
282
288
|
)
|
|
283
289
|
parser.add_argument(
|
|
284
290
|
"--run-id",
|
|
@@ -12,17 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower AppIO service."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
from .sqlite_state import SqliteState as SqliteState
|
|
20
|
-
from .state import State as State
|
|
21
|
-
from .state_factory import StateFactory as StateFactory
|
|
18
|
+
from .app import flwr_serverapp as flwr_serverapp
|
|
22
19
|
|
|
23
20
|
__all__ = [
|
|
24
|
-
"
|
|
25
|
-
"SqliteState",
|
|
26
|
-
"State",
|
|
27
|
-
"StateFactory",
|
|
21
|
+
"flwr_serverapp",
|
|
28
22
|
]
|