flwr-nightly 1.10.0.dev20240709__py3-none-any.whl → 1.10.0.dev20240711__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

flwr/cli/config_utils.py CHANGED
@@ -108,6 +108,14 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
108
108
  return load_from_string(toml_file.read())
109
109
 
110
110
 
111
+ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None:
112
+ for key, value in config_dict.items():
113
+ if isinstance(value, dict):
114
+ _validate_run_config(config_dict[key], errors)
115
+ elif not isinstance(value, str):
116
+ errors.append(f"Config value of key {key} is not of type `str`.")
117
+
118
+
111
119
  # pylint: disable=too-many-branches
112
120
  def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
113
121
  """Validate pyproject.toml fields."""
@@ -133,6 +141,8 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
133
141
  else:
134
142
  if "publisher" not in config["flower"]:
135
143
  errors.append('Property "publisher" missing in [flower]')
144
+ if "config" in config["flower"]:
145
+ _validate_run_config(config["flower"]["config"], errors)
136
146
  if "components" not in config["flower"]:
137
147
  errors.append("Missing [flower.components] section")
138
148
  else:
flwr/cli/new/new.py CHANGED
@@ -264,9 +264,11 @@ def new(
264
264
  bold=True,
265
265
  )
266
266
  )
267
+
268
+ _add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
267
269
  print(
268
270
  typer.style(
269
- f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
271
+ f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
270
272
  fg=typer.colors.BRIGHT_CYAN,
271
273
  bold=True,
272
274
  )
flwr/cli/run/run.py CHANGED
@@ -18,13 +18,14 @@ import sys
18
18
  from enum import Enum
19
19
  from logging import DEBUG
20
20
  from pathlib import Path
21
- from typing import Optional
21
+ from typing import Dict, Optional
22
22
 
23
23
  import typer
24
24
  from typing_extensions import Annotated
25
25
 
26
26
  from flwr.cli import config_utils
27
27
  from flwr.cli.build import build
28
+ from flwr.common.config import parse_config_args
28
29
  from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
29
30
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
30
31
  from flwr.common.logger import log
@@ -58,15 +59,20 @@ def run(
58
59
  Optional[Path],
59
60
  typer.Option(help="Path of the Flower project to run"),
60
61
  ] = None,
62
+ config_overrides: Annotated[
63
+ Optional[str],
64
+ typer.Option(
65
+ "--config",
66
+ "-c",
67
+ help="Override configuration key-value pairs",
68
+ ),
69
+ ] = None,
61
70
  ) -> None:
62
71
  """Run Flower project."""
63
- if use_superexec:
64
- _start_superexec_run(directory)
65
- return
66
-
67
72
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
68
73
 
69
- config, errors, warnings = config_utils.load_and_validate()
74
+ pyproject_path = directory / "pyproject.toml" if directory else None
75
+ config, errors, warnings = config_utils.load_and_validate(path=pyproject_path)
70
76
 
71
77
  if config is None:
72
78
  typer.secho(
@@ -88,6 +94,12 @@ def run(
88
94
 
89
95
  typer.secho("Success", fg=typer.colors.GREEN)
90
96
 
97
+ if use_superexec:
98
+ _start_superexec_run(
99
+ parse_config_args(config_overrides, separator=","), directory
100
+ )
101
+ return
102
+
91
103
  server_app_ref = config["flower"]["components"]["serverapp"]
92
104
  client_app_ref = config["flower"]["components"]["clientapp"]
93
105
 
@@ -115,7 +127,9 @@ def run(
115
127
  )
116
128
 
117
129
 
118
- def _start_superexec_run(directory: Optional[Path]) -> None:
130
+ def _start_superexec_run(
131
+ override_config: Dict[str, str], directory: Optional[Path]
132
+ ) -> None:
119
133
  def on_channel_state_change(channel_connectivity: str) -> None:
120
134
  """Log channel connectivity."""
121
135
  log(DEBUG, channel_connectivity)
@@ -132,6 +146,9 @@ def _start_superexec_run(directory: Optional[Path]) -> None:
132
146
 
133
147
  fab_path = build(directory)
134
148
 
135
- req = StartRunRequest(fab_file=Path(fab_path).read_bytes())
149
+ req = StartRunRequest(
150
+ fab_file=Path(fab_path).read_bytes(),
151
+ override_config=override_config,
152
+ )
136
153
  res = stub.StartRun(req)
137
154
  typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
flwr/client/app.py CHANGED
@@ -19,6 +19,7 @@ import sys
19
19
  import time
20
20
  from dataclasses import dataclass
21
21
  from logging import DEBUG, ERROR, INFO, WARN
22
+ from pathlib import Path
22
23
  from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
23
24
 
24
25
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -193,6 +194,7 @@ def _start_client_internal(
193
194
  max_retries: Optional[int] = None,
194
195
  max_wait_time: Optional[float] = None,
195
196
  partition_id: Optional[int] = None,
197
+ flwr_dir: Optional[Path] = None,
196
198
  ) -> None:
197
199
  """Start a Flower client node which connects to a Flower server.
198
200
 
@@ -239,6 +241,8 @@ def _start_client_internal(
239
241
  partition_id: Optional[int] (default: None)
240
242
  The data partition index associated with this node. Better suited for
241
243
  prototyping purposes.
244
+ flwr_dir: Optional[Path] (default: None)
245
+ The fully resolved path containing installed Flower Apps.
242
246
  """
243
247
  if insecure is None:
244
248
  insecure = root_certificates is None
@@ -316,7 +320,7 @@ def _start_client_internal(
316
320
  )
317
321
 
318
322
  node_state = NodeState(partition_id=partition_id)
319
- run_info: Dict[int, Run] = {}
323
+ runs: Dict[int, Run] = {}
320
324
 
321
325
  while not app_state_tracker.interrupt:
322
326
  sleep_duration: int = 0
@@ -366,15 +370,17 @@ def _start_client_internal(
366
370
 
367
371
  # Get run info
368
372
  run_id = message.metadata.run_id
369
- if run_id not in run_info:
373
+ if run_id not in runs:
370
374
  if get_run is not None:
371
- run_info[run_id] = get_run(run_id)
375
+ runs[run_id] = get_run(run_id)
372
376
  # If get_run is None, i.e., in grpc-bidi mode
373
377
  else:
374
- run_info[run_id] = Run(run_id, "", "", {})
378
+ runs[run_id] = Run(run_id, "", "", {})
375
379
 
376
380
  # Register context for this run
377
- node_state.register_context(run_id=run_id)
381
+ node_state.register_context(
382
+ run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir
383
+ )
378
384
 
379
385
  # Retrieve context for this run
380
386
  context = node_state.retrieve_context(run_id=run_id)
@@ -388,7 +394,7 @@ def _start_client_internal(
388
394
  # Handle app loading and task message
389
395
  try:
390
396
  # Load ClientApp instance
391
- run: Run = run_info[run_id]
397
+ run: Run = runs[run_id]
392
398
  client_app: ClientApp = load_client_app_fn(
393
399
  run.fab_id, run.fab_version
394
400
  )
flwr/client/node_state.py CHANGED
@@ -15,9 +15,21 @@
15
15
  """Node state."""
16
16
 
17
17
 
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
18
20
  from typing import Any, Dict, Optional
19
21
 
20
22
  from flwr.common import Context, RecordSet
23
+ from flwr.common.config import get_fused_config
24
+ from flwr.common.typing import Run
25
+
26
+
27
+ @dataclass()
28
+ class RunInfo:
29
+ """Contains the Context and initial run_config of a Run."""
30
+
31
+ context: Context
32
+ initial_run_config: Dict[str, str]
21
33
 
22
34
 
23
35
  class NodeState:
@@ -25,20 +37,31 @@ class NodeState:
25
37
 
26
38
  def __init__(self, partition_id: Optional[int]) -> None:
27
39
  self._meta: Dict[str, Any] = {} # holds metadata about the node
28
- self.run_contexts: Dict[int, Context] = {}
40
+ self.run_infos: Dict[int, RunInfo] = {}
29
41
  self._partition_id = partition_id
30
42
 
31
- def register_context(self, run_id: int) -> None:
43
+ def register_context(
44
+ self,
45
+ run_id: int,
46
+ run: Optional[Run] = None,
47
+ flwr_dir: Optional[Path] = None,
48
+ ) -> None:
32
49
  """Register new run context for this node."""
33
- if run_id not in self.run_contexts:
34
- self.run_contexts[run_id] = Context(
35
- state=RecordSet(), run_config={}, partition_id=self._partition_id
50
+ if run_id not in self.run_infos:
51
+ initial_run_config = get_fused_config(run, flwr_dir) if run else {}
52
+ self.run_infos[run_id] = RunInfo(
53
+ initial_run_config=initial_run_config,
54
+ context=Context(
55
+ state=RecordSet(),
56
+ run_config=initial_run_config.copy(),
57
+ partition_id=self._partition_id,
58
+ ),
36
59
  )
37
60
 
38
61
  def retrieve_context(self, run_id: int) -> Context:
39
62
  """Get run context given a run_id."""
40
- if run_id in self.run_contexts:
41
- return self.run_contexts[run_id]
63
+ if run_id in self.run_infos:
64
+ return self.run_infos[run_id].context
42
65
 
43
66
  raise RuntimeError(
44
67
  f"Context for run_id={run_id} doesn't exist."
@@ -48,4 +71,9 @@ class NodeState:
48
71
 
49
72
  def update_context(self, run_id: int, context: Context) -> None:
50
73
  """Update run context."""
51
- self.run_contexts[run_id] = context
74
+ if context.run_config != self.run_infos[run_id].initial_run_config:
75
+ raise ValueError(
76
+ "The `run_config` field of the `Context` object cannot be "
77
+ f"modified (run_id: {run_id})."
78
+ )
79
+ self.run_infos[run_id].context = context
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
59
59
  node_state.update_context(run_id=run_id, context=updated_state)
60
60
 
61
61
  # Verify values
62
- for run_id, context in node_state.run_contexts.items():
62
+ for run_id, run_info in node_state.run_infos.items():
63
63
  assert (
64
- context.state.configs_records["counter"]["count"] == expected_values[run_id]
64
+ run_info.context.state.configs_records["counter"]["count"]
65
+ == expected_values[run_id]
65
66
  )
@@ -68,6 +68,7 @@ def run_supernode() -> None:
68
68
  max_retries=args.max_retries,
69
69
  max_wait_time=args.max_wait_time,
70
70
  partition_id=args.partition_id,
71
+ flwr_dir=get_flwr_dir(args.flwr_dir),
71
72
  )
72
73
 
73
74
  # Graceful shutdown
@@ -178,7 +179,7 @@ def _get_load_client_app_fn(
178
179
  else:
179
180
  flwr_dir = Path(args.flwr_dir).absolute()
180
181
 
181
- sys.path.insert(0, str(flwr_dir.absolute()))
182
+ inserted_path = None
182
183
 
183
184
  default_app_ref: str = getattr(args, "client-app")
184
185
 
@@ -188,6 +189,11 @@ def _get_load_client_app_fn(
188
189
  "Flower SuperNode will load and validate ClientApp `%s`",
189
190
  getattr(args, "client-app"),
190
191
  )
192
+ # Insert sys.path
193
+ dir_path = Path(args.dir).absolute()
194
+ sys.path.insert(0, str(dir_path))
195
+ inserted_path = str(dir_path)
196
+
191
197
  valid, error_msg = validate(default_app_ref)
192
198
  if not valid and error_msg:
193
199
  raise LoadClientAppError(error_msg) from None
@@ -196,7 +202,7 @@ def _get_load_client_app_fn(
196
202
  # If multi-app feature is disabled
197
203
  if not multi_app:
198
204
  # Get sys path to be inserted
199
- sys_path = Path(args.dir).absolute()
205
+ dir_path = Path(args.dir).absolute()
200
206
 
201
207
  # Set app reference
202
208
  client_app_ref = default_app_ref
@@ -209,7 +215,7 @@ def _get_load_client_app_fn(
209
215
 
210
216
  log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
211
217
  # Get sys path to be inserted
212
- sys_path = Path(args.dir).absolute()
218
+ dir_path = Path(args.dir).absolute()
213
219
 
214
220
  # Set app reference
215
221
  client_app_ref = default_app_ref
@@ -222,13 +228,21 @@ def _get_load_client_app_fn(
222
228
  raise LoadClientAppError("Failed to load ClientApp") from e
223
229
 
224
230
  # Get sys path to be inserted
225
- sys_path = Path(project_dir).absolute()
231
+ dir_path = Path(project_dir).absolute()
226
232
 
227
233
  # Set app reference
228
234
  client_app_ref = config["flower"]["components"]["clientapp"]
229
235
 
230
236
  # Set sys.path
231
- sys.path.insert(0, str(sys_path))
237
+ nonlocal inserted_path
238
+ if inserted_path != str(dir_path):
239
+ # Remove the previously inserted path
240
+ if inserted_path is not None:
241
+ sys.path.remove(inserted_path)
242
+ # Insert the new path
243
+ sys.path.insert(0, str(dir_path))
244
+
245
+ inserted_path = str(dir_path)
232
246
 
233
247
  # Load ClientApp
234
248
  log(
@@ -236,7 +250,7 @@ def _get_load_client_app_fn(
236
250
  "Loading ClientApp `%s`",
237
251
  client_app_ref,
238
252
  )
239
- client_app = load_app(client_app_ref, LoadClientAppError, sys_path)
253
+ client_app = load_app(client_app_ref, LoadClientAppError, dir_path)
240
254
 
241
255
  if not isinstance(client_app, ClientApp):
242
256
  raise LoadClientAppError(
flwr/common/logger.py CHANGED
@@ -197,6 +197,31 @@ def warn_deprecated_feature(name: str) -> None:
197
197
  )
198
198
 
199
199
 
200
+ def warn_deprecated_feature_with_example(
201
+ deprecation_message: str, example_message: str, code_example: str
202
+ ) -> None:
203
+ """Warn if a feature is deprecated and show code example."""
204
+ log(
205
+ WARN,
206
+ """DEPRECATED FEATURE: %s
207
+
208
+ Check the following `FEATURE UPDATE` warning message for the preferred
209
+ new mechanism to use this feature in Flower.
210
+ """,
211
+ deprecation_message,
212
+ )
213
+ log(
214
+ WARN,
215
+ """FEATURE UPDATE: %s
216
+ ------------------------------------------------------------
217
+ %s
218
+ ------------------------------------------------------------
219
+ """,
220
+ example_message,
221
+ code_example,
222
+ )
223
+
224
+
200
225
  def warn_unsupported_feature(name: str) -> None:
201
226
  """Warn the user when they use an unsupported feature."""
202
227
  log(
flwr/server/__init__.py CHANGED
@@ -28,6 +28,7 @@ from .run_serverapp import run_server_app as run_server_app
28
28
  from .server import Server as Server
29
29
  from .server_app import ServerApp as ServerApp
30
30
  from .server_config import ServerConfig as ServerConfig
31
+ from .serverapp_components import ServerAppComponents as ServerAppComponents
31
32
 
32
33
  __all__ = [
33
34
  "ClientManager",
@@ -36,6 +37,7 @@ __all__ = [
36
37
  "LegacyContext",
37
38
  "Server",
38
39
  "ServerApp",
40
+ "ServerAppComponents",
39
41
  "ServerConfig",
40
42
  "SimpleClientManager",
41
43
  "run_server_app",
flwr/server/server_app.py CHANGED
@@ -17,8 +17,11 @@
17
17
 
18
18
  from typing import Callable, Optional
19
19
 
20
- from flwr.common import Context, RecordSet
21
- from flwr.common.logger import warn_preview_feature
20
+ from flwr.common import Context
21
+ from flwr.common.logger import (
22
+ warn_deprecated_feature_with_example,
23
+ warn_preview_feature,
24
+ )
22
25
  from flwr.server.strategy import Strategy
23
26
 
24
27
  from .client_manager import ClientManager
@@ -26,7 +29,20 @@ from .compat import start_driver
26
29
  from .driver import Driver
27
30
  from .server import Server
28
31
  from .server_config import ServerConfig
29
- from .typing import ServerAppCallable
32
+ from .typing import ServerAppCallable, ServerFn
33
+
34
+ SERVER_FN_USAGE_EXAMPLE = """
35
+
36
+ def server_fn(context: Context):
37
+ server_config = ServerConfig(num_rounds=3)
38
+ strategy = FedAvg()
39
+ return ServerAppComponents(
40
+ strategy=strategy,
41
+ server_config=server_config,
42
+ )
43
+
44
+ app = ServerApp(server_fn=server_fn)
45
+ """
30
46
 
31
47
 
32
48
  class ServerApp:
@@ -36,13 +52,15 @@ class ServerApp:
36
52
  --------
37
53
  Use the `ServerApp` with an existing `Strategy`:
38
54
 
39
- >>> server_config = ServerConfig(num_rounds=3)
40
- >>> strategy = FedAvg()
55
+ >>> def server_fn(context: Context):
56
+ >>> server_config = ServerConfig(num_rounds=3)
57
+ >>> strategy = FedAvg()
58
+ >>> return ServerAppComponents(
59
+ >>> strategy=strategy,
60
+ >>> server_config=server_config,
61
+ >>> )
41
62
  >>>
42
- >>> app = ServerApp(
43
- >>> server_config=server_config,
44
- >>> strategy=strategy,
45
- >>> )
63
+ >>> app = ServerApp(server_fn=server_fn)
46
64
 
47
65
  Use the `ServerApp` with a custom main function:
48
66
 
@@ -53,23 +71,52 @@ class ServerApp:
53
71
  >>> print("ServerApp running")
54
72
  """
55
73
 
74
+ # pylint: disable=too-many-arguments
56
75
  def __init__(
57
76
  self,
58
77
  server: Optional[Server] = None,
59
78
  config: Optional[ServerConfig] = None,
60
79
  strategy: Optional[Strategy] = None,
61
80
  client_manager: Optional[ClientManager] = None,
81
+ server_fn: Optional[ServerFn] = None,
62
82
  ) -> None:
83
+ if any([server, config, strategy, client_manager]):
84
+ warn_deprecated_feature_with_example(
85
+ deprecation_message="Passing either `server`, `config`, `strategy` or "
86
+ "`client_manager` directly to the ServerApp "
87
+ "constructor is deprecated.",
88
+ example_message="Pass `ServerApp` arguments wrapped "
89
+ "in a `flwr.server.ServerAppComponents` object that gets "
90
+ "returned by a function passed as the `server_fn` argument "
91
+ "to the `ServerApp` constructor. For example: ",
92
+ code_example=SERVER_FN_USAGE_EXAMPLE,
93
+ )
94
+
95
+ if server_fn:
96
+ raise ValueError(
97
+ "Passing `server_fn` is incompatible with passing the "
98
+ "other arguments (now deprecated) to ServerApp. "
99
+ "Use `server_fn` exclusively."
100
+ )
101
+
63
102
  self._server = server
64
103
  self._config = config
65
104
  self._strategy = strategy
66
105
  self._client_manager = client_manager
106
+ self._server_fn = server_fn
67
107
  self._main: Optional[ServerAppCallable] = None
68
108
 
69
109
  def __call__(self, driver: Driver, context: Context) -> None:
70
110
  """Execute `ServerApp`."""
71
111
  # Compatibility mode
72
112
  if not self._main:
113
+ if self._server_fn:
114
+ # Execute server_fn()
115
+ components = self._server_fn(context)
116
+ self._server = components.server
117
+ self._config = components.config
118
+ self._strategy = components.strategy
119
+ self._client_manager = components.client_manager
73
120
  start_driver(
74
121
  server=self._server,
75
122
  config=self._config,
@@ -80,7 +127,6 @@ class ServerApp:
80
127
  return
81
128
 
82
129
  # New execution mode
83
- context = Context(state=RecordSet(), run_config={})
84
130
  self._main(driver, context)
85
131
 
86
132
  def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ServerAppComponents for the ServerApp."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ from .client_manager import ClientManager
22
+ from .server import Server
23
+ from .server_config import ServerConfig
24
+ from .strategy import Strategy
25
+
26
+
27
+ @dataclass
28
+ class ServerAppComponents: # pylint: disable=too-many-instance-attributes
29
+ """Components to construct a ServerApp.
30
+
31
+ Parameters
32
+ ----------
33
+ server : Optional[Server] (default: None)
34
+ A server implementation, either `flwr.server.Server` or a subclass
35
+ thereof. If no instance is provided, one will be created internally.
36
+ config : Optional[ServerConfig] (default: None)
37
+ Currently supported values are `num_rounds` (int, default: 1) and
38
+ `round_timeout` in seconds (float, default: None).
39
+ strategy : Optional[Strategy] (default: None)
40
+ An implementation of the abstract base class
41
+ `flwr.server.strategy.Strategy`. If no strategy is provided, then
42
+ `flwr.server.strategy.FedAvg` will be used.
43
+ client_manager : Optional[ClientManager] (default: None)
44
+ An implementation of the class `flwr.server.ClientManager`. If no
45
+ implementation is provided, then `flwr.server.SimpleClientManager`
46
+ will be used.
47
+ """
48
+
49
+ server: Optional[Server] = None
50
+ config: Optional[ServerConfig] = None
51
+ strategy: Optional[Strategy] = None
52
+ client_manager: Optional[ClientManager] = None
@@ -33,8 +33,8 @@ class Backend(ABC):
33
33
  """Construct a backend."""
34
34
 
35
35
  @abstractmethod
36
- async def build(self) -> None:
37
- """Build backend asynchronously.
36
+ def build(self) -> None:
37
+ """Build backend.
38
38
 
39
39
  Different components need to be in place before workers in a backend are ready
40
40
  to accept jobs. When this method finishes executing, the backend should be fully
@@ -54,11 +54,11 @@ class Backend(ABC):
54
54
  """Report whether a backend worker is idle and can therefore run a ClientApp."""
55
55
 
56
56
  @abstractmethod
57
- async def terminate(self) -> None:
57
+ def terminate(self) -> None:
58
58
  """Terminate backend."""
59
59
 
60
60
  @abstractmethod
61
- async def process_message(
61
+ def process_message(
62
62
  self,
63
63
  app: Callable[[], ClientApp],
64
64
  message: Message,
@@ -153,12 +153,12 @@ class RayBackend(Backend):
153
153
  """Report whether the pool has idle actors."""
154
154
  return self.pool.is_actor_available()
155
155
 
156
- async def build(self) -> None:
156
+ def build(self) -> None:
157
157
  """Build pool of Ray actors that this backend will submit jobs to."""
158
- await self.pool.add_actors_to_pool(self.pool.actors_capacity)
158
+ self.pool.add_actors_to_pool(self.pool.actors_capacity)
159
159
  log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
160
160
 
161
- async def process_message(
161
+ def process_message(
162
162
  self,
163
163
  app: Callable[[], ClientApp],
164
164
  message: Message,
@@ -172,17 +172,16 @@ class RayBackend(Backend):
172
172
 
173
173
  try:
174
174
  # Submit a task to the pool
175
- future = await self.pool.submit(
175
+ future = self.pool.submit(
176
176
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
177
177
  (app, message, str(partition_id), context),
178
178
  )
179
179
 
180
- await future
181
180
  # Fetch result
182
181
  (
183
182
  out_mssg,
184
183
  updated_context,
185
- ) = await self.pool.fetch_result_and_return_actor_to_pool(future)
184
+ ) = self.pool.fetch_result_and_return_actor_to_pool(future)
186
185
 
187
186
  return out_mssg, updated_context
188
187
 
@@ -193,11 +192,11 @@ class RayBackend(Backend):
193
192
  self.__class__.__name__,
194
193
  )
195
194
  # add actor back into pool
196
- await self.pool.add_actor_back_to_pool(future)
195
+ self.pool.add_actor_back_to_pool(future)
197
196
  raise ex
198
197
 
199
- async def terminate(self) -> None:
198
+ def terminate(self) -> None:
200
199
  """Terminate all actors in actor pool."""
201
- await self.pool.terminate_all_actors()
200
+ self.pool.terminate_all_actors()
202
201
  ray.shutdown()
203
202
  log(DEBUG, "Terminated %s", self.__class__.__name__)