flwr-nightly 1.10.0.dev20240714__py3-none-any.whl → 1.10.0.dev20240716__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +23 -15
  3. flwr/cli/install.py +1 -1
  4. flwr/cli/new/templates/app/code/server.hf.py.tpl +4 -1
  5. flwr/cli/new/templates/app/code/server.jax.py.tpl +4 -1
  6. flwr/cli/new/templates/app/code/server.mlx.py.tpl +4 -1
  7. flwr/cli/new/templates/app/code/server.numpy.py.tpl +4 -1
  8. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +4 -1
  9. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +4 -1
  10. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +4 -1
  11. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  12. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +10 -10
  13. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +12 -6
  14. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +10 -10
  15. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +10 -10
  16. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +10 -10
  17. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +10 -10
  18. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +10 -10
  19. flwr/cli/run/run.py +110 -57
  20. flwr/client/app.py +3 -3
  21. flwr/client/node_state.py +17 -3
  22. flwr/client/supernode/app.py +26 -15
  23. flwr/common/config.py +13 -4
  24. flwr/server/run_serverapp.py +1 -1
  25. flwr/server/superlink/fleet/vce/vce_api.py +52 -28
  26. flwr/simulation/run_simulation.py +184 -33
  27. flwr/superexec/simulation.py +157 -0
  28. {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/METADATA +2 -1
  29. {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/RECORD +32 -31
  30. {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/LICENSE +0 -0
  31. {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/WHEEL +0 -0
  32. {flwr_nightly-1.10.0.dev20240714.dist-info → flwr_nightly-1.10.0.dev20240716.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py CHANGED
@@ -14,50 +14,33 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `run` command."""
16
16
 
17
+ import subprocess
17
18
  import sys
18
- from enum import Enum
19
19
  from logging import DEBUG
20
20
  from pathlib import Path
21
- from typing import Dict, Optional
21
+ from typing import Any, Dict, Optional
22
22
 
23
23
  import typer
24
24
  from typing_extensions import Annotated
25
25
 
26
- from flwr.cli import config_utils
27
26
  from flwr.cli.build import build
27
+ from flwr.cli.config_utils import load_and_validate
28
28
  from flwr.common.config import parse_config_args
29
- from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
30
29
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
31
30
  from flwr.common.logger import log
32
31
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
33
32
  from flwr.proto.exec_pb2_grpc import ExecStub
34
- from flwr.simulation.run_simulation import _run_simulation
35
-
36
-
37
- class Engine(str, Enum):
38
- """Enum defining the engine to run on."""
39
-
40
- SIMULATION = "simulation"
41
33
 
42
34
 
43
35
  # pylint: disable-next=too-many-locals
44
36
  def run(
45
- engine: Annotated[
46
- Optional[Engine],
47
- typer.Option(
48
- case_sensitive=False,
49
- help="The engine to run FL with (currently only simulation is supported).",
50
- ),
51
- ] = None,
52
- use_superexec: Annotated[
53
- bool,
54
- typer.Option(
55
- case_sensitive=False, help="Use this flag to use the new SuperExec API"
56
- ),
57
- ] = False,
58
37
  directory: Annotated[
59
- Optional[Path],
60
- typer.Option(help="Path of the Flower project to run"),
38
+ Path,
39
+ typer.Argument(help="Path of the Flower project to run"),
40
+ ] = Path("."),
41
+ federation_name: Annotated[
42
+ Optional[str],
43
+ typer.Argument(help="Name of the federation to run the app on"),
61
44
  ] = None,
62
45
  config_overrides: Annotated[
63
46
  Optional[str],
@@ -72,7 +55,7 @@ def run(
72
55
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
73
56
 
74
57
  pyproject_path = directory / "pyproject.toml" if directory else None
75
- config, errors, warnings = config_utils.load_and_validate(path=pyproject_path)
58
+ config, errors, warnings = load_and_validate(path=pyproject_path)
76
59
 
77
60
  if config is None:
78
61
  typer.secho(
@@ -94,50 +77,81 @@ def run(
94
77
 
95
78
  typer.secho("Success", fg=typer.colors.GREEN)
96
79
 
97
- if use_superexec:
98
- _start_superexec_run(
99
- parse_config_args(config_overrides, separator=","), directory
100
- )
101
- return
102
-
103
- server_app_ref = config["flower"]["components"]["serverapp"]
104
- client_app_ref = config["flower"]["components"]["clientapp"]
105
-
106
- if engine is None:
107
- engine = config["flower"]["engine"]["name"]
80
+ federation_name = federation_name or config["tool"]["flwr"]["federations"].get(
81
+ "default"
82
+ )
108
83
 
109
- if engine == Engine.SIMULATION:
110
- num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
111
- backend_config = config["flower"]["engine"]["simulation"].get(
112
- "backend_config", None
84
+ if federation_name is None:
85
+ typer.secho(
86
+ "❌ No federation name was provided and the project's `pyproject.toml` "
87
+ "doesn't declare a default federation (with a SuperExec address or an "
88
+ "`options.num-supernodes` value).",
89
+ fg=typer.colors.RED,
90
+ bold=True,
113
91
  )
92
+ raise typer.Exit(code=1)
114
93
 
115
- typer.secho("Starting run... ", fg=typer.colors.BLUE)
116
- _run_simulation(
117
- server_app_attr=server_app_ref,
118
- client_app_attr=client_app_ref,
119
- num_supernodes=num_supernodes,
120
- backend_config=backend_config,
121
- )
122
- else:
94
+ # Validate the federation exists in the configuration
95
+ federation = config["tool"]["flwr"]["federations"].get(federation_name)
96
+ if federation is None:
97
+ available_feds = list(config["tool"]["flwr"]["federations"])
123
98
  typer.secho(
124
- f"Engine '{engine}' is not yet supported in `flwr run`",
99
+ f" There is no `{federation_name}` federation declared in the "
100
+ "`pyproject.toml`.\n The following federations were found:\n\n"
101
+ "\n".join(available_feds) + "\n\n",
125
102
  fg=typer.colors.RED,
126
103
  bold=True,
127
104
  )
105
+ raise typer.Exit(code=1)
106
+
107
+ if "address" in federation:
108
+ _run_with_superexec(federation, directory, config_overrides)
109
+ else:
110
+ _run_without_superexec(directory, federation, federation_name, config_overrides)
128
111
 
129
112
 
130
- def _start_superexec_run(
131
- override_config: Dict[str, str], directory: Optional[Path]
113
+ def _run_with_superexec(
114
+ federation: Dict[str, str],
115
+ directory: Optional[Path],
116
+ config_overrides: Optional[str],
132
117
  ) -> None:
118
+
133
119
  def on_channel_state_change(channel_connectivity: str) -> None:
134
120
  """Log channel connectivity."""
135
121
  log(DEBUG, channel_connectivity)
136
122
 
123
+ insecure_str = federation.get("insecure")
124
+ if root_certificates := federation.get("root-certificates"):
125
+ root_certificates_bytes = Path(root_certificates).read_bytes()
126
+ if insecure := bool(insecure_str):
127
+ typer.secho(
128
+ "❌ `root_certificates` were provided but the `insecure` parameter"
129
+ "is set to `True`.",
130
+ fg=typer.colors.RED,
131
+ bold=True,
132
+ )
133
+ raise typer.Exit(code=1)
134
+ else:
135
+ root_certificates_bytes = None
136
+ if insecure_str is None:
137
+ typer.secho(
138
+ "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
139
+ fg=typer.colors.RED,
140
+ bold=True,
141
+ )
142
+ raise typer.Exit(code=1)
143
+ if not (insecure := bool(insecure_str)):
144
+ typer.secho(
145
+ "❌ No certificate were given yet `insecure` is set to `False`.",
146
+ fg=typer.colors.RED,
147
+ bold=True,
148
+ )
149
+ raise typer.Exit(code=1)
150
+
137
151
  channel = create_channel(
138
- server_address=SUPEREXEC_DEFAULT_ADDRESS,
139
- insecure=True,
140
- root_certificates=None,
152
+ server_address=federation["address"],
153
+ insecure=insecure,
154
+ root_certificates=root_certificates_bytes,
141
155
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
142
156
  interceptors=None,
143
157
  )
@@ -148,7 +162,46 @@ def _start_superexec_run(
148
162
 
149
163
  req = StartRunRequest(
150
164
  fab_file=Path(fab_path).read_bytes(),
151
- override_config=override_config,
165
+ override_config=parse_config_args(config_overrides, separator=","),
152
166
  )
153
167
  res = stub.StartRun(req)
154
168
  typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
169
+
170
+
171
+ def _run_without_superexec(
172
+ app_path: Optional[Path],
173
+ federation: Dict[str, Any],
174
+ federation_name: str,
175
+ config_overrides: Optional[str],
176
+ ) -> None:
177
+ try:
178
+ num_supernodes = federation["options"]["num-supernodes"]
179
+ except KeyError as err:
180
+ typer.secho(
181
+ "❌ The project's `pyproject.toml` needs to declare the number of"
182
+ " SuperNodes in the simulation. To simulate 10 SuperNodes,"
183
+ " use the following notation:\n\n"
184
+ f"[tool.flwr.federations.{federation_name}]\n"
185
+ "options.num-supernodes = 10\n",
186
+ fg=typer.colors.RED,
187
+ bold=True,
188
+ )
189
+ raise typer.Exit(code=1) from err
190
+
191
+ command = [
192
+ "flower-simulation",
193
+ "--app",
194
+ f"{app_path}",
195
+ "--num-supernodes",
196
+ f"{num_supernodes}",
197
+ ]
198
+
199
+ if config_overrides:
200
+ command.extend(["--run-config", f"{config_overrides}"])
201
+
202
+ # Run the simulation
203
+ subprocess.run(
204
+ command,
205
+ check=True,
206
+ text=True,
207
+ )
flwr/client/app.py CHANGED
@@ -195,7 +195,7 @@ def _start_client_internal(
195
195
  ] = None,
196
196
  max_retries: Optional[int] = None,
197
197
  max_wait_time: Optional[float] = None,
198
- flwr_dir: Optional[Path] = None,
198
+ flwr_path: Optional[Path] = None,
199
199
  ) -> None:
200
200
  """Start a Flower client node which connects to a Flower server.
201
201
 
@@ -241,7 +241,7 @@ def _start_client_internal(
241
241
  The maximum duration before the client stops trying to
242
242
  connect to the server in case of connection error.
243
243
  If set to None, there is no limit to the total time.
244
- flwr_dir: Optional[Path] (default: None)
244
+ flwr_path: Optional[Path] (default: None)
245
245
  The fully resolved path containing installed Flower Apps.
246
246
  """
247
247
  if insecure is None:
@@ -402,7 +402,7 @@ def _start_client_internal(
402
402
 
403
403
  # Register context for this run
404
404
  node_state.register_context(
405
- run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir
405
+ run_id=run_id, run=runs[run_id], flwr_path=flwr_path
406
406
  )
407
407
 
408
408
  # Retrieve context for this run
flwr/client/node_state.py CHANGED
@@ -20,7 +20,7 @@ from pathlib import Path
20
20
  from typing import Dict, Optional
21
21
 
22
22
  from flwr.common import Context, RecordSet
23
- from flwr.common.config import get_fused_config
23
+ from flwr.common.config import get_fused_config, get_fused_config_from_dir
24
24
  from flwr.common.typing import Run
25
25
 
26
26
 
@@ -48,11 +48,25 @@ class NodeState:
48
48
  self,
49
49
  run_id: int,
50
50
  run: Optional[Run] = None,
51
- flwr_dir: Optional[Path] = None,
51
+ flwr_path: Optional[Path] = None,
52
+ app_dir: Optional[str] = None,
52
53
  ) -> None:
53
54
  """Register new run context for this node."""
54
55
  if run_id not in self.run_infos:
55
- initial_run_config = get_fused_config(run, flwr_dir) if run else {}
56
+ initial_run_config = {}
57
+ if app_dir:
58
+ # Load from app directory
59
+ app_path = Path(app_dir)
60
+ if app_path.is_dir():
61
+ override_config = run.override_config if run else {}
62
+ initial_run_config = get_fused_config_from_dir(
63
+ app_path, override_config
64
+ )
65
+ else:
66
+ raise ValueError("The specified `app_dir` must be a directory.")
67
+ else:
68
+ # Load from .fab
69
+ initial_run_config = get_fused_config(run, flwr_path) if run else {}
56
70
  self.run_infos[run_id] = RunInfo(
57
71
  initial_run_config=initial_run_config,
58
72
  context=Context(
@@ -60,7 +60,12 @@ def run_supernode() -> None:
60
60
  _warn_deprecated_server_arg(args)
61
61
 
62
62
  root_certificates = _get_certificates(args)
63
- load_fn = _get_load_client_app_fn(args, multi_app=True)
63
+ load_fn = _get_load_client_app_fn(
64
+ default_app_ref=getattr(args, "client-app"),
65
+ dir_arg=args.dir,
66
+ flwr_dir_arg=args.flwr_dir,
67
+ multi_app=True,
68
+ )
64
69
  authentication_keys = _try_setup_client_authentication(args)
65
70
 
66
71
  _start_client_internal(
@@ -73,7 +78,7 @@ def run_supernode() -> None:
73
78
  max_retries=args.max_retries,
74
79
  max_wait_time=args.max_wait_time,
75
80
  node_config=parse_config_args(args.node_config),
76
- flwr_dir=get_flwr_dir(args.flwr_dir),
81
+ flwr_path=get_flwr_dir(args.flwr_dir),
77
82
  )
78
83
 
79
84
  # Graceful shutdown
@@ -93,7 +98,11 @@ def run_client_app() -> None:
93
98
  _warn_deprecated_server_arg(args)
94
99
 
95
100
  root_certificates = _get_certificates(args)
96
- load_fn = _get_load_client_app_fn(args, multi_app=False)
101
+ load_fn = _get_load_client_app_fn(
102
+ default_app_ref=getattr(args, "client-app"),
103
+ dir_arg=args.dir,
104
+ multi_app=False,
105
+ )
97
106
  authentication_keys = _try_setup_client_authentication(args)
98
107
 
99
108
  _start_client_internal(
@@ -166,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
166
175
 
167
176
 
168
177
  def _get_load_client_app_fn(
169
- args: argparse.Namespace, multi_app: bool
178
+ default_app_ref: str,
179
+ dir_arg: str,
180
+ multi_app: bool,
181
+ flwr_dir_arg: Optional[str] = None,
170
182
  ) -> Callable[[str, str], ClientApp]:
171
183
  """Get the load_client_app_fn function.
172
184
 
@@ -178,25 +190,24 @@ def _get_load_client_app_fn(
178
190
  loads a default ClientApp.
179
191
  """
180
192
  # Find the Flower directory containing Flower Apps (only for multi-app)
181
- flwr_dir = Path("")
182
- if "flwr_dir" in args:
183
- if args.flwr_dir is None:
193
+ if not multi_app:
194
+ flwr_dir = Path("")
195
+ else:
196
+ if flwr_dir_arg is None:
184
197
  flwr_dir = get_flwr_dir()
185
198
  else:
186
- flwr_dir = Path(args.flwr_dir).absolute()
199
+ flwr_dir = Path(flwr_dir_arg).absolute()
187
200
 
188
201
  inserted_path = None
189
202
 
190
- default_app_ref: str = getattr(args, "client-app")
191
-
192
203
  if not multi_app:
193
204
  log(
194
205
  DEBUG,
195
206
  "Flower SuperNode will load and validate ClientApp `%s`",
196
- getattr(args, "client-app"),
207
+ default_app_ref,
197
208
  )
198
209
  # Insert sys.path
199
- dir_path = Path(args.dir).absolute()
210
+ dir_path = Path(dir_arg).absolute()
200
211
  sys.path.insert(0, str(dir_path))
201
212
  inserted_path = str(dir_path)
202
213
 
@@ -208,7 +219,7 @@ def _get_load_client_app_fn(
208
219
  # If multi-app feature is disabled
209
220
  if not multi_app:
210
221
  # Get sys path to be inserted
211
- dir_path = Path(args.dir).absolute()
222
+ dir_path = Path(dir_arg).absolute()
212
223
 
213
224
  # Set app reference
214
225
  client_app_ref = default_app_ref
@@ -221,7 +232,7 @@ def _get_load_client_app_fn(
221
232
 
222
233
  log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
223
234
  # Get sys path to be inserted
224
- dir_path = Path(args.dir).absolute()
235
+ dir_path = Path(dir_arg).absolute()
225
236
 
226
237
  # Set app reference
227
238
  client_app_ref = default_app_ref
@@ -237,7 +248,7 @@ def _get_load_client_app_fn(
237
248
  dir_path = Path(project_dir).absolute()
238
249
 
239
250
  # Set app reference
240
- client_app_ref = config["flower"]["components"]["clientapp"]
251
+ client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
241
252
 
242
253
  # Set sys.path
243
254
  nonlocal inserted_path
flwr/common/config.py CHANGED
@@ -86,6 +86,18 @@ def _fuse_dicts(
86
86
  return fused_dict
87
87
 
88
88
 
89
+ def get_fused_config_from_dir(
90
+ project_dir: Path, override_config: Dict[str, str]
91
+ ) -> Dict[str, str]:
92
+ """Merge the overrides from a given dict with the config from a Flower App."""
93
+ default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
94
+ "config", {}
95
+ )
96
+ flat_default_config = flatten_dict(default_config)
97
+
98
+ return _fuse_dicts(flat_default_config, override_config)
99
+
100
+
89
101
  def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
90
102
  """Merge the overrides from a `Run` with the config from a FAB.
91
103
 
@@ -97,10 +109,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
97
109
 
98
110
  project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
99
111
 
100
- default_config = get_project_config(project_dir)["flower"].get("config", {})
101
- flat_default_config = flatten_dict(default_config)
102
-
103
- return _fuse_dicts(flat_default_config, run.override_config)
112
+ return get_fused_config_from_dir(project_dir, run.override_config)
104
113
 
105
114
 
106
115
  def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]:
@@ -186,7 +186,7 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
186
186
  run_ = driver.run
187
187
  server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
188
188
  config = get_project_config(server_app_dir)
189
- server_app_attr = config["flower"]["components"]["serverapp"]
189
+ server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
190
190
  server_app_run_config = get_fused_config(run_, flwr_dir)
191
191
  else:
192
192
  # User provided `server-app`, but not `--run-id`
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import json
19
- import sys
20
19
  import threading
21
20
  import time
22
21
  import traceback
@@ -29,6 +28,7 @@ from typing import Callable, Dict, Optional
29
28
 
30
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
31
30
  from flwr.client.node_state import NodeState
31
+ from flwr.client.supernode.app import _get_load_client_app_fn
32
32
  from flwr.common.constant import (
33
33
  NUM_PARTITIONS_KEY,
34
34
  PARTITION_ID_KEY,
@@ -37,8 +37,8 @@ from flwr.common.constant import (
37
37
  )
38
38
  from flwr.common.logger import log
39
39
  from flwr.common.message import Error
40
- from flwr.common.object_ref import load_app
41
40
  from flwr.common.serde import message_from_taskins, message_to_taskres
41
+ from flwr.common.typing import Run
42
42
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
43
  from flwr.server.superlink.state import State, StateFactory
44
44
 
@@ -60,6 +60,31 @@ def _register_nodes(
60
60
  return nodes_mapping
61
61
 
62
62
 
63
+ def _register_node_states(
64
+ nodes_mapping: NodeToPartitionMapping,
65
+ run: Run,
66
+ app_dir: Optional[str] = None,
67
+ ) -> Dict[int, NodeState]:
68
+ """Create NodeState objects and pre-register the context for the run."""
69
+ node_states: Dict[int, NodeState] = {}
70
+ num_partitions = len(set(nodes_mapping.values()))
71
+ for node_id, partition_id in nodes_mapping.items():
72
+ node_states[node_id] = NodeState(
73
+ node_id=node_id,
74
+ node_config={
75
+ PARTITION_ID_KEY: str(partition_id),
76
+ NUM_PARTITIONS_KEY: str(num_partitions),
77
+ },
78
+ )
79
+
80
+ # Pre-register Context objects
81
+ node_states[node_id].register_context(
82
+ run_id=run.run_id, run=run, app_dir=app_dir
83
+ )
84
+
85
+ return node_states
86
+
87
+
63
88
  # pylint: disable=too-many-arguments,too-many-locals
64
89
  def worker(
65
90
  app_fn: Callable[[], ClientApp],
@@ -78,8 +103,7 @@ def worker(
78
103
  task_ins: TaskIns = taskins_queue.get(timeout=1.0)
79
104
  node_id = task_ins.task.consumer.node_id
80
105
 
81
- # Register and retrieve context
82
- node_states[node_id].register_context(run_id=task_ins.run_id)
106
+ # Retrieve context
83
107
  context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
84
108
 
85
109
  # Convert TaskIns to Message
@@ -151,7 +175,7 @@ def put_taskres_into_state(
151
175
  pass
152
176
 
153
177
 
154
- def run(
178
+ def run_api(
155
179
  app_fn: Callable[[], ClientApp],
156
180
  backend_fn: Callable[[], Backend],
157
181
  nodes_mapping: NodeToPartitionMapping,
@@ -236,7 +260,10 @@ def start_vce(
236
260
  backend_name: str,
237
261
  backend_config_json_stream: str,
238
262
  app_dir: str,
263
+ is_app: bool,
239
264
  f_stop: threading.Event,
265
+ run: Run,
266
+ flwr_dir: Optional[str] = None,
240
267
  client_app: Optional[ClientApp] = None,
241
268
  client_app_attr: Optional[str] = None,
242
269
  num_supernodes: Optional[int] = None,
@@ -287,17 +314,9 @@ def start_vce(
287
314
  )
288
315
 
289
316
  # Construct mapping of NodeStates
290
- node_states: Dict[int, NodeState] = {}
291
- # Number of unique partitions
292
- num_partitions = len(set(nodes_mapping.values()))
293
- for node_id, partition_id in nodes_mapping.items():
294
- node_states[node_id] = NodeState(
295
- node_id=node_id,
296
- node_config={
297
- PARTITION_ID_KEY: str(partition_id),
298
- NUM_PARTITIONS_KEY: str(num_partitions),
299
- },
300
- )
317
+ node_states = _register_node_states(
318
+ nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
319
+ )
301
320
 
302
321
  # Load backend config
303
322
  log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
@@ -326,16 +345,12 @@ def start_vce(
326
345
  def _load() -> ClientApp:
327
346
 
328
347
  if client_app_attr:
329
-
330
- if app_dir is not None:
331
- sys.path.insert(0, app_dir)
332
-
333
- app: ClientApp = load_app(client_app_attr, LoadClientAppError, app_dir)
334
-
335
- if not isinstance(app, ClientApp):
336
- raise LoadClientAppError(
337
- f"Attribute {client_app_attr} is not of type {ClientApp}",
338
- ) from None
348
+ app = _get_load_client_app_fn(
349
+ default_app_ref=client_app_attr,
350
+ dir_arg=app_dir,
351
+ flwr_dir_arg=flwr_dir,
352
+ multi_app=True,
353
+ )(run.fab_id, run.fab_version)
339
354
 
340
355
  if client_app:
341
356
  app = client_app
@@ -345,10 +360,19 @@ def start_vce(
345
360
 
346
361
  try:
347
362
  # Test if ClientApp can be loaded
348
- _ = app_fn()
363
+ client_app = app_fn()
364
+
365
+ # Cache `ClientApp`
366
+ if client_app_attr:
367
+ # Now wrap the loaded ClientApp in a dummy function
368
+ # this prevent unnecesary low-level loading of ClientApp
369
+ def _load_client_app() -> ClientApp:
370
+ return client_app
371
+
372
+ app_fn = _load_client_app
349
373
 
350
374
  # Run main simulation loop
351
- run(
375
+ run_api(
352
376
  app_fn,
353
377
  backend_fn,
354
378
  nodes_mapping,