flwr-nightly 1.13.0.dev20241106__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (45) hide show
  1. flwr/cli/run/run.py +16 -5
  2. flwr/client/app.py +10 -6
  3. flwr/client/clientapp/app.py +21 -16
  4. flwr/client/nodestate/__init__.py +25 -0
  5. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  6. flwr/client/nodestate/nodestate.py +30 -0
  7. flwr/client/nodestate/nodestate_factory.py +37 -0
  8. flwr/common/args.py +83 -0
  9. flwr/common/config.py +10 -0
  10. flwr/common/constant.py +0 -1
  11. flwr/common/logger.py +6 -2
  12. flwr/common/object_ref.py +47 -16
  13. flwr/common/typing.py +1 -1
  14. flwr/proto/exec_pb2.py +14 -17
  15. flwr/proto/exec_pb2.pyi +6 -20
  16. flwr/proto/run_pb2.py +32 -27
  17. flwr/proto/run_pb2.pyi +26 -0
  18. flwr/proto/simulationio_pb2.py +2 -2
  19. flwr/proto/simulationio_pb2_grpc.py +34 -0
  20. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  21. flwr/server/app.py +45 -20
  22. flwr/server/driver/driver.py +1 -1
  23. flwr/server/driver/grpc_driver.py +2 -6
  24. flwr/server/driver/inmemory_driver.py +1 -3
  25. flwr/server/run_serverapp.py +2 -2
  26. flwr/server/serverapp/app.py +16 -72
  27. flwr/server/strategy/aggregate.py +4 -4
  28. flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -16
  29. flwr/server/superlink/linkstate/linkstate.py +5 -4
  30. flwr/server/superlink/linkstate/sqlite_linkstate.py +6 -15
  31. flwr/server/superlink/linkstate/utils.py +2 -33
  32. flwr/server/superlink/simulation/simulationio_servicer.py +22 -1
  33. flwr/simulation/__init__.py +3 -1
  34. flwr/simulation/app.py +273 -345
  35. flwr/simulation/legacy_app.py +382 -0
  36. flwr/simulation/run_simulation.py +1 -1
  37. flwr/superexec/deployment.py +1 -1
  38. flwr/superexec/exec_servicer.py +2 -2
  39. flwr/superexec/executor.py +4 -3
  40. flwr/superexec/simulation.py +44 -102
  41. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
  42. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +45 -39
  43. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +1 -0
  44. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
  45. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
flwr/cli/run/run.py CHANGED
@@ -29,10 +29,18 @@ from flwr.cli.config_utils import (
29
29
  validate_federation_in_project_config,
30
30
  validate_project_config,
31
31
  )
32
- from flwr.common.config import flatten_dict, parse_config_args
32
+ from flwr.common.config import (
33
+ flatten_dict,
34
+ parse_config_args,
35
+ user_config_to_configsrecord,
36
+ )
33
37
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
34
38
  from flwr.common.logger import log
35
- from flwr.common.serde import fab_to_proto, user_config_to_proto
39
+ from flwr.common.serde import (
40
+ configs_record_to_proto,
41
+ fab_to_proto,
42
+ user_config_to_proto,
43
+ )
36
44
  from flwr.common.typing import Fab
37
45
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
38
46
  from flwr.proto.exec_pb2_grpc import ExecStub
@@ -94,6 +102,7 @@ def run(
94
102
  _run_without_exec_api(app, federation_config, config_overrides, federation)
95
103
 
96
104
 
105
+ # pylint: disable-next=too-many-locals
97
106
  def _run_with_exec_api(
98
107
  app: Path,
99
108
  federation_config: dict[str, Any],
@@ -118,12 +127,14 @@ def _run_with_exec_api(
118
127
  content = Path(fab_path).read_bytes()
119
128
  fab = Fab(fab_hash, content)
120
129
 
130
+ # Construct a `ConfigsRecord` out of a flattened `UserConfig`
131
+ fed_conf = flatten_dict(federation_config.get("options", {}))
132
+ c_record = user_config_to_configsrecord(fed_conf)
133
+
121
134
  req = StartRunRequest(
122
135
  fab=fab_to_proto(fab),
123
136
  override_config=user_config_to_proto(parse_config_args(config_overrides)),
124
- federation_config=user_config_to_proto(
125
- flatten_dict(federation_config.get("options"))
126
- ),
137
+ federation_options=configs_record_to_proto(c_record),
127
138
  )
128
139
  res = stub.StartRun(req)
129
140
 
flwr/client/app.py CHANGED
@@ -32,6 +32,7 @@ from flwr.cli.config_utils import get_fab_metadata
32
32
  from flwr.cli.install import install_from_fab
33
33
  from flwr.client.client import Client
34
34
  from flwr.client.client_app import ClientApp, LoadClientAppError
35
+ from flwr.client.nodestate.nodestate_factory import NodeStateFactory
35
36
  from flwr.client.typing import ClientFnExt
36
37
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
37
38
  from flwr.common.address import parse_address
@@ -365,6 +366,8 @@ def start_client_internal(
365
366
 
366
367
  # DeprecatedRunInfoStore gets initialized when the first connection is established
367
368
  run_info_store: Optional[DeprecatedRunInfoStore] = None
369
+ state_factory = NodeStateFactory()
370
+ state = state_factory.state()
368
371
 
369
372
  runs: dict[int, Run] = {}
370
373
 
@@ -396,13 +399,14 @@ def start_client_internal(
396
399
  )
397
400
  else:
398
401
  # Call create_node fn to register node
399
- node_id: Optional[int] = ( # pylint: disable=assignment-from-none
400
- create_node()
401
- ) # pylint: disable=not-callable
402
- if node_id is None:
403
- raise ValueError("Node registration failed")
402
+ # and store node_id in state
403
+ if (node_id := create_node()) is None:
404
+ raise ValueError(
405
+ "Failed to register SuperNode with the SuperLink"
406
+ )
407
+ state.set_node_id(node_id)
404
408
  run_info_store = DeprecatedRunInfoStore(
405
- node_id=node_id,
409
+ node_id=state.get_node_id(),
406
410
  node_config=node_config,
407
411
  )
408
412
 
@@ -24,6 +24,8 @@ import grpc
24
24
  from flwr.cli.install import install_from_fab
25
25
  from flwr.client.client_app import ClientApp, LoadClientAppError
26
26
  from flwr.common import Context, Message
27
+ from flwr.common.args import add_args_flwr_app_common
28
+ from flwr.common.config import get_flwr_dir
27
29
  from flwr.common.constant import ErrorCode
28
30
  from flwr.common.grpc import create_channel
29
31
  from flwr.common.logger import log
@@ -60,7 +62,7 @@ def flwr_clientapp() -> None:
60
62
  parser.add_argument(
61
63
  "--supernode",
62
64
  type=str,
63
- help="Address of SuperNode ClientAppIo gRPC servicer",
65
+ help="Address of SuperNode's ClientAppIo API",
64
66
  )
65
67
  parser.add_argument(
66
68
  "--token",
@@ -68,17 +70,24 @@ def flwr_clientapp() -> None:
68
70
  required=False,
69
71
  help="Unique token generated by SuperNode for each ClientApp execution",
70
72
  )
73
+ add_args_flwr_app_common(parser=parser)
71
74
  args = parser.parse_args()
72
75
 
73
76
  log(INFO, "Starting Flower ClientApp")
77
+
74
78
  log(
75
79
  DEBUG,
76
- "Staring isolated `ClientApp` connected to SuperNode ClientAppIo at %s "
80
+ "Starting isolated `ClientApp` connected to SuperNode's ClientAppIo API at %s "
77
81
  "with token %s",
78
82
  args.supernode,
79
83
  args.token,
80
84
  )
81
- run_clientapp(supernode=args.supernode, token=args.token)
85
+ run_clientapp(
86
+ supernode=args.supernode,
87
+ run_once=(args.token is not None),
88
+ token=args.token,
89
+ flwr_dir=args.flwr_dir,
90
+ )
82
91
 
83
92
 
84
93
  def on_channel_state_change(channel_connectivity: str) -> None:
@@ -88,27 +97,23 @@ def on_channel_state_change(channel_connectivity: str) -> None:
88
97
 
89
98
  def run_clientapp( # pylint: disable=R0914
90
99
  supernode: str,
100
+ run_once: bool,
91
101
  token: Optional[int] = None,
102
+ flwr_dir: Optional[str] = None,
92
103
  ) -> None:
93
- """Run Flower ClientApp process.
94
-
95
- Parameters
96
- ----------
97
- supernode : str
98
- Address of SuperNode
99
- token : Optional[int] (default: None)
100
- Unique SuperNode token for ClientApp-SuperNode authentication
101
- """
104
+ """Run Flower ClientApp process."""
102
105
  channel = create_channel(
103
106
  server_address=supernode,
104
107
  insecure=True,
105
108
  )
106
109
  channel.subscribe(on_channel_state_change)
107
110
 
111
+ # Resolve directory where FABs are installed
112
+ flwr_dir_ = get_flwr_dir(flwr_dir)
113
+
108
114
  try:
109
115
  stub = ClientAppIoStub(channel)
110
116
 
111
- only_once = token is not None
112
117
  while True:
113
118
  # If token is not set, loop until token is received from SuperNode
114
119
  while token is None:
@@ -121,13 +126,13 @@ def run_clientapp( # pylint: disable=R0914
121
126
  # Install FAB, if provided
122
127
  if fab:
123
128
  log(DEBUG, "Flower ClientApp starts FAB installation.")
124
- install_from_fab(fab.content, flwr_dir=None, skip_prompt=True)
129
+ install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
125
130
 
126
131
  load_client_app_fn = get_load_client_app_fn(
127
132
  default_app_ref="",
128
133
  app_path=None,
129
134
  multi_app=True,
130
- flwr_dir=None,
135
+ flwr_dir=str(flwr_dir_),
131
136
  )
132
137
 
133
138
  try:
@@ -169,7 +174,7 @@ def run_clientapp( # pylint: disable=R0914
169
174
 
170
175
  # Stop the loop if `flwr-clientapp` is expected to process only a single
171
176
  # message
172
- if only_once:
177
+ if run_once:
173
178
  break
174
179
 
175
180
  except KeyboardInterrupt:
@@ -0,0 +1,25 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower NodeState."""
16
+
17
+ from .in_memory_nodestate import InMemoryNodeState as InMemoryNodeState
18
+ from .nodestate import NodeState as NodeState
19
+ from .nodestate_factory import NodeStateFactory as NodeStateFactory
20
+
21
+ __all__ = [
22
+ "InMemoryNodeState",
23
+ "NodeState",
24
+ "NodeStateFactory",
25
+ ]
@@ -0,0 +1,38 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """In-memory NodeState implementation."""
16
+
17
+
18
+ from typing import Optional
19
+
20
+ from flwr.client.nodestate.nodestate import NodeState
21
+
22
+
23
+ class InMemoryNodeState(NodeState):
24
+ """In-memory NodeState implementation."""
25
+
26
+ def __init__(self) -> None:
27
+ # Store node_id
28
+ self.node_id: Optional[int] = None
29
+
30
+ def set_node_id(self, node_id: Optional[int]) -> None:
31
+ """Set the node ID."""
32
+ self.node_id = node_id
33
+
34
+ def get_node_id(self) -> int:
35
+ """Get the node ID."""
36
+ if self.node_id is None:
37
+ raise ValueError("Node ID not set")
38
+ return self.node_id
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Abstract base class NodeState."""
16
+
17
+ import abc
18
+ from typing import Optional
19
+
20
+
21
+ class NodeState(abc.ABC):
22
+ """Abstract NodeState."""
23
+
24
+ @abc.abstractmethod
25
+ def set_node_id(self, node_id: Optional[int]) -> None:
26
+ """Set the node ID."""
27
+
28
+ @abc.abstractmethod
29
+ def get_node_id(self) -> int:
30
+ """Get the node ID."""
@@ -0,0 +1,37 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Factory class that creates NodeState instances."""
16
+
17
+ import threading
18
+ from typing import Optional
19
+
20
+ from .in_memory_nodestate import InMemoryNodeState
21
+ from .nodestate import NodeState
22
+
23
+
24
+ class NodeStateFactory:
25
+ """Factory class that creates NodeState instances."""
26
+
27
+ def __init__(self) -> None:
28
+ self.state_instance: Optional[NodeState] = None
29
+ self.lock = threading.RLock()
30
+
31
+ def state(self) -> NodeState:
32
+ """Return a State instance and create it, if necessary."""
33
+ # Lock access to NodeStateFactory to prevent returning different instances
34
+ with self.lock:
35
+ if self.state_instance is None:
36
+ self.state_instance = InMemoryNodeState()
37
+ return self.state_instance
flwr/common/args.py ADDED
@@ -0,0 +1,83 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Common Flower arguments."""
16
+
17
+ import argparse
18
+ import sys
19
+ from logging import DEBUG, WARN
20
+ from os.path import isfile
21
+ from pathlib import Path
22
+ from typing import Optional
23
+
24
+ from flwr.common.logger import log
25
+
26
+
27
+ def add_args_flwr_app_common(parser: argparse.ArgumentParser) -> None:
28
+ """Add common Flower arguments for flwr-*app to the provided parser."""
29
+ parser.add_argument(
30
+ "--flwr-dir",
31
+ default=None,
32
+ help="""The path containing installed Flower Apps.
33
+ By default, this value is equal to:
34
+
35
+ - `$FLWR_HOME/` if `$FLWR_HOME` is defined
36
+ - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
37
+ - `$HOME/.flwr/` in all other cases
38
+ """,
39
+ )
40
+ parser.add_argument(
41
+ "--insecure",
42
+ action="store_true",
43
+ help="Run the server without HTTPS, regardless of whether certificate "
44
+ "paths are provided. By default, the server runs with HTTPS enabled. "
45
+ "Use this flag only if you understand the risks.",
46
+ )
47
+ parser.add_argument(
48
+ "--root-certificates",
49
+ metavar="ROOT_CERT",
50
+ type=str,
51
+ help="Specifies the path to the PEM-encoded root certificate file for "
52
+ "establishing secure HTTPS connections.",
53
+ )
54
+
55
+
56
+ def try_obtain_certificates(
57
+ args: argparse.Namespace,
58
+ ) -> Optional[bytes]:
59
+ """Validate and return the root certificates."""
60
+ if args.insecure:
61
+ if args.root_certificates is not None:
62
+ sys.exit(
63
+ "Conflicting options: The '--insecure' flag disables HTTPS, "
64
+ "but '--root-certificates' was also specified. Please remove "
65
+ "the '--root-certificates' option when running in insecure mode, "
66
+ "or omit '--insecure' to use HTTPS."
67
+ )
68
+ log(
69
+ WARN,
70
+ "Option `--insecure` was set. Starting insecure HTTP channel.",
71
+ )
72
+ root_certificates = None
73
+ else:
74
+ # Load the certificates if provided, or load the system certificates
75
+ if not isfile(args.root_certificates):
76
+ sys.exit("Path argument `--root-certificates` does not point to a file.")
77
+ root_certificates = Path(args.root_certificates).read_bytes()
78
+ log(
79
+ DEBUG,
80
+ "Starting secure HTTPS channel with the following certificates: %s.",
81
+ args.root_certificates,
82
+ )
83
+ return root_certificates
flwr/common/config.py CHANGED
@@ -22,6 +22,7 @@ from typing import Any, Optional, Union, cast, get_args
22
22
  import tomli
23
23
 
24
24
  from flwr.cli.config_utils import get_fab_config, validate_fields
25
+ from flwr.common import ConfigsRecord
25
26
  from flwr.common.constant import (
26
27
  APP_DIR,
27
28
  FAB_CONFIG_FILE,
@@ -229,3 +230,12 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
229
230
  config["project"]["version"],
230
231
  f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
231
232
  )
233
+
234
+
235
+ def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord:
236
+ """Construct a `ConfigsRecord` out of a `UserConfig`."""
237
+ c_record = ConfigsRecord()
238
+ for k, v in config.items():
239
+ c_record[k] = v
240
+
241
+ return c_record
flwr/common/constant.py CHANGED
@@ -134,7 +134,6 @@ class ErrorCode:
134
134
  UNKNOWN = 0
135
135
  LOAD_CLIENT_APP_EXCEPTION = 1
136
136
  CLIENT_APP_RAISED_EXCEPTION = 2
137
- NODE_UNAVAILABLE = 3
138
137
 
139
138
  def __new__(cls) -> ErrorCode:
140
139
  """Prevent instantiation."""
flwr/common/logger.py CHANGED
@@ -22,13 +22,14 @@ import time
22
22
  from logging import WARN, LogRecord
23
23
  from logging.handlers import HTTPHandler
24
24
  from queue import Empty, Queue
25
- from typing import TYPE_CHECKING, Any, Optional, TextIO
25
+ from typing import TYPE_CHECKING, Any, Optional, TextIO, Union
26
26
 
27
27
  import grpc
28
28
 
29
29
  from flwr.proto.log_pb2 import PushLogsRequest # pylint: disable=E0611
30
30
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
31
31
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
32
+ from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
32
33
 
33
34
  from .constant import LOG_UPLOAD_INTERVAL
34
35
 
@@ -346,7 +347,10 @@ def _log_uploader(
346
347
 
347
348
 
348
349
  def start_log_uploader(
349
- log_queue: Queue[Optional[str]], node_id: int, run_id: int, stub: ServerAppIoStub
350
+ log_queue: Queue[Optional[str]],
351
+ node_id: int,
352
+ run_id: int,
353
+ stub: Union[ServerAppIoStub, SimulationIoStub],
350
354
  ) -> threading.Thread:
351
355
  """Start the log uploader thread and return it."""
352
356
  thread = threading.Thread(
flwr/common/object_ref.py CHANGED
@@ -55,8 +55,8 @@ def validate(
55
55
  specified attribute within it.
56
56
  project_dir : Optional[Union[str, Path]] (default: None)
57
57
  The directory containing the module. If None, the current working directory
58
- is used. If `check_module` is True, the `project_dir` will be inserted into
59
- the system path, and the previously inserted `project_dir` will be removed.
58
+ is used. If `check_module` is True, the `project_dir` will be temporarily
59
+ inserted into the system path and then removed after the validation is complete.
60
60
 
61
61
  Returns
62
62
  -------
@@ -66,8 +66,8 @@ def validate(
66
66
 
67
67
  Note
68
68
  ----
69
- This function will modify `sys.path` by inserting the provided `project_dir`
70
- and removing the previously inserted `project_dir`.
69
+ This function will temporarily modify `sys.path` by inserting the provided
70
+ `project_dir`, which will be removed after the validation is complete.
71
71
  """
72
72
  module_str, _, attributes_str = module_attribute_str.partition(":")
73
73
  if not module_str:
@@ -82,11 +82,19 @@ def validate(
82
82
  )
83
83
 
84
84
  if check_module:
85
+ if project_dir is None:
86
+ project_dir = Path.cwd()
87
+ project_dir = Path(project_dir).absolute()
85
88
  # Set the system path
86
- _set_sys_path(project_dir)
89
+ sys.path.insert(0, str(project_dir))
87
90
 
88
91
  # Load module
89
92
  module = find_spec(module_str)
93
+
94
+ # Unset the system path
95
+ sys.path.remove(str(project_dir))
96
+
97
+ # Check if the module and the attribute exist
90
98
  if module and module.origin:
91
99
  if not _find_attribute_in_module(module.origin, attributes_str):
92
100
  return (
@@ -133,8 +141,10 @@ def load_app( # pylint: disable= too-many-branches
133
141
 
134
142
  Note
135
143
  ----
136
- This function will modify `sys.path` by inserting the provided `project_dir`
137
- and removing the previously inserted `project_dir`.
144
+ - This function will unload all modules in the previously provided `project_dir`,
145
+ if it is invoked again.
146
+ - This function will modify `sys.path` by inserting the provided `project_dir`
147
+ and removing the previously inserted `project_dir`.
138
148
  """
139
149
  valid, error_msg = validate(module_attribute_str, check_module=False)
140
150
  if not valid and error_msg:
@@ -143,8 +153,19 @@ def load_app( # pylint: disable= too-many-branches
143
153
  module_str, _, attributes_str = module_attribute_str.partition(":")
144
154
 
145
155
  try:
156
+ # Initialize project path
157
+ if project_dir is None:
158
+ project_dir = Path.cwd()
159
+ project_dir = Path(project_dir).absolute()
160
+
161
+ # Unload modules if the project directory has changed
162
+ if _current_sys_path and _current_sys_path != str(project_dir):
163
+ _unload_modules(Path(_current_sys_path))
164
+
165
+ # Set the system path
146
166
  _set_sys_path(project_dir)
147
167
 
168
+ # Import the module
148
169
  if module_str not in sys.modules:
149
170
  module = importlib.import_module(module_str)
150
171
  # Hack: `tabnet` does not work with `importlib.reload`
@@ -160,15 +181,7 @@ def load_app( # pylint: disable= too-many-branches
160
181
  module = sys.modules[module_str]
161
182
  else:
162
183
  module = sys.modules[module_str]
163
-
164
- if project_dir is None:
165
- project_dir = Path.cwd()
166
-
167
- # Reload cached modules in the project directory
168
- for m in list(sys.modules.values()):
169
- path: Optional[str] = getattr(m, "__file__", None)
170
- if path is not None and path.startswith(str(project_dir)):
171
- importlib.reload(m)
184
+ _reload_modules(project_dir)
172
185
 
173
186
  except ModuleNotFoundError as err:
174
187
  raise error_type(
@@ -189,6 +202,24 @@ def load_app( # pylint: disable= too-many-branches
189
202
  return attribute
190
203
 
191
204
 
205
+ def _unload_modules(project_dir: Path) -> None:
206
+ """Unload modules from the project directory."""
207
+ dir_str = str(project_dir.absolute())
208
+ for name, m in list(sys.modules.items()):
209
+ path: Optional[str] = getattr(m, "__file__", None)
210
+ if path is not None and path.startswith(dir_str):
211
+ del sys.modules[name]
212
+
213
+
214
+ def _reload_modules(project_dir: Path) -> None:
215
+ """Reload modules from the project directory."""
216
+ dir_str = str(project_dir.absolute())
217
+ for m in list(sys.modules.values()):
218
+ path: Optional[str] = getattr(m, "__file__", None)
219
+ if path is not None and path.startswith(dir_str):
220
+ importlib.reload(m)
221
+
222
+
192
223
  def _set_sys_path(directory: Optional[Union[str, Path]]) -> None:
193
224
  """Set the system path."""
194
225
  if directory is None:
flwr/common/typing.py CHANGED
@@ -24,7 +24,7 @@ import numpy.typing as npt
24
24
 
25
25
  NDArray = npt.NDArray[Any]
26
26
  NDArrayInt = npt.NDArray[np.int_]
27
- NDArrayFloat = npt.NDArray[np.float_]
27
+ NDArrayFloat = npt.NDArray[np.float64]
28
28
  NDArrays = list[NDArray]
29
29
 
30
30
  # The following union type contains Python types corresponding to ProtoBuf types that
flwr/proto/exec_pb2.py CHANGED
@@ -14,9 +14,10 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
16
16
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
17
+ from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
18
 
18
19
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xdf\x02\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12L\n\x11\x66\x65\x64\x65ration_config\x18\x03 \x03(\x0b\x32\x31.flwr.proto.StartRunRequest.FederationConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1aK\n\x15\x46\x65\x64\x65rationConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\x32\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3')
20
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\x32\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3')
20
21
 
21
22
  _globals = globals()
22
23
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -25,20 +26,16 @@ if _descriptor._USE_C_DESCRIPTORS == False:
25
26
  DESCRIPTOR._options = None
26
27
  _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None
27
28
  _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
28
- _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._options = None
29
- _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_options = b'8\001'
30
- _globals['_STARTRUNREQUEST']._serialized_start=88
31
- _globals['_STARTRUNREQUEST']._serialized_end=439
32
- _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=289
33
- _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=362
34
- _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_start=364
35
- _globals['_STARTRUNREQUEST_FEDERATIONCONFIGENTRY']._serialized_end=439
36
- _globals['_STARTRUNRESPONSE']._serialized_start=441
37
- _globals['_STARTRUNRESPONSE']._serialized_end=475
38
- _globals['_STREAMLOGSREQUEST']._serialized_start=477
39
- _globals['_STREAMLOGSREQUEST']._serialized_end=537
40
- _globals['_STREAMLOGSRESPONSE']._serialized_start=539
41
- _globals['_STREAMLOGSRESPONSE']._serialized_end=605
42
- _globals['_EXEC']._serialized_start=608
43
- _globals['_EXEC']._serialized_end=768
29
+ _globals['_STARTRUNREQUEST']._serialized_start=116
30
+ _globals['_STARTRUNREQUEST']._serialized_end=367
31
+ _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=294
32
+ _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=367
33
+ _globals['_STARTRUNRESPONSE']._serialized_start=369
34
+ _globals['_STARTRUNRESPONSE']._serialized_end=403
35
+ _globals['_STREAMLOGSREQUEST']._serialized_start=405
36
+ _globals['_STREAMLOGSREQUEST']._serialized_end=465
37
+ _globals['_STREAMLOGSRESPONSE']._serialized_start=467
38
+ _globals['_STREAMLOGSRESPONSE']._serialized_end=533
39
+ _globals['_EXEC']._serialized_start=536
40
+ _globals['_EXEC']._serialized_end=696
44
41
  # @@protoc_insertion_point(module_scope)