flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/app.py +2 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +15 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +9 -13
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +36 -8
- flwr/client/grpc_rere_client/connection.py +1 -12
- flwr/client/rest_client/connection.py +3 -0
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +29 -0
- flwr/clientapp/mod/centraldp_mods.py +248 -0
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -4
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +30 -7
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +0 -11
- flwr/common/inflatable_utils.py +1 -1
- flwr/common/logger.py +1 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/telemetry.py +4 -0
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +25 -17
- flwr/proto/appio_pb2.pyi +46 -2
- flwr/proto/clientappio_pb2.py +3 -11
- flwr/proto/clientappio_pb2.pyi +0 -47
- flwr/proto/clientappio_pb2_grpc.py +19 -20
- flwr/proto/clientappio_pb2_grpc.pyi +10 -11
- flwr/proto/control_pb2.py +66 -0
- flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
- flwr/proto/control_pb2_grpc.pyi +106 -0
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +68 -0
- flwr/proto/serverappio_pb2_grpc.pyi +26 -0
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +142 -152
- flwr/server/grid/grpc_grid.py +3 -0
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +157 -146
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
- flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +64 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
- flwr/serverapp/strategy/fedadagrad.py +159 -0
- flwr/serverapp/strategy/fedadam.py +178 -0
- flwr/serverapp/strategy/fedavg.py +320 -0
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +170 -0
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +299 -0
- flwr/simulation/app.py +161 -164
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/app_utils.py +58 -0
- flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
- flwr/supercore/cli/flower_superexec.py +166 -0
- flwr/supercore/constant.py +19 -0
- flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/supercore/grpc_health/__init__.py +3 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +199 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
- flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
- flwr/supernode/cli/flower_supernode.py +3 -0
- flwr/supernode/cli/flwr_clientapp.py +18 -21
- flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
- flwr/supernode/nodestate/nodestate.py +3 -59
- flwr/supernode/runtime/run_clientapp.py +39 -102
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
- flwr/supernode/start_client_internal.py +35 -76
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
- flwr/proto/exec_pb2.py +0 -62
- flwr/proto/exec_pb2_grpc.pyi +0 -93
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -191
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -129
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
|
@@ -12,9 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Control API servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import hashlib
|
|
18
19
|
import time
|
|
19
20
|
from collections.abc import Generator
|
|
20
21
|
from logging import ERROR, INFO
|
|
@@ -22,11 +23,15 @@ from typing import Any, Optional, cast
|
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
25
|
-
from flwr.
|
|
26
|
-
from flwr.common
|
|
26
|
+
from flwr.cli.config_utils import get_fab_metadata
|
|
27
|
+
from flwr.common import Context, RecordDict, now
|
|
28
|
+
from flwr.common.auth_plugin import ControlAuthPlugin
|
|
27
29
|
from flwr.common.constant import (
|
|
28
30
|
FAB_MAX_SIZE,
|
|
29
31
|
LOG_STREAM_INTERVAL,
|
|
32
|
+
NO_ARTIFACT_PROVIDER_MESSAGE,
|
|
33
|
+
NO_USER_AUTH_MESSAGE,
|
|
34
|
+
PULL_UNFINISHED_RUN_MESSAGE,
|
|
30
35
|
RUN_ID_NOT_FOUND_MESSAGE,
|
|
31
36
|
Status,
|
|
32
37
|
SubStatus,
|
|
@@ -37,15 +42,17 @@ from flwr.common.serde import (
|
|
|
37
42
|
run_to_proto,
|
|
38
43
|
user_config_from_proto,
|
|
39
44
|
)
|
|
40
|
-
from flwr.common.typing import Run, RunStatus
|
|
41
|
-
from flwr.proto import
|
|
42
|
-
from flwr.proto.
|
|
45
|
+
from flwr.common.typing import Fab, Run, RunStatus
|
|
46
|
+
from flwr.proto import control_pb2_grpc # pylint: disable=E0611
|
|
47
|
+
from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
43
48
|
GetAuthTokensRequest,
|
|
44
49
|
GetAuthTokensResponse,
|
|
45
50
|
GetLoginDetailsRequest,
|
|
46
51
|
GetLoginDetailsResponse,
|
|
47
52
|
ListRunsRequest,
|
|
48
53
|
ListRunsResponse,
|
|
54
|
+
PullArtifactsRequest,
|
|
55
|
+
PullArtifactsResponse,
|
|
49
56
|
StartRunRequest,
|
|
50
57
|
StartRunResponse,
|
|
51
58
|
StopRunRequest,
|
|
@@ -56,34 +63,37 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
|
56
63
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
57
64
|
from flwr.supercore.ffs import FfsFactory
|
|
58
65
|
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
|
|
66
|
+
from flwr.superlink.artifact_provider import ArtifactProvider
|
|
59
67
|
|
|
60
|
-
from .
|
|
61
|
-
from .executor import Executor
|
|
68
|
+
from .control_user_auth_interceptor import shared_account_info
|
|
62
69
|
|
|
63
70
|
|
|
64
|
-
class
|
|
65
|
-
"""
|
|
71
|
+
class ControlServicer(control_pb2_grpc.ControlServicer):
|
|
72
|
+
"""Control API servicer."""
|
|
66
73
|
|
|
67
74
|
def __init__( # pylint: disable=R0913, R0917
|
|
68
75
|
self,
|
|
69
76
|
linkstate_factory: LinkStateFactory,
|
|
70
77
|
ffs_factory: FfsFactory,
|
|
71
78
|
objectstore_factory: ObjectStoreFactory,
|
|
72
|
-
|
|
73
|
-
auth_plugin: Optional[
|
|
79
|
+
is_simulation: bool,
|
|
80
|
+
auth_plugin: Optional[ControlAuthPlugin] = None,
|
|
81
|
+
artifact_provider: Optional[ArtifactProvider] = None,
|
|
74
82
|
) -> None:
|
|
75
83
|
self.linkstate_factory = linkstate_factory
|
|
76
84
|
self.ffs_factory = ffs_factory
|
|
77
85
|
self.objectstore_factory = objectstore_factory
|
|
78
|
-
self.
|
|
79
|
-
self.executor.initialize(linkstate_factory, ffs_factory)
|
|
86
|
+
self.is_simulation = is_simulation
|
|
80
87
|
self.auth_plugin = auth_plugin
|
|
88
|
+
self.artifact_provider = artifact_provider
|
|
81
89
|
|
|
82
|
-
def StartRun(
|
|
90
|
+
def StartRun( # pylint: disable=too-many-locals
|
|
83
91
|
self, request: StartRunRequest, context: grpc.ServicerContext
|
|
84
92
|
) -> StartRunResponse:
|
|
85
93
|
"""Create run ID."""
|
|
86
|
-
log(INFO, "
|
|
94
|
+
log(INFO, "ControlServicer.StartRun")
|
|
95
|
+
state = self.linkstate_factory.state()
|
|
96
|
+
ffs = self.ffs_factory.ffs()
|
|
87
97
|
|
|
88
98
|
if len(request.fab.content) > FAB_MAX_SIZE:
|
|
89
99
|
log(
|
|
@@ -94,24 +104,69 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
94
104
|
return StartRunResponse()
|
|
95
105
|
|
|
96
106
|
flwr_aid = shared_account_info.get().flwr_aid if self.auth_plugin else None
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
107
|
+
override_config = user_config_from_proto(request.override_config)
|
|
108
|
+
federation_options = config_record_from_proto(request.federation_options)
|
|
109
|
+
fab_file = request.fab.content
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
# Check that num-supernodes is set
|
|
113
|
+
if self.is_simulation and "num-supernodes" not in federation_options:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"Federation options doesn't contain key `num-supernodes`."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Create run
|
|
119
|
+
fab = Fab(hashlib.sha256(fab_file).hexdigest(), fab_file)
|
|
120
|
+
fab_hash = ffs.put(fab.content, {})
|
|
121
|
+
if fab_hash != fab.hash_str:
|
|
122
|
+
raise RuntimeError(
|
|
123
|
+
f"FAB ({fab.hash_str}) hash from request doesn't match contents"
|
|
124
|
+
)
|
|
125
|
+
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
126
|
+
|
|
127
|
+
run_id = state.create_run(
|
|
128
|
+
fab_id,
|
|
129
|
+
fab_version,
|
|
130
|
+
fab_hash,
|
|
131
|
+
override_config,
|
|
132
|
+
federation_options,
|
|
133
|
+
flwr_aid,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Initialize node config
|
|
137
|
+
node_config = {}
|
|
138
|
+
if self.artifact_provider is not None:
|
|
139
|
+
node_config = {
|
|
140
|
+
"output_dir": self.artifact_provider.output_dir,
|
|
141
|
+
"tmp_dir": self.artifact_provider.tmp_dir,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Create an empty context for the Run
|
|
145
|
+
context = Context(
|
|
146
|
+
run_id=run_id,
|
|
147
|
+
node_id=0,
|
|
148
|
+
# Dict is invariant in mypy
|
|
149
|
+
node_config=node_config, # type: ignore[arg-type]
|
|
150
|
+
state=RecordDict(),
|
|
151
|
+
run_config={},
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Register the context at the LinkState
|
|
155
|
+
state.set_serverapp_context(run_id=run_id, context=context)
|
|
103
156
|
|
|
104
|
-
|
|
105
|
-
|
|
157
|
+
# pylint: disable-next=broad-except
|
|
158
|
+
except Exception as e:
|
|
159
|
+
log(ERROR, "Could not start run: %s", str(e))
|
|
106
160
|
return StartRunResponse()
|
|
107
161
|
|
|
162
|
+
log(INFO, "Created run %s", str(run_id))
|
|
108
163
|
return StartRunResponse(run_id=run_id)
|
|
109
164
|
|
|
110
165
|
def StreamLogs( # pylint: disable=C0103
|
|
111
166
|
self, request: StreamLogsRequest, context: grpc.ServicerContext
|
|
112
167
|
) -> Generator[StreamLogsResponse, Any, None]:
|
|
113
168
|
"""Get logs."""
|
|
114
|
-
log(INFO, "
|
|
169
|
+
log(INFO, "ControlServicer.StreamLogs")
|
|
115
170
|
state = self.linkstate_factory.state()
|
|
116
171
|
|
|
117
172
|
# Retrieve run ID and run
|
|
@@ -158,7 +213,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
158
213
|
self, request: ListRunsRequest, context: grpc.ServicerContext
|
|
159
214
|
) -> ListRunsResponse:
|
|
160
215
|
"""Handle `flwr ls` command."""
|
|
161
|
-
log(INFO, "
|
|
216
|
+
log(INFO, "ControlServicer.List")
|
|
162
217
|
state = self.linkstate_factory.state()
|
|
163
218
|
|
|
164
219
|
# Build a set of run IDs for `flwr ls --runs`
|
|
@@ -204,7 +259,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
204
259
|
self, request: StopRunRequest, context: grpc.ServicerContext
|
|
205
260
|
) -> StopRunResponse:
|
|
206
261
|
"""Stop a given run ID."""
|
|
207
|
-
log(INFO, "
|
|
262
|
+
log(INFO, "ControlServicer.StopRun")
|
|
208
263
|
state = self.linkstate_factory.state()
|
|
209
264
|
|
|
210
265
|
# Retrieve run ID and run
|
|
@@ -249,11 +304,11 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
249
304
|
self, request: GetLoginDetailsRequest, context: grpc.ServicerContext
|
|
250
305
|
) -> GetLoginDetailsResponse:
|
|
251
306
|
"""Start login."""
|
|
252
|
-
log(INFO, "
|
|
307
|
+
log(INFO, "ControlServicer.GetLoginDetails")
|
|
253
308
|
if self.auth_plugin is None:
|
|
254
309
|
context.abort(
|
|
255
310
|
grpc.StatusCode.UNIMPLEMENTED,
|
|
256
|
-
|
|
311
|
+
NO_USER_AUTH_MESSAGE,
|
|
257
312
|
)
|
|
258
313
|
raise grpc.RpcError() # This line is unreachable
|
|
259
314
|
|
|
@@ -276,11 +331,11 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
276
331
|
self, request: GetAuthTokensRequest, context: grpc.ServicerContext
|
|
277
332
|
) -> GetAuthTokensResponse:
|
|
278
333
|
"""Get auth token."""
|
|
279
|
-
log(INFO, "
|
|
334
|
+
log(INFO, "ControlServicer.GetAuthTokens")
|
|
280
335
|
if self.auth_plugin is None:
|
|
281
336
|
context.abort(
|
|
282
337
|
grpc.StatusCode.UNIMPLEMENTED,
|
|
283
|
-
|
|
338
|
+
NO_USER_AUTH_MESSAGE,
|
|
284
339
|
)
|
|
285
340
|
raise grpc.RpcError() # This line is unreachable
|
|
286
341
|
|
|
@@ -296,6 +351,47 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
296
351
|
refresh_token=credentials.refresh_token,
|
|
297
352
|
)
|
|
298
353
|
|
|
354
|
+
def PullArtifacts(
|
|
355
|
+
self, request: PullArtifactsRequest, context: grpc.ServicerContext
|
|
356
|
+
) -> PullArtifactsResponse:
|
|
357
|
+
"""Pull artifacts for a given run ID."""
|
|
358
|
+
log(INFO, "ControlServicer.PullArtifacts")
|
|
359
|
+
|
|
360
|
+
# Check if artifact provider is configured
|
|
361
|
+
if self.artifact_provider is None:
|
|
362
|
+
context.abort(
|
|
363
|
+
grpc.StatusCode.UNIMPLEMENTED,
|
|
364
|
+
NO_ARTIFACT_PROVIDER_MESSAGE,
|
|
365
|
+
)
|
|
366
|
+
raise grpc.RpcError() # This line is unreachable
|
|
367
|
+
|
|
368
|
+
# Init link state
|
|
369
|
+
state = self.linkstate_factory.state()
|
|
370
|
+
|
|
371
|
+
# Retrieve run ID and run
|
|
372
|
+
run_id = request.run_id
|
|
373
|
+
run = state.get_run(run_id)
|
|
374
|
+
|
|
375
|
+
# Exit if `run_id` not found
|
|
376
|
+
if not run:
|
|
377
|
+
context.abort(grpc.StatusCode.NOT_FOUND, RUN_ID_NOT_FOUND_MESSAGE)
|
|
378
|
+
raise grpc.RpcError() # This line is unreachable
|
|
379
|
+
|
|
380
|
+
# Exit if the run is not finished yet
|
|
381
|
+
if run.status.status != Status.FINISHED:
|
|
382
|
+
context.abort(
|
|
383
|
+
grpc.StatusCode.FAILED_PRECONDITION, PULL_UNFINISHED_RUN_MESSAGE
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Check if `flwr_aid` matches the run's `flwr_aid` when user auth is enabled
|
|
387
|
+
if self.auth_plugin:
|
|
388
|
+
flwr_aid = shared_account_info.get().flwr_aid
|
|
389
|
+
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=run, context=context)
|
|
390
|
+
|
|
391
|
+
# Call artifact provider
|
|
392
|
+
download_url = self.artifact_provider.get_url(run_id)
|
|
393
|
+
return PullArtifactsResponse(url=download_url)
|
|
394
|
+
|
|
299
395
|
|
|
300
396
|
def _create_list_runs_response(
|
|
301
397
|
run_ids: set[int], state: LinkState, store: ObjectStore
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower
|
|
15
|
+
"""Flower Control API interceptor."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import contextvars
|
|
@@ -20,9 +20,9 @@ from typing import Any, Callable, Union
|
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
22
22
|
|
|
23
|
-
from flwr.common.auth_plugin import
|
|
23
|
+
from flwr.common.auth_plugin import ControlAuthPlugin, ControlAuthzPlugin
|
|
24
24
|
from flwr.common.typing import AccountInfo
|
|
25
|
-
from flwr.proto.
|
|
25
|
+
from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
26
26
|
GetAuthTokensRequest,
|
|
27
27
|
GetAuthTokensResponse,
|
|
28
28
|
GetLoginDetailsRequest,
|
|
@@ -50,13 +50,13 @@ shared_account_info: contextvars.ContextVar[AccountInfo] = contextvars.ContextVa
|
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
class
|
|
54
|
-
"""
|
|
53
|
+
class ControlUserAuthInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
54
|
+
"""Control API interceptor for user authentication."""
|
|
55
55
|
|
|
56
56
|
def __init__(
|
|
57
57
|
self,
|
|
58
|
-
auth_plugin:
|
|
59
|
-
authz_plugin:
|
|
58
|
+
auth_plugin: ControlAuthPlugin,
|
|
59
|
+
authz_plugin: ControlAuthzPlugin,
|
|
60
60
|
):
|
|
61
61
|
self.auth_plugin = auth_plugin
|
|
62
62
|
self.authz_plugin = authz_plugin
|
|
@@ -72,12 +72,12 @@ class ExecUserAuthInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
72
72
|
by validating auth metadata sent by the user. Continue RPC call if user is
|
|
73
73
|
authenticated, else, terminate RPC call by setting context to abort.
|
|
74
74
|
"""
|
|
75
|
-
# Only apply to
|
|
76
|
-
if not handler_call_details.method.startswith("/flwr.proto.
|
|
75
|
+
# Only apply to Control service
|
|
76
|
+
if not handler_call_details.method.startswith("/flwr.proto.Control/"):
|
|
77
77
|
return continuation(handler_call_details)
|
|
78
78
|
|
|
79
79
|
# One of the method handlers in
|
|
80
|
-
# `flwr.
|
|
80
|
+
# `flwr.superlink.servicer.control.ControlServicer`
|
|
81
81
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
82
82
|
return self._generic_auth_unary_method_handler(method_handler)
|
|
83
83
|
|
|
@@ -41,6 +41,7 @@ from flwr.common.constant import (
|
|
|
41
41
|
)
|
|
42
42
|
from flwr.common.exit import ExitCode, flwr_exit
|
|
43
43
|
from flwr.common.logger import log
|
|
44
|
+
from flwr.supercore.grpc_health import add_args_health
|
|
44
45
|
from flwr.supernode.start_client_internal import start_client_internal
|
|
45
46
|
|
|
46
47
|
|
|
@@ -79,6 +80,7 @@ def flower_supernode() -> None:
|
|
|
79
80
|
flwr_path=args.flwr_dir,
|
|
80
81
|
isolation=args.isolation,
|
|
81
82
|
clientappio_api_address=args.clientappio_api_address,
|
|
83
|
+
health_server_address=args.health_server_address,
|
|
82
84
|
)
|
|
83
85
|
|
|
84
86
|
|
|
@@ -118,6 +120,7 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
|
118
120
|
help="ClientAppIo API (gRPC) server address (IPv4, IPv6, or a domain name). "
|
|
119
121
|
f"By default, it is set to {CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS}.",
|
|
120
122
|
)
|
|
123
|
+
add_args_health(parser)
|
|
121
124
|
|
|
122
125
|
return parser
|
|
123
126
|
|
|
@@ -19,9 +19,12 @@ import argparse
|
|
|
19
19
|
from logging import DEBUG, INFO
|
|
20
20
|
|
|
21
21
|
from flwr.common.args import add_args_flwr_app_common
|
|
22
|
-
from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS
|
|
22
|
+
from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ExecPluginType
|
|
23
23
|
from flwr.common.exit import ExitCode, flwr_exit
|
|
24
24
|
from flwr.common.logger import log
|
|
25
|
+
from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
|
|
26
|
+
from flwr.supercore.superexec.plugin import ClientAppExecPlugin
|
|
27
|
+
from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
|
|
25
28
|
from flwr.supercore.utils import mask_string
|
|
26
29
|
from flwr.supernode.runtime.run_clientapp import run_clientapp
|
|
27
30
|
|
|
@@ -35,6 +38,20 @@ def flwr_clientapp() -> None:
|
|
|
35
38
|
"flwr-clientapp does not support TLS yet.",
|
|
36
39
|
)
|
|
37
40
|
|
|
41
|
+
# Disallow long-running `flwr-clientapp` processes
|
|
42
|
+
if args.token is None:
|
|
43
|
+
run_with_deprecation_warning(
|
|
44
|
+
cmd="flwr-clientapp",
|
|
45
|
+
plugin_type=ExecPluginType.CLIENT_APP,
|
|
46
|
+
plugin_class=ClientAppExecPlugin,
|
|
47
|
+
stub_class=ClientAppIoStub,
|
|
48
|
+
appio_api_address=args.clientappio_api_address,
|
|
49
|
+
flwr_dir=args.flwr_dir,
|
|
50
|
+
parent_pid=args.parent_pid,
|
|
51
|
+
warn_run_once=args.run_once,
|
|
52
|
+
)
|
|
53
|
+
return
|
|
54
|
+
|
|
38
55
|
log(INFO, "Start `flwr-clientapp` process")
|
|
39
56
|
log(
|
|
40
57
|
DEBUG,
|
|
@@ -45,7 +62,6 @@ def flwr_clientapp() -> None:
|
|
|
45
62
|
)
|
|
46
63
|
run_clientapp(
|
|
47
64
|
clientappio_api_address=args.clientappio_api_address,
|
|
48
|
-
run_once=(args.token is not None) or args.run_once,
|
|
49
65
|
token=args.token,
|
|
50
66
|
flwr_dir=args.flwr_dir,
|
|
51
67
|
certificates=None,
|
|
@@ -65,24 +81,5 @@ def _parse_args_run_flwr_clientapp() -> argparse.ArgumentParser:
|
|
|
65
81
|
help="Address of SuperNode's ClientAppIo API (IPv4, IPv6, or a domain name)."
|
|
66
82
|
f"By default, it is set to {CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS}.",
|
|
67
83
|
)
|
|
68
|
-
parser.add_argument(
|
|
69
|
-
"--token",
|
|
70
|
-
type=str,
|
|
71
|
-
required=False,
|
|
72
|
-
help="Unique token generated by SuperNode for each ClientApp execution",
|
|
73
|
-
)
|
|
74
|
-
parser.add_argument(
|
|
75
|
-
"--parent-pid",
|
|
76
|
-
type=int,
|
|
77
|
-
default=None,
|
|
78
|
-
help="The PID of the parent process. When set, the process will terminate "
|
|
79
|
-
"when the parent process exits.",
|
|
80
|
-
)
|
|
81
|
-
parser.add_argument(
|
|
82
|
-
"--run-once",
|
|
83
|
-
action="store_true",
|
|
84
|
-
help="When set, this process will start a single ClientApp for a pending "
|
|
85
|
-
"message. If there is no pending message, the process will exit.",
|
|
86
|
-
)
|
|
87
84
|
add_args_flwr_app_common(parser=parser)
|
|
88
85
|
return parser
|
|
@@ -171,12 +171,12 @@ class InMemoryNodeState(NodeState): # pylint: disable=too-many-instance-attribu
|
|
|
171
171
|
ret -= set(self.token_store.keys())
|
|
172
172
|
return list(ret)
|
|
173
173
|
|
|
174
|
-
def create_token(self, run_id: int) -> str:
|
|
174
|
+
def create_token(self, run_id: int) -> Optional[str]:
|
|
175
175
|
"""Create a token for the given run ID."""
|
|
176
176
|
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
177
177
|
with self.lock_token_store:
|
|
178
178
|
if run_id in self.token_store:
|
|
179
|
-
|
|
179
|
+
return None # Token already created for this run ID
|
|
180
180
|
self.token_store[run_id] = token
|
|
181
181
|
self.token_to_run_id[token] = run_id
|
|
182
182
|
return token
|
|
@@ -15,15 +15,16 @@
|
|
|
15
15
|
"""Abstract base class NodeState."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from abc import
|
|
18
|
+
from abc import abstractmethod
|
|
19
19
|
from collections.abc import Sequence
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
22
|
from flwr.common import Context, Message
|
|
23
23
|
from flwr.common.typing import Run
|
|
24
|
+
from flwr.supercore.corestate import CoreState
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
class NodeState(
|
|
27
|
+
class NodeState(CoreState):
|
|
27
28
|
"""Abstract base class for node state."""
|
|
28
29
|
|
|
29
30
|
@abstractmethod
|
|
@@ -168,60 +169,3 @@ class NodeState(ABC):
|
|
|
168
169
|
Sequence[int]
|
|
169
170
|
Sequence of run IDs with pending messages.
|
|
170
171
|
"""
|
|
171
|
-
|
|
172
|
-
@abstractmethod
|
|
173
|
-
def create_token(self, run_id: int) -> str:
|
|
174
|
-
"""Create a token for the given run ID.
|
|
175
|
-
|
|
176
|
-
Parameters
|
|
177
|
-
----------
|
|
178
|
-
run_id : int
|
|
179
|
-
The ID of the run for which to create a token.
|
|
180
|
-
|
|
181
|
-
Returns
|
|
182
|
-
-------
|
|
183
|
-
str
|
|
184
|
-
A unique token associated with the run ID.
|
|
185
|
-
"""
|
|
186
|
-
|
|
187
|
-
@abstractmethod
|
|
188
|
-
def verify_token(self, run_id: int, token: str) -> bool:
|
|
189
|
-
"""Verify a token for the given run ID.
|
|
190
|
-
|
|
191
|
-
Parameters
|
|
192
|
-
----------
|
|
193
|
-
run_id : int
|
|
194
|
-
The ID of the run for which to verify the token.
|
|
195
|
-
token : str
|
|
196
|
-
The token to verify.
|
|
197
|
-
|
|
198
|
-
Returns
|
|
199
|
-
-------
|
|
200
|
-
bool
|
|
201
|
-
True if the token is valid for the run ID, False otherwise.
|
|
202
|
-
"""
|
|
203
|
-
|
|
204
|
-
@abstractmethod
|
|
205
|
-
def delete_token(self, run_id: int) -> None:
|
|
206
|
-
"""Delete the token for the given run ID.
|
|
207
|
-
|
|
208
|
-
Parameters
|
|
209
|
-
----------
|
|
210
|
-
run_id : int
|
|
211
|
-
The ID of the run for which to delete the token.
|
|
212
|
-
"""
|
|
213
|
-
|
|
214
|
-
@abstractmethod
|
|
215
|
-
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
216
|
-
"""Get the run ID associated with a given token.
|
|
217
|
-
|
|
218
|
-
Parameters
|
|
219
|
-
----------
|
|
220
|
-
token : str
|
|
221
|
-
The token to look up.
|
|
222
|
-
|
|
223
|
-
Returns
|
|
224
|
-
-------
|
|
225
|
-
Optional[int]
|
|
226
|
-
The run ID if the token is valid, otherwise None.
|
|
227
|
-
"""
|