flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +94 -59
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +48 -49
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +38 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +15 -8
- flwr/client/grpc_rere_client/connection.py +142 -97
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +176 -103
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +39 -8
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/grpc.py +12 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_protobuf_utils.py +141 -0
- flwr/common/inflatable_utils.py +508 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +402 -0
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -211
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +28 -185
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +24 -27
- flwr/proto/fleet_pb2.pyi +19 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +9 -23
- flwr/proto/serverappio_pb2.pyi +0 -110
- flwr/proto/serverappio_pb2_grpc.py +177 -72
- flwr/proto/serverappio_pb2_grpc.pyi +75 -33
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +69 -187
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +6 -2
- flwr/server/grid/grpc_grid.py +148 -41
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +45 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +9 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +25 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +170 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
- flwr/superexec/deployment.py +7 -4
- flwr/superexec/exec_event_log_interceptor.py +8 -4
- flwr/superexec/exec_grpc.py +25 -5
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +135 -24
- flwr/superexec/exec_user_auth_interceptor.py +45 -8
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -3
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/supernode/cli/__init__.py +24 -0
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
- flwr/supernode/cli/flwr_clientapp.py +88 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
- flwr/supernode/nodestate/nodestate.py +227 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +22 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
- flwr/supernode/start_client_internal.py +589 -0
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
- flwr/client/clientapp/clientappio_servicer.py +0 -244
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
flwr/cli/run/run.py
CHANGED
|
@@ -24,9 +24,8 @@ from typing import Annotated, Any, Optional
|
|
|
24
24
|
import typer
|
|
25
25
|
from rich.console import Console
|
|
26
26
|
|
|
27
|
-
from flwr.cli.build import
|
|
27
|
+
from flwr.cli.build import build_fab, get_fab_filename
|
|
28
28
|
from flwr.cli.config_utils import (
|
|
29
|
-
get_fab_metadata,
|
|
30
29
|
load_and_validate,
|
|
31
30
|
process_loaded_project_config,
|
|
32
31
|
validate_federation_in_project_config,
|
|
@@ -34,6 +33,7 @@ from flwr.cli.config_utils import (
|
|
|
34
33
|
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
|
|
35
34
|
from flwr.common.config import (
|
|
36
35
|
flatten_dict,
|
|
36
|
+
get_metadata_from_config,
|
|
37
37
|
parse_config_args,
|
|
38
38
|
user_config_to_configrecord,
|
|
39
39
|
)
|
|
@@ -45,11 +45,7 @@ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
|
45
45
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
46
46
|
|
|
47
47
|
from ..log import start_stream
|
|
48
|
-
from ..utils import
|
|
49
|
-
init_channel,
|
|
50
|
-
try_obtain_cli_auth_plugin,
|
|
51
|
-
unauthenticated_exc_handler,
|
|
52
|
-
)
|
|
48
|
+
from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
|
53
49
|
|
|
54
50
|
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
|
|
55
51
|
|
|
@@ -154,54 +150,57 @@ def _run_with_exec_api(
|
|
|
154
150
|
stream: bool,
|
|
155
151
|
output_format: str,
|
|
156
152
|
) -> None:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
content = Path(fab_path).read_bytes()
|
|
163
|
-
fab_id, fab_version = get_fab_metadata(Path(fab_path))
|
|
153
|
+
channel = None
|
|
154
|
+
try:
|
|
155
|
+
auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
|
|
156
|
+
channel = init_channel(app, federation_config, auth_plugin)
|
|
157
|
+
stub = ExecStub(channel)
|
|
164
158
|
|
|
165
|
-
|
|
166
|
-
|
|
159
|
+
fab_bytes, fab_hash, config = build_fab(app)
|
|
160
|
+
fab_id, fab_version = get_metadata_from_config(config)
|
|
167
161
|
|
|
168
|
-
|
|
162
|
+
fab = Fab(fab_hash, fab_bytes)
|
|
169
163
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
164
|
+
# Construct a `ConfigRecord` out of a flattened `UserConfig`
|
|
165
|
+
fed_config = flatten_dict(federation_config.get("options", {}))
|
|
166
|
+
c_record = user_config_to_configrecord(fed_config)
|
|
173
167
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
)
|
|
179
|
-
with unauthenticated_exc_handler():
|
|
180
|
-
res = stub.StartRun(req)
|
|
181
|
-
|
|
182
|
-
if res.HasField("run_id"):
|
|
183
|
-
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
|
184
|
-
else:
|
|
185
|
-
typer.secho("❌ Failed to start run", fg=typer.colors.RED)
|
|
186
|
-
raise typer.Exit(code=1)
|
|
187
|
-
|
|
188
|
-
if output_format == CliOutputFormat.JSON:
|
|
189
|
-
run_output = json.dumps(
|
|
190
|
-
{
|
|
191
|
-
"success": res.HasField("run_id"),
|
|
192
|
-
"run-id": res.run_id if res.HasField("run_id") else None,
|
|
193
|
-
"fab-id": fab_id,
|
|
194
|
-
"fab-name": fab_id.rsplit("/", maxsplit=1)[-1],
|
|
195
|
-
"fab-version": fab_version,
|
|
196
|
-
"fab-hash": fab_hash[:8],
|
|
197
|
-
"fab-filename": fab_path,
|
|
198
|
-
}
|
|
168
|
+
req = StartRunRequest(
|
|
169
|
+
fab=fab_to_proto(fab),
|
|
170
|
+
override_config=user_config_to_proto(parse_config_args(config_overrides)),
|
|
171
|
+
federation_options=config_record_to_proto(c_record),
|
|
199
172
|
)
|
|
200
|
-
|
|
201
|
-
|
|
173
|
+
with flwr_cli_grpc_exc_handler():
|
|
174
|
+
res = stub.StartRun(req)
|
|
202
175
|
|
|
203
|
-
|
|
204
|
-
|
|
176
|
+
if res.HasField("run_id"):
|
|
177
|
+
typer.secho(
|
|
178
|
+
f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
typer.secho("❌ Failed to start run", fg=typer.colors.RED)
|
|
182
|
+
raise typer.Exit(code=1)
|
|
183
|
+
|
|
184
|
+
if output_format == CliOutputFormat.JSON:
|
|
185
|
+
run_output = json.dumps(
|
|
186
|
+
{
|
|
187
|
+
"success": res.HasField("run_id"),
|
|
188
|
+
"run-id": res.run_id if res.HasField("run_id") else None,
|
|
189
|
+
"fab-id": fab_id,
|
|
190
|
+
"fab-name": fab_id.rsplit("/", maxsplit=1)[-1],
|
|
191
|
+
"fab-version": fab_version,
|
|
192
|
+
"fab-hash": fab_hash[:8],
|
|
193
|
+
"fab-filename": get_fab_filename(config, fab_hash),
|
|
194
|
+
}
|
|
195
|
+
)
|
|
196
|
+
restore_output()
|
|
197
|
+
Console().print_json(run_output)
|
|
198
|
+
|
|
199
|
+
if stream:
|
|
200
|
+
start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
|
|
201
|
+
finally:
|
|
202
|
+
if channel:
|
|
203
|
+
channel.close()
|
|
205
204
|
|
|
206
205
|
|
|
207
206
|
def _run_without_exec_api(
|
flwr/cli/stop.py
CHANGED
|
@@ -35,7 +35,7 @@ from flwr.common.logger import print_json_error, redirect_output, restore_output
|
|
|
35
35
|
from flwr.proto.exec_pb2 import StopRunRequest, StopRunResponse # pylint: disable=E0611
|
|
36
36
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
37
37
|
|
|
38
|
-
from .utils import init_channel, try_obtain_cli_auth_plugin
|
|
38
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def stop( # pylint: disable=R0914
|
|
@@ -122,7 +122,7 @@ def stop( # pylint: disable=R0914
|
|
|
122
122
|
|
|
123
123
|
def _stop_run(stub: ExecStub, run_id: int, output_format: str) -> None:
|
|
124
124
|
"""Stop a run."""
|
|
125
|
-
with
|
|
125
|
+
with flwr_cli_grpc_exc_handler():
|
|
126
126
|
response: StopRunResponse = stub.StopRun(request=StopRunRequest(run_id=run_id))
|
|
127
127
|
if response.success:
|
|
128
128
|
typer.secho(f"✅ Run {run_id} successfully stopped.", fg=typer.colors.GREEN)
|
flwr/cli/utils.py
CHANGED
|
@@ -28,7 +28,12 @@ import typer
|
|
|
28
28
|
|
|
29
29
|
from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
|
|
30
30
|
from flwr.common.auth_plugin import CliAuthPlugin
|
|
31
|
-
from flwr.common.constant import
|
|
31
|
+
from flwr.common.constant import (
|
|
32
|
+
AUTH_TYPE_JSON_KEY,
|
|
33
|
+
CREDENTIALS_DIR,
|
|
34
|
+
FLWR_DIR,
|
|
35
|
+
RUN_ID_NOT_FOUND_MESSAGE,
|
|
36
|
+
)
|
|
32
37
|
from flwr.common.grpc import (
|
|
33
38
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
34
39
|
create_channel,
|
|
@@ -288,11 +293,12 @@ def init_channel(
|
|
|
288
293
|
|
|
289
294
|
|
|
290
295
|
@contextmanager
|
|
291
|
-
def
|
|
292
|
-
"""Context manager to handle gRPC
|
|
296
|
+
def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
297
|
+
"""Context manager to handle specific gRPC errors.
|
|
293
298
|
|
|
294
|
-
It catches grpc.RpcError exceptions with UNAUTHENTICATED
|
|
295
|
-
and
|
|
299
|
+
It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED,
|
|
300
|
+
UNAVAILABLE, and PERMISSION_DENIED statuses, informs the user, and exits the
|
|
301
|
+
application. All other exceptions will be allowed to escape.
|
|
296
302
|
"""
|
|
297
303
|
try:
|
|
298
304
|
yield
|
|
@@ -312,4 +318,31 @@ def unauthenticated_exc_handler() -> Iterator[None]:
|
|
|
312
318
|
bold=True,
|
|
313
319
|
)
|
|
314
320
|
raise typer.Exit(code=1) from None
|
|
321
|
+
if e.code() == grpc.StatusCode.PERMISSION_DENIED:
|
|
322
|
+
typer.secho(
|
|
323
|
+
"❌ Permission denied.",
|
|
324
|
+
fg=typer.colors.RED,
|
|
325
|
+
bold=True,
|
|
326
|
+
)
|
|
327
|
+
# pylint: disable=E1101
|
|
328
|
+
typer.secho(e.details(), fg=typer.colors.RED, bold=True)
|
|
329
|
+
raise typer.Exit(code=1) from None
|
|
330
|
+
if e.code() == grpc.StatusCode.UNAVAILABLE:
|
|
331
|
+
typer.secho(
|
|
332
|
+
"Connection to the SuperLink is unavailable. Please check your network "
|
|
333
|
+
"connection and 'address' in the federation configuration.",
|
|
334
|
+
fg=typer.colors.RED,
|
|
335
|
+
bold=True,
|
|
336
|
+
)
|
|
337
|
+
raise typer.Exit(code=1) from None
|
|
338
|
+
if (
|
|
339
|
+
e.code() == grpc.StatusCode.NOT_FOUND
|
|
340
|
+
and e.details() == RUN_ID_NOT_FOUND_MESSAGE
|
|
341
|
+
):
|
|
342
|
+
typer.secho(
|
|
343
|
+
"❌ Run ID not found.",
|
|
344
|
+
fg=typer.colors.RED,
|
|
345
|
+
bold=True,
|
|
346
|
+
)
|
|
347
|
+
raise typer.Exit(code=1) from None
|
|
315
348
|
raise
|
flwr/client/__init__.py
CHANGED
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
"""Flower client."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .app import start_client as start_client
|
|
19
|
-
from .app import start_numpy_client as start_numpy_client
|
|
18
|
+
from ..compat.client.app import start_client as start_client # Deprecated
|
|
19
|
+
from ..compat.client.app import start_numpy_client as start_numpy_client # Deprecated
|
|
20
20
|
from .client import Client as Client
|
|
21
21
|
from .client_app import ClientApp as ClientApp
|
|
22
22
|
from .numpy_client import NumPyClient as NumPyClient
|
flwr/client/client_app.py
CHANGED
|
@@ -20,6 +20,7 @@ from collections.abc import Iterator
|
|
|
20
20
|
from contextlib import contextmanager
|
|
21
21
|
from typing import Callable, Optional
|
|
22
22
|
|
|
23
|
+
from flwr.app.metadata import validate_message_type
|
|
23
24
|
from flwr.client.client import Client
|
|
24
25
|
from flwr.client.message_handler.message_handler import (
|
|
25
26
|
handle_legacy_message_from_msgtype,
|
|
@@ -28,7 +29,6 @@ from flwr.client.mod.utils import make_ffn
|
|
|
28
29
|
from flwr.client.typing import ClientFnExt, Mod
|
|
29
30
|
from flwr.common import Context, Message, MessageType
|
|
30
31
|
from flwr.common.logger import warn_deprecated_feature
|
|
31
|
-
from flwr.common.message import validate_message_type
|
|
32
32
|
|
|
33
33
|
from .typing import ClientAppCallable
|
|
34
34
|
|
|
@@ -29,6 +29,7 @@ from flwr.common.logger import log
|
|
|
29
29
|
from flwr.common.message import Message
|
|
30
30
|
from flwr.common.retry_invoker import RetryInvoker
|
|
31
31
|
from flwr.common.typing import Fab, Run
|
|
32
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
@contextmanager
|
|
@@ -43,12 +44,15 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
|
|
|
43
44
|
] = None,
|
|
44
45
|
) -> Iterator[
|
|
45
46
|
tuple[
|
|
46
|
-
Callable[[], Optional[Message]],
|
|
47
|
-
Callable[[Message],
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
47
|
+
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
48
|
+
Callable[[Message, ObjectTree], set[str]],
|
|
49
|
+
Callable[[], Optional[int]],
|
|
50
|
+
Callable[[], None],
|
|
51
|
+
Callable[[int], Run],
|
|
52
|
+
Callable[[str, int], Fab],
|
|
53
|
+
Callable[[int, str], bytes],
|
|
54
|
+
Callable[[int, str, bytes], None],
|
|
55
|
+
Callable[[int, str], None],
|
|
52
56
|
]
|
|
53
57
|
]:
|
|
54
58
|
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
|
|
@@ -77,12 +81,15 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
|
|
|
77
81
|
|
|
78
82
|
Returns
|
|
79
83
|
-------
|
|
80
|
-
receive : Callable
|
|
81
|
-
send : Callable
|
|
84
|
+
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
85
|
+
send : Callable[[Message, ObjectTree], set[str]]
|
|
82
86
|
create_node : Optional[Callable]
|
|
83
87
|
delete_node : Optional[Callable]
|
|
84
88
|
get_run : Optional[Callable]
|
|
85
89
|
get_fab : Optional[Callable]
|
|
90
|
+
pull_object : Callable[[str], bytes]
|
|
91
|
+
push_object : Callable[[str, bytes], None]
|
|
92
|
+
confirm_message_received : Callable[[str], None]
|
|
86
93
|
"""
|
|
87
94
|
if authentication_keys is not None:
|
|
88
95
|
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
@@ -15,11 +15,8 @@
|
|
|
15
15
|
"""Contextmanager for a gRPC request-response channel to the Flower server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import random
|
|
19
|
-
import threading
|
|
20
18
|
from collections.abc import Iterator, Sequence
|
|
21
19
|
from contextlib import contextmanager
|
|
22
|
-
from copy import copy
|
|
23
20
|
from logging import ERROR
|
|
24
21
|
from pathlib import Path
|
|
25
22
|
from typing import Callable, Optional, Union, cast
|
|
@@ -27,19 +24,18 @@ from typing import Callable, Optional, Union, cast
|
|
|
27
24
|
import grpc
|
|
28
25
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
29
26
|
|
|
30
|
-
from flwr.client.heartbeat import start_ping_loop
|
|
31
|
-
from flwr.client.message_handler.message_handler import validate_out_message
|
|
32
27
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
33
|
-
from flwr.common.constant import
|
|
34
|
-
PING_BASE_MULTIPLIER,
|
|
35
|
-
PING_CALL_TIMEOUT,
|
|
36
|
-
PING_DEFAULT_INTERVAL,
|
|
37
|
-
PING_RANDOM_RANGE,
|
|
38
|
-
)
|
|
28
|
+
from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
|
|
39
29
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
30
|
+
from flwr.common.heartbeat import HeartbeatSender
|
|
31
|
+
from flwr.common.inflatable_protobuf_utils import (
|
|
32
|
+
make_confirm_message_received_fn_protobuf,
|
|
33
|
+
make_pull_object_fn_protobuf,
|
|
34
|
+
make_push_object_fn_protobuf,
|
|
35
|
+
)
|
|
40
36
|
from flwr.common.logger import log
|
|
41
|
-
from flwr.common.message import Message,
|
|
42
|
-
from flwr.common.retry_invoker import RetryInvoker
|
|
37
|
+
from flwr.common.message import Message, remove_content_from_message
|
|
38
|
+
from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
|
|
43
39
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
44
40
|
generate_key_pairs,
|
|
45
41
|
)
|
|
@@ -49,13 +45,17 @@ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=
|
|
|
49
45
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
50
46
|
CreateNodeRequest,
|
|
51
47
|
DeleteNodeRequest,
|
|
52
|
-
PingRequest,
|
|
53
|
-
PingResponse,
|
|
54
48
|
PullMessagesRequest,
|
|
55
49
|
PullMessagesResponse,
|
|
56
50
|
PushMessagesRequest,
|
|
51
|
+
PushMessagesResponse,
|
|
57
52
|
)
|
|
58
53
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
54
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
55
|
+
SendNodeHeartbeatRequest,
|
|
56
|
+
SendNodeHeartbeatResponse,
|
|
57
|
+
)
|
|
58
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
59
59
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
60
60
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
61
61
|
|
|
@@ -76,12 +76,15 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
76
76
|
adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None,
|
|
77
77
|
) -> Iterator[
|
|
78
78
|
tuple[
|
|
79
|
-
Callable[[], Optional[Message]],
|
|
80
|
-
Callable[[Message],
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
79
|
+
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
80
|
+
Callable[[Message, ObjectTree], set[str]],
|
|
81
|
+
Callable[[], Optional[int]],
|
|
82
|
+
Callable[[], None],
|
|
83
|
+
Callable[[int], Run],
|
|
84
|
+
Callable[[str, int], Fab],
|
|
85
|
+
Callable[[int, str], bytes],
|
|
86
|
+
Callable[[int, str, bytes], None],
|
|
87
|
+
Callable[[int, str], None],
|
|
85
88
|
]
|
|
86
89
|
]:
|
|
87
90
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -124,6 +127,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
124
127
|
create_node : Optional[Callable]
|
|
125
128
|
delete_node : Optional[Callable]
|
|
126
129
|
get_run : Optional[Callable]
|
|
130
|
+
pull_object : Callable[[str], bytes]
|
|
131
|
+
push_object : Callable[[str, bytes], None]
|
|
132
|
+
confirm_message_received : Callable[[str], None]
|
|
127
133
|
"""
|
|
128
134
|
if isinstance(root_certificates, str):
|
|
129
135
|
root_certificates = Path(root_certificates).read_bytes()
|
|
@@ -149,10 +155,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
149
155
|
if adapter_cls is None:
|
|
150
156
|
adapter_cls = FleetStub
|
|
151
157
|
stub = adapter_cls(channel)
|
|
152
|
-
metadata: Optional[Metadata] = None
|
|
153
158
|
node: Optional[Node] = None
|
|
154
|
-
ping_thread: Optional[threading.Thread] = None
|
|
155
|
-
ping_stop_event = threading.Event()
|
|
156
159
|
|
|
157
160
|
def _should_giveup_fn(e: Exception) -> bool:
|
|
158
161
|
if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore
|
|
@@ -165,46 +168,58 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
165
168
|
# If the status code is PERMISSION_DENIED, additionally raise RunNotRunningException
|
|
166
169
|
retry_invoker.should_giveup = _should_giveup_fn
|
|
167
170
|
|
|
171
|
+
# Wrap stub
|
|
172
|
+
_wrap_stub(stub, retry_invoker)
|
|
168
173
|
###########################################################################
|
|
169
|
-
#
|
|
174
|
+
# send_node_heartbeat/create_node/delete_node/receive/send/get_run functions
|
|
170
175
|
###########################################################################
|
|
171
176
|
|
|
172
|
-
def
|
|
177
|
+
def send_node_heartbeat() -> bool:
|
|
173
178
|
# Get Node
|
|
174
179
|
if node is None:
|
|
175
180
|
log(ERROR, "Node instance missing")
|
|
176
|
-
return
|
|
181
|
+
return False
|
|
177
182
|
|
|
178
|
-
# Construct the
|
|
179
|
-
req =
|
|
183
|
+
# Construct the heartbeat request
|
|
184
|
+
req = SendNodeHeartbeatRequest(
|
|
185
|
+
node=node, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
|
|
186
|
+
)
|
|
180
187
|
|
|
181
188
|
# Call FleetAPI
|
|
182
|
-
|
|
189
|
+
try:
|
|
190
|
+
res: SendNodeHeartbeatResponse = stub.SendNodeHeartbeat(
|
|
191
|
+
req, timeout=HEARTBEAT_CALL_TIMEOUT
|
|
192
|
+
)
|
|
193
|
+
except grpc.RpcError as e:
|
|
194
|
+
status_code = e.code()
|
|
195
|
+
if status_code == grpc.StatusCode.UNAVAILABLE:
|
|
196
|
+
return False
|
|
197
|
+
if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
198
|
+
return False
|
|
199
|
+
raise
|
|
183
200
|
|
|
184
201
|
# Check if success
|
|
185
202
|
if not res.success:
|
|
186
|
-
raise RuntimeError(
|
|
203
|
+
raise RuntimeError(
|
|
204
|
+
"Heartbeat failed unexpectedly. The SuperLink does not "
|
|
205
|
+
"recognize this SuperNode."
|
|
206
|
+
)
|
|
207
|
+
return True
|
|
187
208
|
|
|
188
|
-
|
|
189
|
-
rd = random.uniform(*PING_RANDOM_RANGE)
|
|
190
|
-
next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
|
|
191
|
-
next_interval *= PING_BASE_MULTIPLIER + rd
|
|
192
|
-
if not ping_stop_event.is_set():
|
|
193
|
-
ping_stop_event.wait(next_interval)
|
|
209
|
+
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
194
210
|
|
|
195
211
|
def create_node() -> Optional[int]:
|
|
196
212
|
"""Set create_node."""
|
|
197
213
|
# Call FleetAPI
|
|
198
|
-
create_node_request = CreateNodeRequest(
|
|
199
|
-
|
|
200
|
-
stub.CreateNode,
|
|
201
|
-
request=create_node_request,
|
|
214
|
+
create_node_request = CreateNodeRequest(
|
|
215
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
|
|
202
216
|
)
|
|
217
|
+
create_node_response = stub.CreateNode(request=create_node_request)
|
|
203
218
|
|
|
204
|
-
# Remember the node and the
|
|
205
|
-
nonlocal node
|
|
219
|
+
# Remember the node and start the heartbeat sender
|
|
220
|
+
nonlocal node
|
|
206
221
|
node = cast(Node, create_node_response.node)
|
|
207
|
-
|
|
222
|
+
heartbeat_sender.start()
|
|
208
223
|
return node.node_id
|
|
209
224
|
|
|
210
225
|
def delete_node() -> None:
|
|
@@ -215,83 +230,67 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
215
230
|
log(ERROR, "Node instance missing")
|
|
216
231
|
return
|
|
217
232
|
|
|
218
|
-
# Stop the
|
|
219
|
-
|
|
233
|
+
# Stop the heartbeat sender
|
|
234
|
+
heartbeat_sender.stop()
|
|
220
235
|
|
|
221
236
|
# Call FleetAPI
|
|
222
237
|
delete_node_request = DeleteNodeRequest(node=node)
|
|
223
|
-
|
|
238
|
+
stub.DeleteNode(request=delete_node_request)
|
|
224
239
|
|
|
225
240
|
# Cleanup
|
|
226
241
|
node = None
|
|
227
242
|
|
|
228
|
-
def receive() -> Optional[Message]:
|
|
229
|
-
"""
|
|
243
|
+
def receive() -> Optional[tuple[Message, ObjectTree]]:
|
|
244
|
+
"""Pull a message with its ObjectTree from SuperLink."""
|
|
230
245
|
# Get Node
|
|
231
246
|
if node is None:
|
|
232
247
|
log(ERROR, "Node instance missing")
|
|
233
248
|
return None
|
|
234
249
|
|
|
235
|
-
#
|
|
250
|
+
# Try to pull a message with its object tree from SuperLink
|
|
236
251
|
request = PullMessagesRequest(node=node)
|
|
237
|
-
response: PullMessagesResponse =
|
|
238
|
-
stub.PullMessages, request=request
|
|
239
|
-
)
|
|
252
|
+
response: PullMessagesResponse = stub.PullMessages(request=request)
|
|
240
253
|
|
|
241
|
-
#
|
|
242
|
-
|
|
243
|
-
None
|
|
244
|
-
)
|
|
254
|
+
# If no messages are available, return None
|
|
255
|
+
if len(response.messages_list) == 0:
|
|
256
|
+
return None
|
|
245
257
|
|
|
246
|
-
#
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
):
|
|
250
|
-
message_proto = None
|
|
258
|
+
# Get the current Message and its object tree
|
|
259
|
+
message_proto = response.messages_list[0]
|
|
260
|
+
object_tree = response.message_object_trees[0]
|
|
251
261
|
|
|
252
262
|
# Construct the Message
|
|
253
|
-
in_message = message_from_proto(message_proto)
|
|
254
|
-
|
|
255
|
-
# Remember `metadata` of the in message
|
|
256
|
-
nonlocal metadata
|
|
257
|
-
metadata = copy(in_message.metadata) if in_message else None
|
|
263
|
+
in_message = message_from_proto(message_proto)
|
|
258
264
|
|
|
259
|
-
# Return the
|
|
260
|
-
return in_message
|
|
265
|
+
# Return the Message and its object tree
|
|
266
|
+
return in_message, object_tree
|
|
261
267
|
|
|
262
|
-
def send(message: Message) ->
|
|
263
|
-
"""Send message
|
|
268
|
+
def send(message: Message, object_tree: ObjectTree) -> set[str]:
|
|
269
|
+
"""Send the message with its ObjectTree to SuperLink."""
|
|
264
270
|
# Get Node
|
|
265
271
|
if node is None:
|
|
266
272
|
log(ERROR, "Node instance missing")
|
|
267
|
-
return
|
|
268
|
-
|
|
269
|
-
# Get the metadata of the incoming message
|
|
270
|
-
nonlocal metadata
|
|
271
|
-
if metadata is None:
|
|
272
|
-
log(ERROR, "No current message")
|
|
273
|
-
return
|
|
273
|
+
return set()
|
|
274
274
|
|
|
275
|
-
#
|
|
276
|
-
if
|
|
277
|
-
|
|
278
|
-
return
|
|
275
|
+
# Remove the content from the message if it has
|
|
276
|
+
if message.has_content():
|
|
277
|
+
message = remove_content_from_message(message)
|
|
279
278
|
|
|
280
|
-
#
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
279
|
+
# Send the message with its ObjectTree to SuperLink
|
|
280
|
+
request = PushMessagesRequest(
|
|
281
|
+
node=node,
|
|
282
|
+
messages_list=[message_to_proto(message)],
|
|
283
|
+
message_object_trees=[object_tree],
|
|
284
|
+
)
|
|
285
|
+
response: PushMessagesResponse = stub.PushMessages(request=request)
|
|
284
286
|
|
|
285
|
-
#
|
|
286
|
-
|
|
287
|
+
# Get and return the object IDs to push
|
|
288
|
+
return set(response.objects_to_push)
|
|
287
289
|
|
|
288
290
|
def get_run(run_id: int) -> Run:
|
|
289
291
|
# Call FleetAPI
|
|
290
292
|
get_run_request = GetRunRequest(node=node, run_id=run_id)
|
|
291
|
-
get_run_response: GetRunResponse =
|
|
292
|
-
stub.GetRun,
|
|
293
|
-
request=get_run_request,
|
|
294
|
-
)
|
|
293
|
+
get_run_response: GetRunResponse = stub.GetRun(request=get_run_request)
|
|
295
294
|
|
|
296
295
|
# Return fab_id and fab_version
|
|
297
296
|
return run_from_proto(get_run_response.run)
|
|
@@ -299,16 +298,62 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
299
298
|
def get_fab(fab_hash: str, run_id: int) -> Fab:
|
|
300
299
|
# Call FleetAPI
|
|
301
300
|
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
|
|
302
|
-
get_fab_response: GetFabResponse =
|
|
303
|
-
stub.GetFab,
|
|
304
|
-
request=get_fab_request,
|
|
305
|
-
)
|
|
301
|
+
get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
|
|
306
302
|
|
|
307
303
|
return Fab(get_fab_response.fab.hash_str, get_fab_response.fab.content)
|
|
308
304
|
|
|
305
|
+
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
306
|
+
"""Pull the object from the SuperLink."""
|
|
307
|
+
# Check Node
|
|
308
|
+
if node is None:
|
|
309
|
+
raise RuntimeError("Node instance missing")
|
|
310
|
+
|
|
311
|
+
fn = make_pull_object_fn_protobuf(
|
|
312
|
+
pull_object_protobuf=stub.PullObject,
|
|
313
|
+
node=node,
|
|
314
|
+
run_id=run_id,
|
|
315
|
+
)
|
|
316
|
+
return fn(object_id)
|
|
317
|
+
|
|
318
|
+
def push_object(run_id: int, object_id: str, contents: bytes) -> None:
|
|
319
|
+
"""Push the object to the SuperLink."""
|
|
320
|
+
# Check Node
|
|
321
|
+
if node is None:
|
|
322
|
+
raise RuntimeError("Node instance missing")
|
|
323
|
+
|
|
324
|
+
fn = make_push_object_fn_protobuf(
|
|
325
|
+
push_object_protobuf=stub.PushObject,
|
|
326
|
+
node=node,
|
|
327
|
+
run_id=run_id,
|
|
328
|
+
)
|
|
329
|
+
fn(object_id, contents)
|
|
330
|
+
|
|
331
|
+
def confirm_message_received(run_id: int, object_id: str) -> None:
|
|
332
|
+
"""Confirm that the message has been received."""
|
|
333
|
+
# Check Node
|
|
334
|
+
if node is None:
|
|
335
|
+
raise RuntimeError("Node instance missing")
|
|
336
|
+
|
|
337
|
+
fn = make_confirm_message_received_fn_protobuf(
|
|
338
|
+
confirm_message_received_protobuf=stub.ConfirmMessageReceived,
|
|
339
|
+
node=node,
|
|
340
|
+
run_id=run_id,
|
|
341
|
+
)
|
|
342
|
+
fn(object_id)
|
|
343
|
+
|
|
309
344
|
try:
|
|
310
345
|
# Yield methods
|
|
311
|
-
yield (
|
|
346
|
+
yield (
|
|
347
|
+
receive,
|
|
348
|
+
send,
|
|
349
|
+
create_node,
|
|
350
|
+
delete_node,
|
|
351
|
+
get_run,
|
|
352
|
+
get_fab,
|
|
353
|
+
pull_object,
|
|
354
|
+
push_object,
|
|
355
|
+
confirm_message_received,
|
|
356
|
+
)
|
|
312
357
|
except Exception as exc: # pylint: disable=broad-except
|
|
313
358
|
log(ERROR, exc)
|
|
314
359
|
# Cleanup
|