flwr-nightly 1.11.0.dev20240822__py3-none-any.whl → 1.11.1.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (75) hide show
  1. flwr/cli/app.py +0 -2
  2. flwr/cli/build.py +1 -1
  3. flwr/cli/new/new.py +41 -40
  4. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  5. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  6. flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
  7. flwr/cli/new/templates/app/README.md.tpl +7 -30
  8. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  9. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  10. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
  12. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  13. flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
  14. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
  15. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
  16. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
  17. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  18. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  19. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  20. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
  21. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  22. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
  23. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  24. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  25. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
  26. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
  27. flwr/cli/run/run.py +12 -2
  28. flwr/client/__init__.py +0 -4
  29. flwr/client/app.py +3 -4
  30. flwr/client/client.py +22 -1
  31. flwr/client/client_app.py +2 -2
  32. flwr/client/grpc_rere_client/client_interceptor.py +15 -7
  33. flwr/client/numpy_client.py +22 -1
  34. flwr/client/rest_client/connection.py +1 -1
  35. flwr/client/supernode/app.py +8 -7
  36. flwr/common/address.py +43 -0
  37. flwr/common/config.py +14 -11
  38. flwr/common/constant.py +12 -1
  39. flwr/common/record/recordset.py +1 -1
  40. flwr/common/record/typeddict.py +24 -1
  41. flwr/common/telemetry.py +36 -30
  42. flwr/server/__init__.py +0 -4
  43. flwr/server/app.py +27 -22
  44. flwr/server/compat/app.py +0 -5
  45. flwr/server/driver/grpc_driver.py +3 -6
  46. flwr/server/run_serverapp.py +20 -7
  47. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +15 -2
  48. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -0
  49. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
  50. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +13 -12
  51. flwr/server/superlink/fleet/rest_rere/rest_api.py +71 -122
  52. flwr/server/superlink/fleet/vce/backend/backend.py +1 -2
  53. flwr/server/superlink/fleet/vce/backend/raybackend.py +33 -15
  54. flwr/server/superlink/fleet/vce/vce_api.py +2 -6
  55. flwr/server/superlink/state/in_memory_state.py +15 -15
  56. flwr/server/superlink/state/sqlite_state.py +10 -10
  57. flwr/server/superlink/state/state.py +8 -8
  58. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
  59. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -0
  60. flwr/simulation/ray_transport/ray_actor.py +2 -2
  61. flwr/simulation/run_simulation.py +85 -25
  62. flwr/superexec/__init__.py +0 -6
  63. flwr/superexec/app.py +5 -3
  64. flwr/superexec/deployment.py +2 -2
  65. flwr/superexec/simulation.py +20 -1
  66. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/METADATA +3 -3
  67. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/RECORD +70 -62
  68. flwr_nightly-1.11.1.dev20240912.dist-info/entry_points.txt +10 -0
  69. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
  70. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
  71. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
  72. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
  73. flwr_nightly-1.11.0.dev20240822.dist-info/entry_points.txt +0 -10
  74. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/LICENSE +0 -0
  75. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/WHEEL +0 -0
@@ -23,6 +23,7 @@ from google.protobuf.message import Message as GrpcMessage
23
23
 
24
24
  from flwr.common.logger import log
25
25
  from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
26
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
26
27
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
27
28
  CreateNodeRequest,
28
29
  CreateNodeResponse,
@@ -37,6 +38,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
38
  )
38
39
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
39
40
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
41
+ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
40
42
  from flwr.server.superlink.fleet.message_handler import message_handler
41
43
  from flwr.server.superlink.state import StateFactory
42
44
 
@@ -60,10 +62,11 @@ def _handle(
60
62
  class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
61
63
  """Fleet API via GrpcAdapter servicer."""
62
64
 
63
- def __init__(self, state_factory: StateFactory) -> None:
65
+ def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
64
66
  self.state_factory = state_factory
67
+ self.ffs_factory = ffs_factory
65
68
 
66
- def SendReceive(
69
+ def SendReceive( # pylint: disable=too-many-return-statements
67
70
  self, request: MessageContainer, context: grpc.ServicerContext
68
71
  ) -> MessageContainer:
69
72
  """."""
@@ -80,6 +83,8 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
80
83
  return _handle(request, PushTaskResRequest, self._push_task_res)
81
84
  if request.grpc_message_name == GetRunRequest.__qualname__:
82
85
  return _handle(request, GetRunRequest, self._get_run)
86
+ if request.grpc_message_name == GetFabRequest.__qualname__:
87
+ return _handle(request, GetFabRequest, self._get_fab)
83
88
  raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
84
89
 
85
90
  def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
@@ -129,3 +134,11 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
129
134
  request=request,
130
135
  state=self.state_factory.state(),
131
136
  )
137
+
138
+ def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
139
+ """Get FAB."""
140
+ log(INFO, "GrpcAdapter.GetFab")
141
+ return message_handler.get_fab(
142
+ request=request,
143
+ ffs=self.ffs_factory.ffs(),
144
+ )
@@ -23,6 +23,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union
23
23
  import grpc
24
24
 
25
25
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
26
+ from flwr.common.address import is_port_in_use
26
27
  from flwr.common.logger import log
27
28
  from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
28
29
  add_FlowerServiceServicer_to_server,
@@ -218,6 +219,10 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
218
219
  server : grpc.Server
219
220
  A non-running instance of a gRPC server.
220
221
  """
222
+ # Check if port is in use
223
+ if is_port_in_use(server_address):
224
+ sys.exit(f"Port in server address {server_address} is already in use.")
225
+
221
226
  # Deconstruct tuple into servicer and function
222
227
  servicer, add_servicer_to_server_fn = servicer_and_add_fn
223
228
 
@@ -51,19 +51,22 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
51
51
  self, request: CreateNodeRequest, context: grpc.ServicerContext
52
52
  ) -> CreateNodeResponse:
53
53
  """."""
54
- log(INFO, "FleetServicer.CreateNode")
54
+ log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
55
+ log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
55
56
  response = message_handler.create_node(
56
57
  request=request,
57
58
  state=self.state_factory.state(),
58
59
  )
59
- log(INFO, "FleetServicer: Created node_id=%s", response.node.node_id)
60
+ log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
61
+ log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
60
62
  return response
61
63
 
62
64
  def DeleteNode(
63
65
  self, request: DeleteNodeRequest, context: grpc.ServicerContext
64
66
  ) -> DeleteNodeResponse:
65
67
  """."""
66
- log(INFO, "FleetServicer.DeleteNode")
68
+ log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
69
+ log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
67
70
  return message_handler.delete_node(
68
71
  request=request,
69
72
  state=self.state_factory.state(),
@@ -71,7 +74,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
71
74
 
72
75
  def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
73
76
  """."""
74
- log(DEBUG, "FleetServicer.Ping")
77
+ log(DEBUG, "[Fleet.Ping] Request: %s", request)
75
78
  return message_handler.ping(
76
79
  request=request,
77
80
  state=self.state_factory.state(),
@@ -81,7 +84,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
81
84
  self, request: PullTaskInsRequest, context: grpc.ServicerContext
82
85
  ) -> PullTaskInsResponse:
83
86
  """Pull TaskIns."""
84
- log(INFO, "FleetServicer.PullTaskIns")
87
+ log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id)
88
+ log(DEBUG, "[Fleet.PullTaskIns] Request: %s", request)
85
89
  return message_handler.pull_task_ins(
86
90
  request=request,
87
91
  state=self.state_factory.state(),
@@ -91,7 +95,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
91
95
  self, request: PushTaskResRequest, context: grpc.ServicerContext
92
96
  ) -> PushTaskResResponse:
93
97
  """Push TaskRes."""
94
- log(INFO, "FleetServicer.PushTaskRes")
98
+ if request.task_res_list:
99
+ log(
100
+ INFO,
101
+ "[Fleet.PushTaskRes] Push results from node_id=%s",
102
+ request.task_res_list[0].task.producer.node_id,
103
+ )
104
+ else:
105
+ log(INFO, "[Fleet.PushTaskRes] No task results to push")
95
106
  return message_handler.push_task_res(
96
107
  request=request,
97
108
  state=self.state_factory.state(),
@@ -101,7 +112,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
101
112
  self, request: GetRunRequest, context: grpc.ServicerContext
102
113
  ) -> GetRunResponse:
103
114
  """Get run information."""
104
- log(INFO, "FleetServicer.GetRun")
115
+ log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
105
116
  return message_handler.get_run(
106
117
  request=request,
107
118
  state=self.state_factory.state(),
@@ -111,7 +122,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
111
122
  self, request: GetFabRequest, context: grpc.ServicerContext
112
123
  ) -> GetFabResponse:
113
124
  """Get FAB."""
114
- log(DEBUG, "DriverServicer.GetFab")
125
+ log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
115
126
  return message_handler.get_fab(
116
127
  request=request,
117
128
  ffs=self.ffs_factory.ffs(),
@@ -78,13 +78,13 @@ def _get_value_from_tuples(
78
78
 
79
79
 
80
80
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
81
- """Server interceptor for client authentication."""
81
+ """Server interceptor for node authentication."""
82
82
 
83
83
  def __init__(self, state: State):
84
84
  self.state = state
85
85
 
86
- self.client_public_keys = state.get_client_public_keys()
87
- if len(self.client_public_keys) == 0:
86
+ self.node_public_keys = state.get_node_public_keys()
87
+ if len(self.node_public_keys) == 0:
88
88
  log(WARNING, "Authentication enabled, but no known public keys configured")
89
89
 
90
90
  private_key = self.state.get_server_private_key()
@@ -103,9 +103,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
103
103
  ) -> grpc.RpcMethodHandler:
104
104
  """Flower server interceptor authentication logic.
105
105
 
106
- Intercept all unary calls from clients and authenticate clients by validating
107
- auth metadata sent by the client. Continue RPC call if client is authenticated,
108
- else, terminate RPC call by setting context to abort.
106
+ Intercept all unary calls from nodes and authenticate nodes by validating auth
107
+ metadata sent by the node. Continue RPC call if node is authenticated, else,
108
+ terminate RPC call by setting context to abort.
109
109
  """
110
110
  # One of the method handlers in
111
111
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
@@ -119,17 +119,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
119
119
  request: Request,
120
120
  context: grpc.ServicerContext,
121
121
  ) -> Response:
122
- client_public_key_bytes = base64.urlsafe_b64decode(
122
+ node_public_key_bytes = base64.urlsafe_b64decode(
123
123
  _get_value_from_tuples(
124
124
  _PUBLIC_KEY_HEADER, context.invocation_metadata()
125
125
  )
126
126
  )
127
- if client_public_key_bytes not in self.client_public_keys:
127
+ if node_public_key_bytes not in self.node_public_keys:
128
128
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
129
129
 
130
130
  if isinstance(request, CreateNodeRequest):
131
131
  response = self._create_authenticated_node(
132
- client_public_key_bytes, request, context
132
+ node_public_key_bytes, request, context
133
133
  )
134
134
  log(
135
135
  INFO,
@@ -144,13 +144,13 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
144
144
  _AUTH_TOKEN_HEADER, context.invocation_metadata()
145
145
  )
146
146
  )
147
- public_key = bytes_to_public_key(client_public_key_bytes)
147
+ public_key = bytes_to_public_key(node_public_key_bytes)
148
148
 
149
149
  if not self._verify_hmac(public_key, request, hmac_value):
150
150
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
151
151
 
152
152
  # Verify node_id
153
- node_id = self.state.get_node_id(client_public_key_bytes)
153
+ node_id = self.state.get_node_id(node_public_key_bytes)
154
154
 
155
155
  if not self._verify_node_id(node_id, request):
156
156
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
@@ -188,7 +188,8 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
188
188
  self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
189
189
  ) -> bool:
190
190
  shared_secret = generate_shared_key(self.server_private_key, public_key)
191
- return verify_hmac(shared_secret, request.SerializeToString(True), hmac_value)
191
+ message_bytes = request.SerializeToString(deterministic=True)
192
+ return verify_hmac(shared_secret, message_bytes, hmac_value)
192
193
 
193
194
  def _create_authenticated_node(
194
195
  self,
@@ -15,17 +15,29 @@
15
15
  """Experimental REST API server."""
16
16
 
17
17
 
18
+ from __future__ import annotations
19
+
18
20
  import sys
21
+ from typing import Awaitable, Callable, TypeVar
22
+
23
+ from google.protobuf.message import Message as GrpcMessage
19
24
 
20
25
  from flwr.common.constant import MISSING_EXTRA_REST
26
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
21
27
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
22
28
  CreateNodeRequest,
29
+ CreateNodeResponse,
23
30
  DeleteNodeRequest,
31
+ DeleteNodeResponse,
24
32
  PingRequest,
33
+ PingResponse,
25
34
  PullTaskInsRequest,
35
+ PullTaskInsResponse,
26
36
  PushTaskResRequest,
37
+ PushTaskResResponse,
27
38
  )
28
- from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
39
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
+ from flwr.server.superlink.ffs.ffs import Ffs
29
41
  from flwr.server.superlink.fleet.message_handler import message_handler
30
42
  from flwr.server.superlink.state import State
31
43
 
@@ -40,172 +52,108 @@ except ModuleNotFoundError:
40
52
  sys.exit(MISSING_EXTRA_REST)
41
53
 
42
54
 
43
- async def create_node(request: Request) -> Response:
44
- """Create Node."""
45
- _check_headers(request.headers)
55
+ GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
56
+ GrpcResponse = TypeVar("GrpcResponse", bound=GrpcMessage)
46
57
 
47
- # Get the request body as raw bytes
48
- create_node_request_bytes: bytes = await request.body()
58
+ GrpcAsyncFunction = Callable[[GrpcRequest], Awaitable[GrpcResponse]]
59
+ RestEndPoint = Callable[[Request], Awaitable[Response]]
49
60
 
50
- # Deserialize ProtoBuf
51
- create_node_request_proto = CreateNodeRequest()
52
- create_node_request_proto.ParseFromString(create_node_request_bytes)
53
61
 
54
- # Get state from app
55
- state: State = app.state.STATE_FACTORY.state()
62
+ def rest_request_response(
63
+ grpc_request_type: type[GrpcRequest],
64
+ ) -> Callable[[GrpcAsyncFunction[GrpcRequest, GrpcResponse]], RestEndPoint]:
65
+ """Convert an async gRPC-based function into a RESTful HTTP endpoint."""
56
66
 
57
- # Handle message
58
- create_node_response_proto = message_handler.create_node(
59
- request=create_node_request_proto, state=state
60
- )
67
+ def decorator(func: GrpcAsyncFunction[GrpcRequest, GrpcResponse]) -> RestEndPoint:
68
+ async def wrapper(request: Request) -> Response:
69
+ _check_headers(request.headers)
61
70
 
62
- # Return serialized ProtoBuf
63
- create_node_response_bytes = create_node_response_proto.SerializeToString()
64
- return Response(
65
- status_code=200,
66
- content=create_node_response_bytes,
67
- headers={"Content-Type": "application/protobuf"},
68
- )
71
+ # Get the request body as raw bytes
72
+ grpc_req_bytes: bytes = await request.body()
69
73
 
74
+ # Deserialize ProtoBuf
75
+ grpc_req = grpc_request_type.FromString(grpc_req_bytes)
76
+ grpc_res = await func(grpc_req)
77
+ return Response(
78
+ status_code=200,
79
+ content=grpc_res.SerializeToString(),
80
+ headers={"Content-Type": "application/protobuf"},
81
+ )
70
82
 
71
- async def delete_node(request: Request) -> Response:
72
- """Delete Node Id."""
73
- _check_headers(request.headers)
83
+ return wrapper
74
84
 
75
- # Get the request body as raw bytes
76
- delete_node_request_bytes: bytes = await request.body()
85
+ return decorator
77
86
 
78
- # Deserialize ProtoBuf
79
- delete_node_request_proto = DeleteNodeRequest()
80
- delete_node_request_proto.ParseFromString(delete_node_request_bytes)
81
87
 
88
+ @rest_request_response(CreateNodeRequest)
89
+ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
90
+ """Create Node."""
82
91
  # Get state from app
83
92
  state: State = app.state.STATE_FACTORY.state()
84
93
 
85
94
  # Handle message
86
- delete_node_response_proto = message_handler.delete_node(
87
- request=delete_node_request_proto, state=state
88
- )
95
+ return message_handler.create_node(request=request, state=state)
89
96
 
90
- # Return serialized ProtoBuf
91
- delete_node_response_bytes = delete_node_response_proto.SerializeToString()
92
- return Response(
93
- status_code=200,
94
- content=delete_node_response_bytes,
95
- headers={"Content-Type": "application/protobuf"},
96
- )
97
97
 
98
+ @rest_request_response(DeleteNodeRequest)
99
+ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
100
+ """Delete Node Id."""
101
+ # Get state from app
102
+ state: State = app.state.STATE_FACTORY.state()
98
103
 
99
- async def pull_task_ins(request: Request) -> Response:
100
- """Pull TaskIns."""
101
- _check_headers(request.headers)
102
-
103
- # Get the request body as raw bytes
104
- pull_task_ins_request_bytes: bytes = await request.body()
104
+ # Handle message
105
+ return message_handler.delete_node(request=request, state=state)
105
106
 
106
- # Deserialize ProtoBuf
107
- pull_task_ins_request_proto = PullTaskInsRequest()
108
- pull_task_ins_request_proto.ParseFromString(pull_task_ins_request_bytes)
109
107
 
108
+ @rest_request_response(PullTaskInsRequest)
109
+ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
110
+ """Pull TaskIns."""
110
111
  # Get state from app
111
112
  state: State = app.state.STATE_FACTORY.state()
112
113
 
113
114
  # Handle message
114
- pull_task_ins_response_proto = message_handler.pull_task_ins(
115
- request=pull_task_ins_request_proto,
116
- state=state,
117
- )
118
-
119
- # Return serialized ProtoBuf
120
- pull_task_ins_response_bytes = pull_task_ins_response_proto.SerializeToString()
121
- return Response(
122
- status_code=200,
123
- content=pull_task_ins_response_bytes,
124
- headers={"Content-Type": "application/protobuf"},
125
- )
115
+ return message_handler.pull_task_ins(request=request, state=state)
126
116
 
127
117
 
128
- async def push_task_res(request: Request) -> Response: # Check if token is needed here
118
+ # Check if token is needed here
119
+ @rest_request_response(PushTaskResRequest)
120
+ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
129
121
  """Push TaskRes."""
130
- _check_headers(request.headers)
131
-
132
- # Get the request body as raw bytes
133
- push_task_res_request_bytes: bytes = await request.body()
134
-
135
- # Deserialize ProtoBuf
136
- push_task_res_request_proto = PushTaskResRequest()
137
- push_task_res_request_proto.ParseFromString(push_task_res_request_bytes)
138
-
139
122
  # Get state from app
140
123
  state: State = app.state.STATE_FACTORY.state()
141
124
 
142
125
  # Handle message
143
- push_task_res_response_proto = message_handler.push_task_res(
144
- request=push_task_res_request_proto,
145
- state=state,
146
- )
147
-
148
- # Return serialized ProtoBuf
149
- push_task_res_response_bytes = push_task_res_response_proto.SerializeToString()
150
- return Response(
151
- status_code=200,
152
- content=push_task_res_response_bytes,
153
- headers={"Content-Type": "application/protobuf"},
154
- )
126
+ return message_handler.push_task_res(request=request, state=state)
155
127
 
156
128
 
157
- async def ping(request: Request) -> Response:
129
+ @rest_request_response(PingRequest)
130
+ async def ping(request: PingRequest) -> PingResponse:
158
131
  """Ping."""
159
- _check_headers(request.headers)
160
-
161
- # Get the request body as raw bytes
162
- ping_request_bytes: bytes = await request.body()
163
-
164
- # Deserialize ProtoBuf
165
- ping_request_proto = PingRequest()
166
- ping_request_proto.ParseFromString(ping_request_bytes)
167
-
168
132
  # Get state from app
169
133
  state: State = app.state.STATE_FACTORY.state()
170
134
 
171
135
  # Handle message
172
- ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
173
-
174
- # Return serialized ProtoBuf
175
- ping_response_bytes = ping_response_proto.SerializeToString()
176
- return Response(
177
- status_code=200,
178
- content=ping_response_bytes,
179
- headers={"Content-Type": "application/protobuf"},
180
- )
136
+ return message_handler.ping(request=request, state=state)
181
137
 
182
138
 
183
- async def get_run(request: Request) -> Response:
139
+ @rest_request_response(GetRunRequest)
140
+ async def get_run(request: GetRunRequest) -> GetRunResponse:
184
141
  """GetRun."""
185
- _check_headers(request.headers)
186
-
187
- # Get the request body as raw bytes
188
- get_run_request_bytes: bytes = await request.body()
189
-
190
- # Deserialize ProtoBuf
191
- get_run_request_proto = GetRunRequest()
192
- get_run_request_proto.ParseFromString(get_run_request_bytes)
193
-
194
142
  # Get state from app
195
143
  state: State = app.state.STATE_FACTORY.state()
196
144
 
197
145
  # Handle message
198
- get_run_response_proto = message_handler.get_run(
199
- request=get_run_request_proto, state=state
200
- )
146
+ return message_handler.get_run(request=request, state=state)
147
+
201
148
 
202
- # Return serialized ProtoBuf
203
- get_run_response_bytes = get_run_response_proto.SerializeToString()
204
- return Response(
205
- status_code=200,
206
- content=get_run_response_bytes,
207
- headers={"Content-Type": "application/protobuf"},
208
- )
149
+ @rest_request_response(GetFabRequest)
150
+ async def get_fab(request: GetFabRequest) -> GetFabResponse:
151
+ """GetRun."""
152
+ # Get ffs from app
153
+ ffs: Ffs = app.state.FFS_FACTORY.state()
154
+
155
+ # Handle message
156
+ return message_handler.get_fab(request=request, ffs=ffs)
209
157
 
210
158
 
211
159
  routes = [
@@ -215,6 +163,7 @@ routes = [
215
163
  Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
216
164
  Route("/api/v0/fleet/ping", ping, methods=["POST"]),
217
165
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
166
+ Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
218
167
  ]
219
168
 
220
169
  app: Starlette = Starlette(
@@ -33,7 +33,7 @@ class Backend(ABC):
33
33
  """Construct a backend."""
34
34
 
35
35
  @abstractmethod
36
- def build(self) -> None:
36
+ def build(self, app_fn: Callable[[], ClientApp]) -> None:
37
37
  """Build backend.
38
38
 
39
39
  Different components need to be in place before workers in a backend are ready
@@ -60,7 +60,6 @@ class Backend(ABC):
60
60
  @abstractmethod
61
61
  def process_message(
62
62
  self,
63
- app: Callable[[], ClientApp],
64
63
  message: Message,
65
64
  context: Context,
66
65
  ) -> Tuple[Message, Context]:
@@ -16,7 +16,7 @@
16
16
 
17
17
  import sys
18
18
  from logging import DEBUG, ERROR
19
- from typing import Callable, Dict, Tuple, Union
19
+ from typing import Callable, Dict, Optional, Tuple, Union
20
20
 
21
21
  import ray
22
22
 
@@ -52,16 +52,13 @@ class RayBackend(Backend):
52
52
 
53
53
  # Validate client resources
54
54
  self.client_resources_key = "client_resources"
55
- client_resources = self._validate_client_resources(config=backend_config)
55
+ self.client_resources = self._validate_client_resources(config=backend_config)
56
56
 
57
- # Create actor pool
58
- actor_kwargs = self._validate_actor_arguments(config=backend_config)
57
+ # Valide actor resources
58
+ self.actor_kwargs = self._validate_actor_arguments(config=backend_config)
59
+ self.pool: Optional[BasicActorPool] = None
59
60
 
60
- self.pool = BasicActorPool(
61
- actor_type=ClientAppActor,
62
- client_resources=client_resources,
63
- actor_kwargs=actor_kwargs,
64
- )
61
+ self.app_fn: Optional[Callable[[], ClientApp]] = None
65
62
 
66
63
  def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
67
64
  client_resources_config = config.get(self.client_resources_key)
@@ -120,20 +117,31 @@ class RayBackend(Backend):
120
117
  @property
121
118
  def num_workers(self) -> int:
122
119
  """Return number of actors in pool."""
123
- return self.pool.num_actors
120
+ return self.pool.num_actors if self.pool else 0
124
121
 
125
122
  def is_worker_idle(self) -> bool:
126
123
  """Report whether the pool has idle actors."""
127
- return self.pool.is_actor_available()
124
+ return self.pool.is_actor_available() if self.pool else False
128
125
 
129
- def build(self) -> None:
126
+ def build(self, app_fn: Callable[[], ClientApp]) -> None:
130
127
  """Build pool of Ray actors that this backend will submit jobs to."""
128
+ # Create Actor Pool
129
+ try:
130
+ self.pool = BasicActorPool(
131
+ actor_type=ClientAppActor,
132
+ client_resources=self.client_resources,
133
+ actor_kwargs=self.actor_kwargs,
134
+ )
135
+ except Exception as ex:
136
+ raise ex
137
+
131
138
  self.pool.add_actors_to_pool(self.pool.actors_capacity)
139
+ # Set ClientApp callable that ray actors will use
140
+ self.app_fn = app_fn
132
141
  log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
133
142
 
134
143
  def process_message(
135
144
  self,
136
- app: Callable[[], ClientApp],
137
145
  message: Message,
138
146
  context: Context,
139
147
  ) -> Tuple[Message, Context]:
@@ -143,11 +151,20 @@ class RayBackend(Backend):
143
151
  """
144
152
  partition_id = context.node_config[PARTITION_ID_KEY]
145
153
 
154
+ if self.pool is None:
155
+ raise ValueError("The actor pool is empty, unfit to process messages.")
156
+
157
+ if self.app_fn is None:
158
+ raise ValueError(
159
+ "Unspecified function to load a `ClientApp`. "
160
+ "Call the backend's `build()` method before processing messages."
161
+ )
162
+
146
163
  try:
147
164
  # Submit a task to the pool
148
165
  future = self.pool.submit(
149
166
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
150
- (app, message, str(partition_id), context),
167
+ (self.app_fn, message, str(partition_id), context),
151
168
  )
152
169
 
153
170
  # Fetch result
@@ -170,6 +187,7 @@ class RayBackend(Backend):
170
187
 
171
188
  def terminate(self) -> None:
172
189
  """Terminate all actors in actor pool."""
173
- self.pool.terminate_all_actors()
190
+ if self.pool:
191
+ self.pool.terminate_all_actors()
174
192
  ray.shutdown()
175
193
  log(DEBUG, "Terminated %s", self.__class__.__name__)
@@ -87,7 +87,6 @@ def _register_node_states(
87
87
 
88
88
  # pylint: disable=too-many-arguments,too-many-locals
89
89
  def worker(
90
- app_fn: Callable[[], ClientApp],
91
90
  taskins_queue: "Queue[TaskIns]",
92
91
  taskres_queue: "Queue[TaskRes]",
93
92
  node_states: Dict[int, NodeState],
@@ -110,9 +109,7 @@ def worker(
110
109
  message = message_from_taskins(task_ins)
111
110
 
112
111
  # Let backend process message
113
- out_mssg, updated_context = backend.process_message(
114
- app_fn, message, context
115
- )
112
+ out_mssg, updated_context = backend.process_message(message, context)
116
113
 
117
114
  # Update Context
118
115
  node_states[node_id].update_context(
@@ -193,7 +190,7 @@ def run_api(
193
190
  backend = backend_fn()
194
191
 
195
192
  # Build backend
196
- backend.build()
193
+ backend.build(app_fn)
197
194
 
198
195
  # Add workers (they submit Messages to Backend)
199
196
  state = state_factory.state()
@@ -223,7 +220,6 @@ def run_api(
223
220
  _ = [
224
221
  executor.submit(
225
222
  worker,
226
- app_fn,
227
223
  taskins_queue,
228
224
  taskres_queue,
229
225
  node_states,