flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.15.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/cli/app.py +16 -2
- flwr/cli/build.py +181 -0
- flwr/cli/cli_user_auth_interceptor.py +90 -0
- flwr/cli/config_utils.py +343 -0
- flwr/cli/example.py +4 -1
- flwr/cli/install.py +253 -0
- flwr/cli/log.py +182 -0
- flwr/{server/superlink/state → cli/login}/__init__.py +4 -10
- flwr/cli/login/login.py +88 -0
- flwr/cli/ls.py +327 -0
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +210 -66
- flwr/cli/new/templates/app/.gitignore.tpl +163 -0
- flwr/cli/new/templates/app/LICENSE.tpl +202 -0
- flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +66 -0
- flwr/cli/new/templates/app/README.md.tpl +16 -32
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +50 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +73 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +7 -7
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +30 -21
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +63 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +57 -1
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +87 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +78 -0
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +94 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +38 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +26 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +22 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +36 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +7 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +29 -24
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +67 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +68 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +46 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +35 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +35 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +212 -34
- flwr/cli/stop.py +130 -0
- flwr/cli/utils.py +240 -5
- flwr/client/__init__.py +3 -2
- flwr/client/app.py +432 -255
- flwr/client/client.py +1 -11
- flwr/client/client_app.py +74 -13
- flwr/client/clientapp/__init__.py +22 -0
- flwr/client/clientapp/app.py +259 -0
- flwr/client/clientapp/clientappio_servicer.py +244 -0
- flwr/client/clientapp/utils.py +115 -0
- flwr/client/dpfedavg_numpy_client.py +7 -8
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +98 -0
- flwr/client/grpc_client/connection.py +21 -7
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +176 -0
- flwr/client/grpc_rere_client/connection.py +163 -56
- flwr/client/grpc_rere_client/grpc_adapter.py +167 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +10 -11
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +5 -4
- flwr/client/mod/localdp_mod.py +10 -5
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
- flwr/client/mod/utils.py +2 -4
- flwr/client/nodestate/__init__.py +26 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +31 -0
- flwr/client/nodestate/nodestate_factory.py +38 -0
- flwr/client/numpy_client.py +8 -31
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +199 -176
- flwr/client/run_info_store.py +112 -0
- flwr/client/supernode/__init__.py +24 -0
- flwr/client/supernode/app.py +321 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +17 -11
- flwr/common/address.py +47 -3
- flwr/common/args.py +153 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +121 -0
- flwr/common/config.py +243 -0
- flwr/common/constant.py +132 -1
- flwr/common/context.py +32 -2
- flwr/common/date.py +22 -4
- flwr/common/differential_privacy.py +2 -2
- flwr/common/dp.py +2 -4
- flwr/common/exit_handlers.py +3 -3
- flwr/common/grpc.py +164 -5
- flwr/common/logger.py +230 -12
- flwr/common/message.py +191 -106
- flwr/common/object_ref.py +179 -44
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/configsrecord.py +58 -18
- flwr/common/record/metricsrecord.py +57 -17
- flwr/common/record/parametersrecord.py +88 -20
- flwr/common/record/recordset.py +153 -30
- flwr/common/record/typeddict.py +30 -55
- flwr/common/recordset_compat.py +31 -12
- flwr/common/retry_invoker.py +123 -30
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +11 -11
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +68 -4
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +17 -17
- flwr/common/secure_aggregation/quantization.py +8 -8
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +10 -12
- flwr/common/serde.py +298 -19
- flwr/common/telemetry.py +65 -29
- flwr/common/typing.py +120 -19
- flwr/common/version.py +17 -3
- flwr/proto/clientappio_pb2.py +45 -0
- flwr/proto/clientappio_pb2.pyi +132 -0
- flwr/proto/clientappio_pb2_grpc.py +135 -0
- flwr/proto/clientappio_pb2_grpc.pyi +53 -0
- flwr/proto/exec_pb2.py +62 -0
- flwr/proto/exec_pb2.pyi +212 -0
- flwr/proto/exec_pb2_grpc.py +237 -0
- flwr/proto/exec_pb2_grpc.pyi +93 -0
- flwr/proto/fab_pb2.py +31 -0
- flwr/proto/fab_pb2.pyi +65 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +42 -23
- flwr/proto/fleet_pb2.pyi +123 -1
- flwr/proto/fleet_pb2_grpc.py +170 -0
- flwr/proto/fleet_pb2_grpc.pyi +61 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +128 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/node_pb2.py +1 -1
- flwr/proto/recordset_pb2.py +35 -33
- flwr/proto/recordset_pb2.pyi +40 -14
- flwr/proto/run_pb2.py +64 -0
- flwr/proto/run_pb2.pyi +268 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +62 -20
- flwr/proto/serverappio_pb2_grpc.py +410 -0
- flwr/proto/serverappio_pb2_grpc.pyi +160 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +239 -0
- flwr/proto/simulationio_pb2_grpc.pyi +94 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/proto/transport_pb2.py +8 -8
- flwr/proto/transport_pb2.pyi +9 -6
- flwr/server/__init__.py +2 -10
- flwr/server/app.py +579 -402
- flwr/server/client_manager.py +8 -6
- flwr/server/compat/app.py +6 -62
- flwr/server/compat/app_utils.py +14 -8
- flwr/server/compat/driver_client_proxy.py +25 -58
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +36 -131
- flwr/server/driver/grpc_driver.py +217 -81
- flwr/server/driver/inmemory_driver.py +182 -0
- flwr/server/history.py +28 -29
- flwr/server/run_serverapp.py +15 -126
- flwr/server/server.py +50 -44
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp/__init__.py +22 -0
- flwr/server/serverapp/app.py +256 -0
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/aggregate.py +37 -23
- flwr/server/strategy/bulyan.py +9 -9
- flwr/server/strategy/dp_adaptive_clipping.py +25 -25
- flwr/server/strategy/dp_fixed_clipping.py +23 -22
- flwr/server/strategy/dpfedavg_adaptive.py +8 -8
- flwr/server/strategy/dpfedavg_fixed.py +13 -12
- flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
- flwr/server/strategy/fedadagrad.py +9 -9
- flwr/server/strategy/fedadam.py +20 -10
- flwr/server/strategy/fedavg.py +16 -16
- flwr/server/strategy/fedavg_android.py +17 -17
- flwr/server/strategy/fedavgm.py +9 -9
- flwr/server/strategy/fedmedian.py +5 -5
- flwr/server/strategy/fedopt.py +6 -6
- flwr/server/strategy/fedprox.py +7 -7
- flwr/server/strategy/fedtrimmedavg.py +8 -8
- flwr/server/strategy/fedxgb_bagging.py +12 -12
- flwr/server/strategy/fedxgb_cyclic.py +10 -10
- flwr/server/strategy/fedxgb_nn_avg.py +6 -6
- flwr/server/strategy/fedyogi.py +9 -9
- flwr/server/strategy/krum.py +9 -9
- flwr/server/strategy/qfedavg.py +16 -16
- flwr/server/strategy/strategy.py +10 -10
- flwr/server/superlink/driver/__init__.py +2 -2
- flwr/server/superlink/driver/serverappio_grpc.py +61 -0
- flwr/server/superlink/driver/serverappio_servicer.py +363 -0
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +108 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/ffs/ffs_factory.py +47 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +162 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -154
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +120 -13
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +228 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +153 -9
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +119 -81
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +4 -4
- flwr/server/superlink/fleet/vce/backend/backend.py +8 -9
- flwr/server/superlink/fleet/vce/backend/raybackend.py +87 -68
- flwr/server/superlink/fleet/vce/vce_api.py +208 -146
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +581 -0
- flwr/server/superlink/linkstate/linkstate.py +389 -0
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +19 -10
- flwr/server/superlink/linkstate/sqlite_linkstate.py +1236 -0
- flwr/server/superlink/linkstate/utils.py +389 -0
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +186 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +5 -5
- flwr/server/utils/validator.py +31 -11
- flwr/server/workflow/default_workflows.py +70 -26
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +40 -27
- flwr/simulation/__init__.py +12 -5
- flwr/simulation/app.py +247 -315
- flwr/simulation/legacy_app.py +402 -0
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +42 -67
- flwr/simulation/ray_transport/ray_client_proxy.py +37 -17
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +306 -163
- flwr/simulation/simulationio_connection.py +89 -0
- flwr/superexec/__init__.py +15 -0
- flwr/superexec/app.py +59 -0
- flwr/superexec/deployment.py +188 -0
- flwr/superexec/exec_grpc.py +80 -0
- flwr/superexec/exec_servicer.py +231 -0
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +96 -0
- flwr/superexec/simulation.py +124 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/METADATA +33 -26
- flwr_nightly-1.15.0.dev20250114.dist-info/RECORD +328 -0
- flwr_nightly-1.15.0.dev20250114.dist-info/entry_points.txt +12 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr/client/node_state.py +0 -48
- flwr/client/node_state_tests.py +0 -65
- flwr/proto/driver_pb2.py +0 -44
- flwr/proto/driver_pb2_grpc.py +0 -169
- flwr/proto/driver_pb2_grpc.pyi +0 -66
- flwr/server/superlink/driver/driver_grpc.py +0 -54
- flwr/server/superlink/driver/driver_servicer.py +0 -129
- flwr/server/superlink/state/in_memory_state.py +0 -230
- flwr/server/superlink/state/sqlite_state.py +0 -630
- flwr/server/superlink/state/state.py +0 -154
- flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -15,17 +15,37 @@
|
|
15
15
|
"""Experimental REST API server."""
|
16
16
|
|
17
17
|
|
18
|
+
from __future__ import annotations
|
19
|
+
|
18
20
|
import sys
|
21
|
+
from collections.abc import Awaitable
|
22
|
+
from typing import Callable, TypeVar, cast
|
23
|
+
|
24
|
+
from google.protobuf.message import Message as GrpcMessage
|
19
25
|
|
20
26
|
from flwr.common.constant import MISSING_EXTRA_REST
|
27
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
21
28
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
22
29
|
CreateNodeRequest,
|
30
|
+
CreateNodeResponse,
|
23
31
|
DeleteNodeRequest,
|
32
|
+
DeleteNodeResponse,
|
33
|
+
PingRequest,
|
34
|
+
PingResponse,
|
35
|
+
PullMessagesRequest,
|
36
|
+
PullMessagesResponse,
|
24
37
|
PullTaskInsRequest,
|
38
|
+
PullTaskInsResponse,
|
39
|
+
PushMessagesRequest,
|
40
|
+
PushMessagesResponse,
|
25
41
|
PushTaskResRequest,
|
42
|
+
PushTaskResResponse,
|
26
43
|
)
|
44
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
45
|
+
from flwr.server.superlink.ffs.ffs import Ffs
|
46
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
27
47
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
28
|
-
from flwr.server.superlink.
|
48
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
29
49
|
|
30
50
|
try:
|
31
51
|
from starlette.applications import Starlette
|
@@ -38,125 +58,143 @@ except ModuleNotFoundError:
|
|
38
58
|
sys.exit(MISSING_EXTRA_REST)
|
39
59
|
|
40
60
|
|
41
|
-
|
42
|
-
|
43
|
-
_check_headers(request.headers)
|
61
|
+
GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
|
62
|
+
GrpcResponse = TypeVar("GrpcResponse", bound=GrpcMessage)
|
44
63
|
|
45
|
-
|
46
|
-
|
64
|
+
GrpcAsyncFunction = Callable[[GrpcRequest], Awaitable[GrpcResponse]]
|
65
|
+
RestEndPoint = Callable[[Request], Awaitable[Response]]
|
47
66
|
|
48
|
-
# Deserialize ProtoBuf
|
49
|
-
create_node_request_proto = CreateNodeRequest()
|
50
|
-
create_node_request_proto.ParseFromString(create_node_request_bytes)
|
51
67
|
|
52
|
-
|
53
|
-
|
68
|
+
def rest_request_response(
|
69
|
+
grpc_request_type: type[GrpcRequest],
|
70
|
+
) -> Callable[[GrpcAsyncFunction[GrpcRequest, GrpcResponse]], RestEndPoint]:
|
71
|
+
"""Convert an async gRPC-based function into a RESTful HTTP endpoint."""
|
54
72
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
)
|
73
|
+
def decorator(func: GrpcAsyncFunction[GrpcRequest, GrpcResponse]) -> RestEndPoint:
|
74
|
+
async def wrapper(request: Request) -> Response:
|
75
|
+
_check_headers(request.headers)
|
59
76
|
|
60
|
-
|
61
|
-
|
62
|
-
return Response(
|
63
|
-
status_code=200,
|
64
|
-
content=create_node_response_bytes,
|
65
|
-
headers={"Content-Type": "application/protobuf"},
|
66
|
-
)
|
77
|
+
# Get the request body as raw bytes
|
78
|
+
grpc_req_bytes: bytes = await request.body()
|
67
79
|
|
80
|
+
# Deserialize ProtoBuf
|
81
|
+
grpc_req = grpc_request_type.FromString(grpc_req_bytes)
|
82
|
+
grpc_res = await func(grpc_req)
|
83
|
+
return Response(
|
84
|
+
status_code=200,
|
85
|
+
content=grpc_res.SerializeToString(),
|
86
|
+
headers={"Content-Type": "application/protobuf"},
|
87
|
+
)
|
68
88
|
|
69
|
-
|
70
|
-
"""Delete Node Id."""
|
71
|
-
_check_headers(request.headers)
|
89
|
+
return wrapper
|
72
90
|
|
73
|
-
|
74
|
-
delete_node_request_bytes: bytes = await request.body()
|
91
|
+
return decorator
|
75
92
|
|
76
|
-
# Deserialize ProtoBuf
|
77
|
-
delete_node_request_proto = DeleteNodeRequest()
|
78
|
-
delete_node_request_proto.ParseFromString(delete_node_request_bytes)
|
79
93
|
|
94
|
+
@rest_request_response(CreateNodeRequest)
|
95
|
+
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
96
|
+
"""Create Node."""
|
80
97
|
# Get state from app
|
81
|
-
state:
|
98
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
82
99
|
|
83
100
|
# Handle message
|
84
|
-
|
85
|
-
|
86
|
-
)
|
101
|
+
return message_handler.create_node(request=request, state=state)
|
102
|
+
|
87
103
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
104
|
+
@rest_request_response(DeleteNodeRequest)
|
105
|
+
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
106
|
+
"""Delete Node Id."""
|
107
|
+
# Get state from app
|
108
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
109
|
+
|
110
|
+
# Handle message
|
111
|
+
return message_handler.delete_node(request=request, state=state)
|
95
112
|
|
96
113
|
|
97
|
-
|
114
|
+
@rest_request_response(PullTaskInsRequest)
|
115
|
+
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
98
116
|
"""Pull TaskIns."""
|
99
|
-
|
117
|
+
# Get state from app
|
118
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
100
119
|
|
101
|
-
#
|
102
|
-
|
120
|
+
# Handle message
|
121
|
+
return message_handler.pull_task_ins(request=request, state=state)
|
103
122
|
|
104
|
-
# Deserialize ProtoBuf
|
105
|
-
pull_task_ins_request_proto = PullTaskInsRequest()
|
106
|
-
pull_task_ins_request_proto.ParseFromString(pull_task_ins_request_bytes)
|
107
123
|
|
124
|
+
@rest_request_response(PullMessagesRequest)
|
125
|
+
async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
126
|
+
"""Pull PullMessages."""
|
108
127
|
# Get state from app
|
109
|
-
state:
|
128
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
110
129
|
|
111
130
|
# Handle message
|
112
|
-
|
113
|
-
request=pull_task_ins_request_proto,
|
114
|
-
state=state,
|
115
|
-
)
|
131
|
+
return message_handler.pull_messages(request=request, state=state)
|
116
132
|
|
117
|
-
# Return serialized ProtoBuf
|
118
|
-
pull_task_ins_response_bytes = pull_task_ins_response_proto.SerializeToString()
|
119
|
-
return Response(
|
120
|
-
status_code=200,
|
121
|
-
content=pull_task_ins_response_bytes,
|
122
|
-
headers={"Content-Type": "application/protobuf"},
|
123
|
-
)
|
124
133
|
|
125
|
-
|
126
|
-
|
134
|
+
# Check if token is needed here
|
135
|
+
@rest_request_response(PushTaskResRequest)
|
136
|
+
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
127
137
|
"""Push TaskRes."""
|
128
|
-
|
138
|
+
# Get state from app
|
139
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
129
140
|
|
130
|
-
#
|
131
|
-
|
141
|
+
# Handle message
|
142
|
+
return message_handler.push_task_res(request=request, state=state)
|
143
|
+
|
144
|
+
|
145
|
+
@rest_request_response(PushMessagesRequest)
|
146
|
+
async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
147
|
+
"""Pull PushMessages."""
|
148
|
+
# Get state from app
|
149
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
150
|
+
|
151
|
+
# Handle message
|
152
|
+
return message_handler.push_messages(request=request, state=state)
|
153
|
+
|
154
|
+
|
155
|
+
@rest_request_response(PingRequest)
|
156
|
+
async def ping(request: PingRequest) -> PingResponse:
|
157
|
+
"""Ping."""
|
158
|
+
# Get state from app
|
159
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
160
|
+
|
161
|
+
# Handle message
|
162
|
+
return message_handler.ping(request=request, state=state)
|
132
163
|
|
133
|
-
# Deserialize ProtoBuf
|
134
|
-
push_task_res_request_proto = PushTaskResRequest()
|
135
|
-
push_task_res_request_proto.ParseFromString(push_task_res_request_bytes)
|
136
164
|
|
165
|
+
@rest_request_response(GetRunRequest)
|
166
|
+
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
167
|
+
"""GetRun."""
|
137
168
|
# Get state from app
|
138
|
-
state:
|
169
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
139
170
|
|
140
171
|
# Handle message
|
141
|
-
|
142
|
-
|
143
|
-
state=state,
|
144
|
-
)
|
172
|
+
return message_handler.get_run(request=request, state=state)
|
173
|
+
|
145
174
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
175
|
+
@rest_request_response(GetFabRequest)
|
176
|
+
async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
177
|
+
"""GetRun."""
|
178
|
+
# Get ffs from app
|
179
|
+
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
|
180
|
+
|
181
|
+
# Get state from app
|
182
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
183
|
+
|
184
|
+
# Handle message
|
185
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state)
|
153
186
|
|
154
187
|
|
155
188
|
routes = [
|
156
189
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
157
190
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
158
191
|
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
|
192
|
+
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
159
193
|
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
194
|
+
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
195
|
+
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
196
|
+
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
197
|
+
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
160
198
|
]
|
161
199
|
|
162
200
|
app: Starlette = Starlette(
|
@@ -14,18 +14,18 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Simulation Engine Backends."""
|
16
16
|
|
17
|
+
|
17
18
|
import importlib
|
18
|
-
from typing import Dict, Type
|
19
19
|
|
20
20
|
from .backend import Backend, BackendConfig
|
21
21
|
|
22
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
23
23
|
|
24
24
|
# Mapping of supported backends
|
25
|
-
supported_backends:
|
25
|
+
supported_backends: dict[str, type[Backend]] = {}
|
26
26
|
|
27
27
|
# To log backend-specific error message when chosen backend isn't available
|
28
|
-
error_messages_backends:
|
28
|
+
error_messages_backends: dict[str, str] = {}
|
29
29
|
|
30
30
|
if is_ray_installed:
|
31
31
|
from .raybackend import RayBackend
|
@@ -38,7 +38,7 @@ else:
|
|
38
38
|
|
39
39
|
To install the necessary dependencies, install `flwr` with the `simulation` extra:
|
40
40
|
|
41
|
-
pip install -U flwr[
|
41
|
+
pip install -U "flwr[simulation]"
|
42
42
|
"""
|
43
43
|
|
44
44
|
|
@@ -16,25 +16,25 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from abc import ABC, abstractmethod
|
19
|
-
from typing import Callable
|
19
|
+
from typing import Callable
|
20
20
|
|
21
21
|
from flwr.client.client_app import ClientApp
|
22
22
|
from flwr.common.context import Context
|
23
23
|
from flwr.common.message import Message
|
24
24
|
from flwr.common.typing import ConfigsRecordValues
|
25
25
|
|
26
|
-
BackendConfig =
|
26
|
+
BackendConfig = dict[str, dict[str, ConfigsRecordValues]]
|
27
27
|
|
28
28
|
|
29
29
|
class Backend(ABC):
|
30
30
|
"""Abstract base class for a Simulation Engine Backend."""
|
31
31
|
|
32
|
-
def __init__(self, backend_config: BackendConfig
|
32
|
+
def __init__(self, backend_config: BackendConfig) -> None:
|
33
33
|
"""Construct a backend."""
|
34
34
|
|
35
35
|
@abstractmethod
|
36
|
-
|
37
|
-
"""Build backend
|
36
|
+
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
37
|
+
"""Build backend.
|
38
38
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
40
40
|
to accept jobs. When this method finishes executing, the backend should be fully
|
@@ -54,14 +54,13 @@ class Backend(ABC):
|
|
54
54
|
"""Report whether a backend worker is idle and can therefore run a ClientApp."""
|
55
55
|
|
56
56
|
@abstractmethod
|
57
|
-
|
57
|
+
def terminate(self) -> None:
|
58
58
|
"""Terminate backend."""
|
59
59
|
|
60
60
|
@abstractmethod
|
61
|
-
|
61
|
+
def process_message(
|
62
62
|
self,
|
63
|
-
app: Callable[[], ClientApp],
|
64
63
|
message: Message,
|
65
64
|
context: Context,
|
66
|
-
) ->
|
65
|
+
) -> tuple[Message, Context]:
|
67
66
|
"""Submit a job to the backend."""
|
@@ -14,26 +14,26 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
from
|
17
|
+
|
18
|
+
import sys
|
19
|
+
from logging import DEBUG, ERROR
|
20
|
+
from typing import Callable, Optional, Union
|
20
21
|
|
21
22
|
import ray
|
22
23
|
|
23
|
-
from flwr.client.client_app import ClientApp
|
24
|
+
from flwr.client.client_app import ClientApp
|
25
|
+
from flwr.common.constant import PARTITION_ID_KEY
|
24
26
|
from flwr.common.context import Context
|
25
27
|
from flwr.common.logger import log
|
26
28
|
from flwr.common.message import Message
|
27
|
-
from flwr.
|
28
|
-
|
29
|
-
ClientAppActor,
|
30
|
-
init_ray,
|
31
|
-
)
|
29
|
+
from flwr.common.typing import ConfigsRecordValues
|
30
|
+
from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
|
32
31
|
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
33
32
|
|
34
33
|
from .backend import Backend, BackendConfig
|
35
34
|
|
36
|
-
ClientResourcesDict =
|
35
|
+
ClientResourcesDict = dict[str, Union[int, float]]
|
36
|
+
ActorArgsDict = dict[str, Union[int, float, Callable[[], None]]]
|
37
37
|
|
38
38
|
|
39
39
|
class RayBackend(Backend):
|
@@ -42,51 +42,24 @@ class RayBackend(Backend):
|
|
42
42
|
def __init__(
|
43
43
|
self,
|
44
44
|
backend_config: BackendConfig,
|
45
|
-
work_dir: str,
|
46
45
|
) -> None:
|
47
46
|
"""Prepare RayBackend by initialising Ray and creating the ActorPool."""
|
48
|
-
log(
|
49
|
-
log(
|
50
|
-
|
51
|
-
if not pathlib.Path(work_dir).exists():
|
52
|
-
raise ValueError(f"Specified work_dir {work_dir} does not exist.")
|
47
|
+
log(DEBUG, "Initialising: %s", self.__class__.__name__)
|
48
|
+
log(DEBUG, "Backend config: %s", backend_config)
|
53
49
|
|
54
|
-
#
|
55
|
-
|
56
|
-
|
57
|
-
)
|
58
|
-
init_ray(runtime_env=runtime_env)
|
50
|
+
# Initialise ray
|
51
|
+
self.init_args_key = "init_args"
|
52
|
+
self.init_ray(backend_config)
|
59
53
|
|
60
54
|
# Validate client resources
|
61
55
|
self.client_resources_key = "client_resources"
|
56
|
+
self.client_resources = self._validate_client_resources(config=backend_config)
|
62
57
|
|
63
|
-
#
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
client_resources = self._validate_client_resources(config=backend_config)
|
68
|
-
self.pool = BasicActorPool(
|
69
|
-
actor_type=ClientAppActor,
|
70
|
-
client_resources=client_resources,
|
71
|
-
actor_kwargs=actor_kwargs,
|
72
|
-
)
|
73
|
-
|
74
|
-
def _configure_runtime_env(self, work_dir: str) -> Dict[str, Union[str, List[str]]]:
|
75
|
-
"""Return list of files/subdirectories to exclude relative to work_dir.
|
58
|
+
# Valide actor resources
|
59
|
+
self.actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
60
|
+
self.pool: Optional[BasicActorPool] = None
|
76
61
|
|
77
|
-
|
78
|
-
"""
|
79
|
-
runtime_env: Dict[str, Union[str, List[str]]] = {"working_dir": work_dir}
|
80
|
-
|
81
|
-
excludes = []
|
82
|
-
path = pathlib.Path(work_dir)
|
83
|
-
for p in path.rglob("*"):
|
84
|
-
# Exclude files need to be relative to the working_dir
|
85
|
-
if p.is_file() and not str(p).endswith(".py"):
|
86
|
-
excludes.append(str(p.relative_to(path)))
|
87
|
-
runtime_env["excludes"] = excludes
|
88
|
-
|
89
|
-
return runtime_env
|
62
|
+
self.app_fn: Optional[Callable[[], ClientApp]] = None
|
90
63
|
|
91
64
|
def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
|
92
65
|
client_resources_config = config.get(self.client_resources_key)
|
@@ -109,7 +82,7 @@ class RayBackend(Backend):
|
|
109
82
|
else:
|
110
83
|
client_resources = {"num_cpus": 2, "num_gpus": 0.0}
|
111
84
|
log(
|
112
|
-
|
85
|
+
DEBUG,
|
113
86
|
"`%s` not specified in backend config. Applying default setting: %s",
|
114
87
|
self.client_resources_key,
|
115
88
|
client_resources,
|
@@ -117,59 +90,105 @@ class RayBackend(Backend):
|
|
117
90
|
|
118
91
|
return client_resources
|
119
92
|
|
93
|
+
def _validate_actor_arguments(self, config: BackendConfig) -> ActorArgsDict:
|
94
|
+
actor_args_config = config.get("actor", False)
|
95
|
+
actor_args: ActorArgsDict = {}
|
96
|
+
if actor_args_config:
|
97
|
+
use_tf = actor_args.get("tensorflow", False)
|
98
|
+
if use_tf:
|
99
|
+
actor_args["on_actor_init_fn"] = enable_tf_gpu_growth
|
100
|
+
return actor_args
|
101
|
+
|
102
|
+
def init_ray(self, backend_config: BackendConfig) -> None:
|
103
|
+
"""Intialises Ray if not already initialised."""
|
104
|
+
if not ray.is_initialized():
|
105
|
+
ray_init_args: dict[
|
106
|
+
str,
|
107
|
+
ConfigsRecordValues,
|
108
|
+
] = {}
|
109
|
+
|
110
|
+
if backend_config.get(self.init_args_key):
|
111
|
+
for k, v in backend_config[self.init_args_key].items():
|
112
|
+
ray_init_args[k] = v
|
113
|
+
ray.init(
|
114
|
+
runtime_env={"env_vars": {"PYTHONPATH": ":".join(sys.path)}},
|
115
|
+
**ray_init_args,
|
116
|
+
)
|
117
|
+
|
120
118
|
@property
|
121
119
|
def num_workers(self) -> int:
|
122
120
|
"""Return number of actors in pool."""
|
123
|
-
return self.pool.num_actors
|
121
|
+
return self.pool.num_actors if self.pool else 0
|
124
122
|
|
125
123
|
def is_worker_idle(self) -> bool:
|
126
124
|
"""Report whether the pool has idle actors."""
|
127
|
-
return self.pool.is_actor_available()
|
125
|
+
return self.pool.is_actor_available() if self.pool else False
|
128
126
|
|
129
|
-
|
127
|
+
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
130
128
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
131
|
-
|
132
|
-
|
129
|
+
# Create Actor Pool
|
130
|
+
try:
|
131
|
+
self.pool = BasicActorPool(
|
132
|
+
actor_type=ClientAppActor,
|
133
|
+
client_resources=self.client_resources,
|
134
|
+
actor_kwargs=self.actor_kwargs,
|
135
|
+
)
|
136
|
+
except Exception as ex:
|
137
|
+
raise ex
|
138
|
+
|
139
|
+
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
140
|
+
# Set ClientApp callable that ray actors will use
|
141
|
+
self.app_fn = app_fn
|
142
|
+
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
133
143
|
|
134
|
-
|
144
|
+
def process_message(
|
135
145
|
self,
|
136
|
-
app: Callable[[], ClientApp],
|
137
146
|
message: Message,
|
138
147
|
context: Context,
|
139
|
-
) ->
|
148
|
+
) -> tuple[Message, Context]:
|
140
149
|
"""Run ClientApp that process a given message.
|
141
150
|
|
142
151
|
Return output message and updated context.
|
143
152
|
"""
|
144
|
-
partition_id =
|
153
|
+
partition_id = context.node_config[PARTITION_ID_KEY]
|
154
|
+
|
155
|
+
if self.pool is None:
|
156
|
+
raise ValueError("The actor pool is empty, unfit to process messages.")
|
157
|
+
|
158
|
+
if self.app_fn is None:
|
159
|
+
raise ValueError(
|
160
|
+
"Unspecified function to load a `ClientApp`. "
|
161
|
+
"Call the backend's `build()` method before processing messages."
|
162
|
+
)
|
145
163
|
|
146
164
|
try:
|
147
|
-
#
|
148
|
-
future =
|
165
|
+
# Submit a task to the pool
|
166
|
+
future = self.pool.submit(
|
149
167
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
150
|
-
(
|
168
|
+
(self.app_fn, message, str(partition_id), context),
|
151
169
|
)
|
152
170
|
|
153
|
-
await future
|
154
|
-
|
155
171
|
# Fetch result
|
156
172
|
(
|
157
173
|
out_mssg,
|
158
174
|
updated_context,
|
159
|
-
) =
|
175
|
+
) = self.pool.fetch_result_and_return_actor_to_pool(future)
|
160
176
|
|
161
177
|
return out_mssg, updated_context
|
162
178
|
|
163
|
-
except
|
179
|
+
except Exception as ex:
|
164
180
|
log(
|
165
181
|
ERROR,
|
166
182
|
"An exception was raised when processing a message by %s",
|
167
183
|
self.__class__.__name__,
|
168
184
|
)
|
169
|
-
|
185
|
+
# add actor back into pool
|
186
|
+
self.pool.add_actor_back_to_pool(future)
|
187
|
+
raise ex
|
170
188
|
|
171
|
-
|
189
|
+
def terminate(self) -> None:
|
172
190
|
"""Terminate all actors in actor pool."""
|
173
|
-
|
191
|
+
if self.pool:
|
192
|
+
self.pool.terminate_all_actors()
|
174
193
|
ray.shutdown()
|
175
|
-
log(
|
194
|
+
log(DEBUG, "Terminated %s", self.__class__.__name__)
|