flwr-nightly 1.11.0.dev20240822__py3-none-any.whl → 1.11.1.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +0 -2
- flwr/cli/build.py +1 -1
- flwr/cli/new/new.py +41 -40
- flwr/cli/new/templates/app/LICENSE.tpl +202 -0
- flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
- flwr/cli/new/templates/app/README.md.tpl +7 -30
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
- flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +12 -2
- flwr/client/__init__.py +0 -4
- flwr/client/app.py +3 -4
- flwr/client/client.py +22 -1
- flwr/client/client_app.py +2 -2
- flwr/client/grpc_rere_client/client_interceptor.py +15 -7
- flwr/client/numpy_client.py +22 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +8 -7
- flwr/common/address.py +43 -0
- flwr/common/config.py +14 -11
- flwr/common/constant.py +12 -1
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +24 -1
- flwr/common/telemetry.py +36 -30
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +27 -22
- flwr/server/compat/app.py +0 -5
- flwr/server/driver/grpc_driver.py +3 -6
- flwr/server/run_serverapp.py +20 -7
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +15 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +13 -12
- flwr/server/superlink/fleet/rest_rere/rest_api.py +71 -122
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +33 -15
- flwr/server/superlink/fleet/vce/vce_api.py +2 -6
- flwr/server/superlink/state/in_memory_state.py +15 -15
- flwr/server/superlink/state/sqlite_state.py +10 -10
- flwr/server/superlink/state/state.py +8 -8
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -0
- flwr/simulation/ray_transport/ray_actor.py +2 -2
- flwr/simulation/run_simulation.py +85 -25
- flwr/superexec/__init__.py +0 -6
- flwr/superexec/app.py +5 -3
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/simulation.py +20 -1
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/METADATA +3 -3
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/RECORD +70 -62
- flwr_nightly-1.11.1.dev20240912.dist-info/entry_points.txt +10 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
- flwr_nightly-1.11.0.dev20240822.dist-info/entry_points.txt +0 -10
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/WHEEL +0 -0
|
@@ -23,6 +23,7 @@ from google.protobuf.message import Message as GrpcMessage
|
|
|
23
23
|
|
|
24
24
|
from flwr.common.logger import log
|
|
25
25
|
from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
|
|
26
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
26
27
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
27
28
|
CreateNodeRequest,
|
|
28
29
|
CreateNodeResponse,
|
|
@@ -37,6 +38,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
37
38
|
)
|
|
38
39
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
39
40
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
40
42
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
41
43
|
from flwr.server.superlink.state import StateFactory
|
|
42
44
|
|
|
@@ -60,10 +62,11 @@ def _handle(
|
|
|
60
62
|
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
61
63
|
"""Fleet API via GrpcAdapter servicer."""
|
|
62
64
|
|
|
63
|
-
def __init__(self, state_factory: StateFactory) -> None:
|
|
65
|
+
def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
|
|
64
66
|
self.state_factory = state_factory
|
|
67
|
+
self.ffs_factory = ffs_factory
|
|
65
68
|
|
|
66
|
-
def SendReceive(
|
|
69
|
+
def SendReceive( # pylint: disable=too-many-return-statements
|
|
67
70
|
self, request: MessageContainer, context: grpc.ServicerContext
|
|
68
71
|
) -> MessageContainer:
|
|
69
72
|
"""."""
|
|
@@ -80,6 +83,8 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
|
80
83
|
return _handle(request, PushTaskResRequest, self._push_task_res)
|
|
81
84
|
if request.grpc_message_name == GetRunRequest.__qualname__:
|
|
82
85
|
return _handle(request, GetRunRequest, self._get_run)
|
|
86
|
+
if request.grpc_message_name == GetFabRequest.__qualname__:
|
|
87
|
+
return _handle(request, GetFabRequest, self._get_fab)
|
|
83
88
|
raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
|
|
84
89
|
|
|
85
90
|
def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
|
|
@@ -129,3 +134,11 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
|
129
134
|
request=request,
|
|
130
135
|
state=self.state_factory.state(),
|
|
131
136
|
)
|
|
137
|
+
|
|
138
|
+
def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
|
|
139
|
+
"""Get FAB."""
|
|
140
|
+
log(INFO, "GrpcAdapter.GetFab")
|
|
141
|
+
return message_handler.get_fab(
|
|
142
|
+
request=request,
|
|
143
|
+
ffs=self.ffs_factory.ffs(),
|
|
144
|
+
)
|
|
@@ -23,6 +23,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
26
|
+
from flwr.common.address import is_port_in_use
|
|
26
27
|
from flwr.common.logger import log
|
|
27
28
|
from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
|
|
28
29
|
add_FlowerServiceServicer_to_server,
|
|
@@ -218,6 +219,10 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
218
219
|
server : grpc.Server
|
|
219
220
|
A non-running instance of a gRPC server.
|
|
220
221
|
"""
|
|
222
|
+
# Check if port is in use
|
|
223
|
+
if is_port_in_use(server_address):
|
|
224
|
+
sys.exit(f"Port in server address {server_address} is already in use.")
|
|
225
|
+
|
|
221
226
|
# Deconstruct tuple into servicer and function
|
|
222
227
|
servicer, add_servicer_to_server_fn = servicer_and_add_fn
|
|
223
228
|
|
|
@@ -51,19 +51,22 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
51
51
|
self, request: CreateNodeRequest, context: grpc.ServicerContext
|
|
52
52
|
) -> CreateNodeResponse:
|
|
53
53
|
"""."""
|
|
54
|
-
log(INFO, "
|
|
54
|
+
log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
|
|
55
|
+
log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
|
|
55
56
|
response = message_handler.create_node(
|
|
56
57
|
request=request,
|
|
57
58
|
state=self.state_factory.state(),
|
|
58
59
|
)
|
|
59
|
-
log(INFO, "
|
|
60
|
+
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
|
61
|
+
log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
|
|
60
62
|
return response
|
|
61
63
|
|
|
62
64
|
def DeleteNode(
|
|
63
65
|
self, request: DeleteNodeRequest, context: grpc.ServicerContext
|
|
64
66
|
) -> DeleteNodeResponse:
|
|
65
67
|
"""."""
|
|
66
|
-
log(INFO, "
|
|
68
|
+
log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
|
|
69
|
+
log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
|
|
67
70
|
return message_handler.delete_node(
|
|
68
71
|
request=request,
|
|
69
72
|
state=self.state_factory.state(),
|
|
@@ -71,7 +74,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
71
74
|
|
|
72
75
|
def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
|
|
73
76
|
"""."""
|
|
74
|
-
log(DEBUG, "
|
|
77
|
+
log(DEBUG, "[Fleet.Ping] Request: %s", request)
|
|
75
78
|
return message_handler.ping(
|
|
76
79
|
request=request,
|
|
77
80
|
state=self.state_factory.state(),
|
|
@@ -81,7 +84,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
81
84
|
self, request: PullTaskInsRequest, context: grpc.ServicerContext
|
|
82
85
|
) -> PullTaskInsResponse:
|
|
83
86
|
"""Pull TaskIns."""
|
|
84
|
-
log(INFO, "
|
|
87
|
+
log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id)
|
|
88
|
+
log(DEBUG, "[Fleet.PullTaskIns] Request: %s", request)
|
|
85
89
|
return message_handler.pull_task_ins(
|
|
86
90
|
request=request,
|
|
87
91
|
state=self.state_factory.state(),
|
|
@@ -91,7 +95,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
91
95
|
self, request: PushTaskResRequest, context: grpc.ServicerContext
|
|
92
96
|
) -> PushTaskResResponse:
|
|
93
97
|
"""Push TaskRes."""
|
|
94
|
-
|
|
98
|
+
if request.task_res_list:
|
|
99
|
+
log(
|
|
100
|
+
INFO,
|
|
101
|
+
"[Fleet.PushTaskRes] Push results from node_id=%s",
|
|
102
|
+
request.task_res_list[0].task.producer.node_id,
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
|
95
106
|
return message_handler.push_task_res(
|
|
96
107
|
request=request,
|
|
97
108
|
state=self.state_factory.state(),
|
|
@@ -101,7 +112,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
101
112
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
102
113
|
) -> GetRunResponse:
|
|
103
114
|
"""Get run information."""
|
|
104
|
-
log(INFO, "
|
|
115
|
+
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
|
|
105
116
|
return message_handler.get_run(
|
|
106
117
|
request=request,
|
|
107
118
|
state=self.state_factory.state(),
|
|
@@ -111,7 +122,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
111
122
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
112
123
|
) -> GetFabResponse:
|
|
113
124
|
"""Get FAB."""
|
|
114
|
-
log(
|
|
125
|
+
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
|
|
115
126
|
return message_handler.get_fab(
|
|
116
127
|
request=request,
|
|
117
128
|
ffs=self.ffs_factory.ffs(),
|
|
@@ -78,13 +78,13 @@ def _get_value_from_tuples(
|
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
81
|
-
"""Server interceptor for
|
|
81
|
+
"""Server interceptor for node authentication."""
|
|
82
82
|
|
|
83
83
|
def __init__(self, state: State):
|
|
84
84
|
self.state = state
|
|
85
85
|
|
|
86
|
-
self.
|
|
87
|
-
if len(self.
|
|
86
|
+
self.node_public_keys = state.get_node_public_keys()
|
|
87
|
+
if len(self.node_public_keys) == 0:
|
|
88
88
|
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
89
89
|
|
|
90
90
|
private_key = self.state.get_server_private_key()
|
|
@@ -103,9 +103,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
103
103
|
) -> grpc.RpcMethodHandler:
|
|
104
104
|
"""Flower server interceptor authentication logic.
|
|
105
105
|
|
|
106
|
-
Intercept all unary calls from
|
|
107
|
-
|
|
108
|
-
|
|
106
|
+
Intercept all unary calls from nodes and authenticate nodes by validating auth
|
|
107
|
+
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
|
108
|
+
terminate RPC call by setting context to abort.
|
|
109
109
|
"""
|
|
110
110
|
# One of the method handlers in
|
|
111
111
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
@@ -119,17 +119,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
119
119
|
request: Request,
|
|
120
120
|
context: grpc.ServicerContext,
|
|
121
121
|
) -> Response:
|
|
122
|
-
|
|
122
|
+
node_public_key_bytes = base64.urlsafe_b64decode(
|
|
123
123
|
_get_value_from_tuples(
|
|
124
124
|
_PUBLIC_KEY_HEADER, context.invocation_metadata()
|
|
125
125
|
)
|
|
126
126
|
)
|
|
127
|
-
if
|
|
127
|
+
if node_public_key_bytes not in self.node_public_keys:
|
|
128
128
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
129
129
|
|
|
130
130
|
if isinstance(request, CreateNodeRequest):
|
|
131
131
|
response = self._create_authenticated_node(
|
|
132
|
-
|
|
132
|
+
node_public_key_bytes, request, context
|
|
133
133
|
)
|
|
134
134
|
log(
|
|
135
135
|
INFO,
|
|
@@ -144,13 +144,13 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
144
144
|
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
|
145
145
|
)
|
|
146
146
|
)
|
|
147
|
-
public_key = bytes_to_public_key(
|
|
147
|
+
public_key = bytes_to_public_key(node_public_key_bytes)
|
|
148
148
|
|
|
149
149
|
if not self._verify_hmac(public_key, request, hmac_value):
|
|
150
150
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
151
151
|
|
|
152
152
|
# Verify node_id
|
|
153
|
-
node_id = self.state.get_node_id(
|
|
153
|
+
node_id = self.state.get_node_id(node_public_key_bytes)
|
|
154
154
|
|
|
155
155
|
if not self._verify_node_id(node_id, request):
|
|
156
156
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
@@ -188,7 +188,8 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
188
188
|
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
|
189
189
|
) -> bool:
|
|
190
190
|
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
|
191
|
-
|
|
191
|
+
message_bytes = request.SerializeToString(deterministic=True)
|
|
192
|
+
return verify_hmac(shared_secret, message_bytes, hmac_value)
|
|
192
193
|
|
|
193
194
|
def _create_authenticated_node(
|
|
194
195
|
self,
|
|
@@ -15,17 +15,29 @@
|
|
|
15
15
|
"""Experimental REST API server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
import sys
|
|
21
|
+
from typing import Awaitable, Callable, TypeVar
|
|
22
|
+
|
|
23
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
19
24
|
|
|
20
25
|
from flwr.common.constant import MISSING_EXTRA_REST
|
|
26
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
21
27
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
22
28
|
CreateNodeRequest,
|
|
29
|
+
CreateNodeResponse,
|
|
23
30
|
DeleteNodeRequest,
|
|
31
|
+
DeleteNodeResponse,
|
|
24
32
|
PingRequest,
|
|
33
|
+
PingResponse,
|
|
25
34
|
PullTaskInsRequest,
|
|
35
|
+
PullTaskInsResponse,
|
|
26
36
|
PushTaskResRequest,
|
|
37
|
+
PushTaskResResponse,
|
|
27
38
|
)
|
|
28
|
-
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
39
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
|
+
from flwr.server.superlink.ffs.ffs import Ffs
|
|
29
41
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
30
42
|
from flwr.server.superlink.state import State
|
|
31
43
|
|
|
@@ -40,172 +52,108 @@ except ModuleNotFoundError:
|
|
|
40
52
|
sys.exit(MISSING_EXTRA_REST)
|
|
41
53
|
|
|
42
54
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
_check_headers(request.headers)
|
|
55
|
+
GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
|
|
56
|
+
GrpcResponse = TypeVar("GrpcResponse", bound=GrpcMessage)
|
|
46
57
|
|
|
47
|
-
|
|
48
|
-
|
|
58
|
+
GrpcAsyncFunction = Callable[[GrpcRequest], Awaitable[GrpcResponse]]
|
|
59
|
+
RestEndPoint = Callable[[Request], Awaitable[Response]]
|
|
49
60
|
|
|
50
|
-
# Deserialize ProtoBuf
|
|
51
|
-
create_node_request_proto = CreateNodeRequest()
|
|
52
|
-
create_node_request_proto.ParseFromString(create_node_request_bytes)
|
|
53
61
|
|
|
54
|
-
|
|
55
|
-
|
|
62
|
+
def rest_request_response(
|
|
63
|
+
grpc_request_type: type[GrpcRequest],
|
|
64
|
+
) -> Callable[[GrpcAsyncFunction[GrpcRequest, GrpcResponse]], RestEndPoint]:
|
|
65
|
+
"""Convert an async gRPC-based function into a RESTful HTTP endpoint."""
|
|
56
66
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
)
|
|
67
|
+
def decorator(func: GrpcAsyncFunction[GrpcRequest, GrpcResponse]) -> RestEndPoint:
|
|
68
|
+
async def wrapper(request: Request) -> Response:
|
|
69
|
+
_check_headers(request.headers)
|
|
61
70
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
return Response(
|
|
65
|
-
status_code=200,
|
|
66
|
-
content=create_node_response_bytes,
|
|
67
|
-
headers={"Content-Type": "application/protobuf"},
|
|
68
|
-
)
|
|
71
|
+
# Get the request body as raw bytes
|
|
72
|
+
grpc_req_bytes: bytes = await request.body()
|
|
69
73
|
|
|
74
|
+
# Deserialize ProtoBuf
|
|
75
|
+
grpc_req = grpc_request_type.FromString(grpc_req_bytes)
|
|
76
|
+
grpc_res = await func(grpc_req)
|
|
77
|
+
return Response(
|
|
78
|
+
status_code=200,
|
|
79
|
+
content=grpc_res.SerializeToString(),
|
|
80
|
+
headers={"Content-Type": "application/protobuf"},
|
|
81
|
+
)
|
|
70
82
|
|
|
71
|
-
|
|
72
|
-
"""Delete Node Id."""
|
|
73
|
-
_check_headers(request.headers)
|
|
83
|
+
return wrapper
|
|
74
84
|
|
|
75
|
-
|
|
76
|
-
delete_node_request_bytes: bytes = await request.body()
|
|
85
|
+
return decorator
|
|
77
86
|
|
|
78
|
-
# Deserialize ProtoBuf
|
|
79
|
-
delete_node_request_proto = DeleteNodeRequest()
|
|
80
|
-
delete_node_request_proto.ParseFromString(delete_node_request_bytes)
|
|
81
87
|
|
|
88
|
+
@rest_request_response(CreateNodeRequest)
|
|
89
|
+
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
90
|
+
"""Create Node."""
|
|
82
91
|
# Get state from app
|
|
83
92
|
state: State = app.state.STATE_FACTORY.state()
|
|
84
93
|
|
|
85
94
|
# Handle message
|
|
86
|
-
|
|
87
|
-
request=delete_node_request_proto, state=state
|
|
88
|
-
)
|
|
95
|
+
return message_handler.create_node(request=request, state=state)
|
|
89
96
|
|
|
90
|
-
# Return serialized ProtoBuf
|
|
91
|
-
delete_node_response_bytes = delete_node_response_proto.SerializeToString()
|
|
92
|
-
return Response(
|
|
93
|
-
status_code=200,
|
|
94
|
-
content=delete_node_response_bytes,
|
|
95
|
-
headers={"Content-Type": "application/protobuf"},
|
|
96
|
-
)
|
|
97
97
|
|
|
98
|
+
@rest_request_response(DeleteNodeRequest)
|
|
99
|
+
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
100
|
+
"""Delete Node Id."""
|
|
101
|
+
# Get state from app
|
|
102
|
+
state: State = app.state.STATE_FACTORY.state()
|
|
98
103
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
_check_headers(request.headers)
|
|
102
|
-
|
|
103
|
-
# Get the request body as raw bytes
|
|
104
|
-
pull_task_ins_request_bytes: bytes = await request.body()
|
|
104
|
+
# Handle message
|
|
105
|
+
return message_handler.delete_node(request=request, state=state)
|
|
105
106
|
|
|
106
|
-
# Deserialize ProtoBuf
|
|
107
|
-
pull_task_ins_request_proto = PullTaskInsRequest()
|
|
108
|
-
pull_task_ins_request_proto.ParseFromString(pull_task_ins_request_bytes)
|
|
109
107
|
|
|
108
|
+
@rest_request_response(PullTaskInsRequest)
|
|
109
|
+
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
110
|
+
"""Pull TaskIns."""
|
|
110
111
|
# Get state from app
|
|
111
112
|
state: State = app.state.STATE_FACTORY.state()
|
|
112
113
|
|
|
113
114
|
# Handle message
|
|
114
|
-
|
|
115
|
-
request=pull_task_ins_request_proto,
|
|
116
|
-
state=state,
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
# Return serialized ProtoBuf
|
|
120
|
-
pull_task_ins_response_bytes = pull_task_ins_response_proto.SerializeToString()
|
|
121
|
-
return Response(
|
|
122
|
-
status_code=200,
|
|
123
|
-
content=pull_task_ins_response_bytes,
|
|
124
|
-
headers={"Content-Type": "application/protobuf"},
|
|
125
|
-
)
|
|
115
|
+
return message_handler.pull_task_ins(request=request, state=state)
|
|
126
116
|
|
|
127
117
|
|
|
128
|
-
|
|
118
|
+
# Check if token is needed here
|
|
119
|
+
@rest_request_response(PushTaskResRequest)
|
|
120
|
+
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
129
121
|
"""Push TaskRes."""
|
|
130
|
-
_check_headers(request.headers)
|
|
131
|
-
|
|
132
|
-
# Get the request body as raw bytes
|
|
133
|
-
push_task_res_request_bytes: bytes = await request.body()
|
|
134
|
-
|
|
135
|
-
# Deserialize ProtoBuf
|
|
136
|
-
push_task_res_request_proto = PushTaskResRequest()
|
|
137
|
-
push_task_res_request_proto.ParseFromString(push_task_res_request_bytes)
|
|
138
|
-
|
|
139
122
|
# Get state from app
|
|
140
123
|
state: State = app.state.STATE_FACTORY.state()
|
|
141
124
|
|
|
142
125
|
# Handle message
|
|
143
|
-
|
|
144
|
-
request=push_task_res_request_proto,
|
|
145
|
-
state=state,
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
# Return serialized ProtoBuf
|
|
149
|
-
push_task_res_response_bytes = push_task_res_response_proto.SerializeToString()
|
|
150
|
-
return Response(
|
|
151
|
-
status_code=200,
|
|
152
|
-
content=push_task_res_response_bytes,
|
|
153
|
-
headers={"Content-Type": "application/protobuf"},
|
|
154
|
-
)
|
|
126
|
+
return message_handler.push_task_res(request=request, state=state)
|
|
155
127
|
|
|
156
128
|
|
|
157
|
-
|
|
129
|
+
@rest_request_response(PingRequest)
|
|
130
|
+
async def ping(request: PingRequest) -> PingResponse:
|
|
158
131
|
"""Ping."""
|
|
159
|
-
_check_headers(request.headers)
|
|
160
|
-
|
|
161
|
-
# Get the request body as raw bytes
|
|
162
|
-
ping_request_bytes: bytes = await request.body()
|
|
163
|
-
|
|
164
|
-
# Deserialize ProtoBuf
|
|
165
|
-
ping_request_proto = PingRequest()
|
|
166
|
-
ping_request_proto.ParseFromString(ping_request_bytes)
|
|
167
|
-
|
|
168
132
|
# Get state from app
|
|
169
133
|
state: State = app.state.STATE_FACTORY.state()
|
|
170
134
|
|
|
171
135
|
# Handle message
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
# Return serialized ProtoBuf
|
|
175
|
-
ping_response_bytes = ping_response_proto.SerializeToString()
|
|
176
|
-
return Response(
|
|
177
|
-
status_code=200,
|
|
178
|
-
content=ping_response_bytes,
|
|
179
|
-
headers={"Content-Type": "application/protobuf"},
|
|
180
|
-
)
|
|
136
|
+
return message_handler.ping(request=request, state=state)
|
|
181
137
|
|
|
182
138
|
|
|
183
|
-
|
|
139
|
+
@rest_request_response(GetRunRequest)
|
|
140
|
+
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
184
141
|
"""GetRun."""
|
|
185
|
-
_check_headers(request.headers)
|
|
186
|
-
|
|
187
|
-
# Get the request body as raw bytes
|
|
188
|
-
get_run_request_bytes: bytes = await request.body()
|
|
189
|
-
|
|
190
|
-
# Deserialize ProtoBuf
|
|
191
|
-
get_run_request_proto = GetRunRequest()
|
|
192
|
-
get_run_request_proto.ParseFromString(get_run_request_bytes)
|
|
193
|
-
|
|
194
142
|
# Get state from app
|
|
195
143
|
state: State = app.state.STATE_FACTORY.state()
|
|
196
144
|
|
|
197
145
|
# Handle message
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
)
|
|
146
|
+
return message_handler.get_run(request=request, state=state)
|
|
147
|
+
|
|
201
148
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
149
|
+
@rest_request_response(GetFabRequest)
|
|
150
|
+
async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
151
|
+
"""GetRun."""
|
|
152
|
+
# Get ffs from app
|
|
153
|
+
ffs: Ffs = app.state.FFS_FACTORY.state()
|
|
154
|
+
|
|
155
|
+
# Handle message
|
|
156
|
+
return message_handler.get_fab(request=request, ffs=ffs)
|
|
209
157
|
|
|
210
158
|
|
|
211
159
|
routes = [
|
|
@@ -215,6 +163,7 @@ routes = [
|
|
|
215
163
|
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
|
216
164
|
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
|
217
165
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
|
166
|
+
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
|
218
167
|
]
|
|
219
168
|
|
|
220
169
|
app: Starlette = Starlette(
|
|
@@ -33,7 +33,7 @@ class Backend(ABC):
|
|
|
33
33
|
"""Construct a backend."""
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
|
-
def build(self) -> None:
|
|
36
|
+
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
|
37
37
|
"""Build backend.
|
|
38
38
|
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
|
@@ -60,7 +60,6 @@ class Backend(ABC):
|
|
|
60
60
|
@abstractmethod
|
|
61
61
|
def process_message(
|
|
62
62
|
self,
|
|
63
|
-
app: Callable[[], ClientApp],
|
|
64
63
|
message: Message,
|
|
65
64
|
context: Context,
|
|
66
65
|
) -> Tuple[Message, Context]:
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import sys
|
|
18
18
|
from logging import DEBUG, ERROR
|
|
19
|
-
from typing import Callable, Dict, Tuple, Union
|
|
19
|
+
from typing import Callable, Dict, Optional, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
22
22
|
|
|
@@ -52,16 +52,13 @@ class RayBackend(Backend):
|
|
|
52
52
|
|
|
53
53
|
# Validate client resources
|
|
54
54
|
self.client_resources_key = "client_resources"
|
|
55
|
-
client_resources = self._validate_client_resources(config=backend_config)
|
|
55
|
+
self.client_resources = self._validate_client_resources(config=backend_config)
|
|
56
56
|
|
|
57
|
-
#
|
|
58
|
-
actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
57
|
+
# Valide actor resources
|
|
58
|
+
self.actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
59
|
+
self.pool: Optional[BasicActorPool] = None
|
|
59
60
|
|
|
60
|
-
self.
|
|
61
|
-
actor_type=ClientAppActor,
|
|
62
|
-
client_resources=client_resources,
|
|
63
|
-
actor_kwargs=actor_kwargs,
|
|
64
|
-
)
|
|
61
|
+
self.app_fn: Optional[Callable[[], ClientApp]] = None
|
|
65
62
|
|
|
66
63
|
def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
|
|
67
64
|
client_resources_config = config.get(self.client_resources_key)
|
|
@@ -120,20 +117,31 @@ class RayBackend(Backend):
|
|
|
120
117
|
@property
|
|
121
118
|
def num_workers(self) -> int:
|
|
122
119
|
"""Return number of actors in pool."""
|
|
123
|
-
return self.pool.num_actors
|
|
120
|
+
return self.pool.num_actors if self.pool else 0
|
|
124
121
|
|
|
125
122
|
def is_worker_idle(self) -> bool:
|
|
126
123
|
"""Report whether the pool has idle actors."""
|
|
127
|
-
return self.pool.is_actor_available()
|
|
124
|
+
return self.pool.is_actor_available() if self.pool else False
|
|
128
125
|
|
|
129
|
-
def build(self) -> None:
|
|
126
|
+
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
|
130
127
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
128
|
+
# Create Actor Pool
|
|
129
|
+
try:
|
|
130
|
+
self.pool = BasicActorPool(
|
|
131
|
+
actor_type=ClientAppActor,
|
|
132
|
+
client_resources=self.client_resources,
|
|
133
|
+
actor_kwargs=self.actor_kwargs,
|
|
134
|
+
)
|
|
135
|
+
except Exception as ex:
|
|
136
|
+
raise ex
|
|
137
|
+
|
|
131
138
|
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
139
|
+
# Set ClientApp callable that ray actors will use
|
|
140
|
+
self.app_fn = app_fn
|
|
132
141
|
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
133
142
|
|
|
134
143
|
def process_message(
|
|
135
144
|
self,
|
|
136
|
-
app: Callable[[], ClientApp],
|
|
137
145
|
message: Message,
|
|
138
146
|
context: Context,
|
|
139
147
|
) -> Tuple[Message, Context]:
|
|
@@ -143,11 +151,20 @@ class RayBackend(Backend):
|
|
|
143
151
|
"""
|
|
144
152
|
partition_id = context.node_config[PARTITION_ID_KEY]
|
|
145
153
|
|
|
154
|
+
if self.pool is None:
|
|
155
|
+
raise ValueError("The actor pool is empty, unfit to process messages.")
|
|
156
|
+
|
|
157
|
+
if self.app_fn is None:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"Unspecified function to load a `ClientApp`. "
|
|
160
|
+
"Call the backend's `build()` method before processing messages."
|
|
161
|
+
)
|
|
162
|
+
|
|
146
163
|
try:
|
|
147
164
|
# Submit a task to the pool
|
|
148
165
|
future = self.pool.submit(
|
|
149
166
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
150
|
-
(
|
|
167
|
+
(self.app_fn, message, str(partition_id), context),
|
|
151
168
|
)
|
|
152
169
|
|
|
153
170
|
# Fetch result
|
|
@@ -170,6 +187,7 @@ class RayBackend(Backend):
|
|
|
170
187
|
|
|
171
188
|
def terminate(self) -> None:
|
|
172
189
|
"""Terminate all actors in actor pool."""
|
|
173
|
-
self.pool
|
|
190
|
+
if self.pool:
|
|
191
|
+
self.pool.terminate_all_actors()
|
|
174
192
|
ray.shutdown()
|
|
175
193
|
log(DEBUG, "Terminated %s", self.__class__.__name__)
|
|
@@ -87,7 +87,6 @@ def _register_node_states(
|
|
|
87
87
|
|
|
88
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
89
89
|
def worker(
|
|
90
|
-
app_fn: Callable[[], ClientApp],
|
|
91
90
|
taskins_queue: "Queue[TaskIns]",
|
|
92
91
|
taskres_queue: "Queue[TaskRes]",
|
|
93
92
|
node_states: Dict[int, NodeState],
|
|
@@ -110,9 +109,7 @@ def worker(
|
|
|
110
109
|
message = message_from_taskins(task_ins)
|
|
111
110
|
|
|
112
111
|
# Let backend process message
|
|
113
|
-
out_mssg, updated_context = backend.process_message(
|
|
114
|
-
app_fn, message, context
|
|
115
|
-
)
|
|
112
|
+
out_mssg, updated_context = backend.process_message(message, context)
|
|
116
113
|
|
|
117
114
|
# Update Context
|
|
118
115
|
node_states[node_id].update_context(
|
|
@@ -193,7 +190,7 @@ def run_api(
|
|
|
193
190
|
backend = backend_fn()
|
|
194
191
|
|
|
195
192
|
# Build backend
|
|
196
|
-
backend.build()
|
|
193
|
+
backend.build(app_fn)
|
|
197
194
|
|
|
198
195
|
# Add workers (they submit Messages to Backend)
|
|
199
196
|
state = state_factory.state()
|
|
@@ -223,7 +220,6 @@ def run_api(
|
|
|
223
220
|
_ = [
|
|
224
221
|
executor.submit(
|
|
225
222
|
worker,
|
|
226
|
-
app_fn,
|
|
227
223
|
taskins_queue,
|
|
228
224
|
taskres_queue,
|
|
229
225
|
node_states,
|