flwr-nightly 1.12.0.dev20241008__py3-none-any.whl → 1.13.0.dev20241015__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +60 -29
- flwr/cli/config_utils.py +10 -0
- flwr/cli/install.py +60 -20
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -5
- flwr/client/app.py +13 -3
- flwr/client/clientapp/app.py +3 -1
- flwr/client/clientapp/utils.py +11 -5
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +4 -1
- flwr/client/node_state.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +12 -5
- flwr/common/config.py +19 -5
- flwr/common/logger.py +1 -1
- flwr/common/message.py +6 -1
- flwr/common/record/configsrecord.py +6 -0
- flwr/common/recordset_compat.py +10 -0
- flwr/common/retry_invoker.py +15 -0
- flwr/server/app.py +3 -2
- flwr/server/client_manager.py +2 -0
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +1 -1
- flwr/server/driver/inmemory_driver.py +2 -2
- flwr/server/run_serverapp.py +11 -13
- flwr/server/server_app.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +2 -2
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -3
- flwr/server/superlink/fleet/vce/vce_api.py +9 -6
- flwr/server/superlink/state/in_memory_state.py +26 -8
- flwr/server/superlink/state/sqlite_state.py +46 -11
- flwr/server/superlink/state/state.py +1 -7
- flwr/server/superlink/state/utils.py +0 -10
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +4 -4
- {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/METADATA +1 -1
- {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/RECORD +52 -52
- {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20241008.dist-info → flwr_nightly-1.13.0.dev20241015.dist-info}/entry_points.txt +0 -0
flwr/common/config.py
CHANGED
|
@@ -22,7 +22,12 @@ from typing import Any, Optional, Union, cast, get_args
|
|
|
22
22
|
import tomli
|
|
23
23
|
|
|
24
24
|
from flwr.cli.config_utils import get_fab_config, validate_fields
|
|
25
|
-
from flwr.common.constant import
|
|
25
|
+
from flwr.common.constant import (
|
|
26
|
+
APP_DIR,
|
|
27
|
+
FAB_CONFIG_FILE,
|
|
28
|
+
FAB_HASH_TRUNCATION,
|
|
29
|
+
FLWR_HOME,
|
|
30
|
+
)
|
|
26
31
|
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
|
27
32
|
|
|
28
33
|
|
|
@@ -39,7 +44,10 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
|
39
44
|
|
|
40
45
|
|
|
41
46
|
def get_project_dir(
|
|
42
|
-
fab_id: str,
|
|
47
|
+
fab_id: str,
|
|
48
|
+
fab_version: str,
|
|
49
|
+
fab_hash: str,
|
|
50
|
+
flwr_dir: Optional[Union[str, Path]] = None,
|
|
43
51
|
) -> Path:
|
|
44
52
|
"""Return the project directory based on the given fab_id and fab_version."""
|
|
45
53
|
# Check the fab_id
|
|
@@ -50,7 +58,11 @@ def get_project_dir(
|
|
|
50
58
|
publisher, project_name = fab_id.split("/")
|
|
51
59
|
if flwr_dir is None:
|
|
52
60
|
flwr_dir = get_flwr_dir()
|
|
53
|
-
return
|
|
61
|
+
return (
|
|
62
|
+
Path(flwr_dir)
|
|
63
|
+
/ APP_DIR
|
|
64
|
+
/ f"{publisher}.{project_name}.{fab_version}.{fab_hash[:FAB_HASH_TRUNCATION]}"
|
|
65
|
+
)
|
|
54
66
|
|
|
55
67
|
|
|
56
68
|
def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
|
|
@@ -127,7 +139,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
|
|
|
127
139
|
if not run.fab_id or not run.fab_version:
|
|
128
140
|
return {}
|
|
129
141
|
|
|
130
|
-
project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
|
|
142
|
+
project_dir = get_project_dir(run.fab_id, run.fab_version, run.fab_hash, flwr_dir)
|
|
131
143
|
|
|
132
144
|
# Return empty dict if project directory does not exist
|
|
133
145
|
if not project_dir.is_dir():
|
|
@@ -194,6 +206,7 @@ def parse_config_args(
|
|
|
194
206
|
# Regular expression to capture key-value pairs with possible quoted values
|
|
195
207
|
pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)")
|
|
196
208
|
|
|
209
|
+
flat_overrides = {}
|
|
197
210
|
for config_line in config:
|
|
198
211
|
if config_line:
|
|
199
212
|
# .toml files aren't allowed alongside other configs
|
|
@@ -205,8 +218,9 @@ def parse_config_args(
|
|
|
205
218
|
matches = pattern.findall(config_line)
|
|
206
219
|
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
|
|
207
220
|
overrides.update(tomli.loads(toml_str))
|
|
221
|
+
flat_overrides = flatten_dict(overrides)
|
|
208
222
|
|
|
209
|
-
return
|
|
223
|
+
return flat_overrides
|
|
210
224
|
|
|
211
225
|
|
|
212
226
|
def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
|
flwr/common/logger.py
CHANGED
|
@@ -111,7 +111,7 @@ FLOWER_LOGGER.addHandler(console_handler)
|
|
|
111
111
|
class CustomHTTPHandler(HTTPHandler):
|
|
112
112
|
"""Custom HTTPHandler which overrides the mapLogRecords method."""
|
|
113
113
|
|
|
114
|
-
# pylint: disable=too-many-arguments,bad-option-value,R1725
|
|
114
|
+
# pylint: disable=too-many-arguments,bad-option-value,R1725,R0917
|
|
115
115
|
def __init__(
|
|
116
116
|
self,
|
|
117
117
|
identifier: str,
|
flwr/common/message.py
CHANGED
|
@@ -52,7 +52,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
52
52
|
the receiving end.
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
55
|
+
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
56
56
|
self,
|
|
57
57
|
run_id: int,
|
|
58
58
|
message_id: str,
|
|
@@ -290,6 +290,11 @@ class Message:
|
|
|
290
290
|
follows the equation:
|
|
291
291
|
|
|
292
292
|
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
|
|
293
|
+
|
|
294
|
+
Returns
|
|
295
|
+
-------
|
|
296
|
+
message : Message
|
|
297
|
+
A Message containing only the relevant error and metadata.
|
|
293
298
|
"""
|
|
294
299
|
# If no TTL passed, use default for message creation (will update after
|
|
295
300
|
# message creation)
|
|
@@ -128,6 +128,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
|
128
128
|
|
|
129
129
|
def get_var_bytes(value: ConfigsScalar) -> int:
|
|
130
130
|
"""Return Bytes of value passed."""
|
|
131
|
+
var_bytes = 0
|
|
131
132
|
if isinstance(value, bool):
|
|
132
133
|
var_bytes = 1
|
|
133
134
|
elif isinstance(value, (int, float)):
|
|
@@ -136,6 +137,11 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
|
136
137
|
)
|
|
137
138
|
if isinstance(value, (str, bytes)):
|
|
138
139
|
var_bytes = len(value)
|
|
140
|
+
if var_bytes == 0:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Config values must be either `bool`, `int`, `float`, "
|
|
143
|
+
"`str`, or `bytes`"
|
|
144
|
+
)
|
|
139
145
|
return var_bytes
|
|
140
146
|
|
|
141
147
|
num_bytes = 0
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -59,6 +59,11 @@ def parametersrecord_to_parameters(
|
|
|
59
59
|
keep_input : bool
|
|
60
60
|
A boolean indicating whether entries in the record should be deleted from the
|
|
61
61
|
input dictionary immediately after adding them to the record.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
parameters : Parameters
|
|
66
|
+
The parameters in the legacy format Parameters.
|
|
62
67
|
"""
|
|
63
68
|
parameters = Parameters(tensors=[], tensor_type="")
|
|
64
69
|
|
|
@@ -94,6 +99,11 @@ def parameters_to_parametersrecord(
|
|
|
94
99
|
A boolean indicating whether parameters should be deleted from the input
|
|
95
100
|
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
|
|
96
101
|
adding them to the record.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
ParametersRecord
|
|
106
|
+
The ParametersRecord containing the provided parameters.
|
|
97
107
|
"""
|
|
98
108
|
tensor_type = parameters.tensor_type
|
|
99
109
|
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -38,6 +38,11 @@ def exponential(
|
|
|
38
38
|
Factor by which the delay is multiplied after each retry.
|
|
39
39
|
max_delay: Optional[float] (default: None)
|
|
40
40
|
The maximum delay duration between two consecutive retries.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Generator[float, None, None]
|
|
45
|
+
A generator for the delay between 2 retries.
|
|
41
46
|
"""
|
|
42
47
|
delay = base_delay if max_delay is None else min(base_delay, max_delay)
|
|
43
48
|
while True:
|
|
@@ -56,6 +61,11 @@ def constant(
|
|
|
56
61
|
----------
|
|
57
62
|
interval: Union[float, Iterable[float]] (default: 1)
|
|
58
63
|
A constant value to yield or an iterable of such values.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Generator[float, None, None]
|
|
68
|
+
A generator for the delay between 2 retries.
|
|
59
69
|
"""
|
|
60
70
|
if not isinstance(interval, Iterable):
|
|
61
71
|
interval = itertools.repeat(interval)
|
|
@@ -73,6 +83,11 @@ def full_jitter(max_value: float) -> float:
|
|
|
73
83
|
----------
|
|
74
84
|
max_value : float
|
|
75
85
|
The upper limit for the randomized value.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
float
|
|
90
|
+
A random float that is less than max_value.
|
|
76
91
|
"""
|
|
77
92
|
return random.uniform(0, max_value)
|
|
78
93
|
|
flwr/server/app.py
CHANGED
|
@@ -199,12 +199,12 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
199
199
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
200
200
|
def run_superlink() -> None:
|
|
201
201
|
"""Run Flower SuperLink (Driver API and Fleet API)."""
|
|
202
|
+
args = _parse_args_run_superlink().parse_args()
|
|
203
|
+
|
|
202
204
|
log(INFO, "Starting Flower SuperLink")
|
|
203
205
|
|
|
204
206
|
event(EventType.RUN_SUPERLINK_ENTER)
|
|
205
207
|
|
|
206
|
-
args = _parse_args_run_superlink().parse_args()
|
|
207
|
-
|
|
208
208
|
# Parse IP address
|
|
209
209
|
driver_address, _, _ = _format_address(args.driver_api_address)
|
|
210
210
|
|
|
@@ -542,6 +542,7 @@ def _run_fleet_api_grpc_adapter(
|
|
|
542
542
|
|
|
543
543
|
|
|
544
544
|
# pylint: disable=import-outside-toplevel,too-many-arguments
|
|
545
|
+
# pylint: disable=too-many-positional-arguments
|
|
545
546
|
def _run_fleet_api_rest(
|
|
546
547
|
host: str,
|
|
547
548
|
port: int,
|
flwr/server/client_manager.py
CHANGED
|
@@ -47,6 +47,7 @@ class ClientManager(ABC):
|
|
|
47
47
|
Parameters
|
|
48
48
|
----------
|
|
49
49
|
client : flwr.server.client_proxy.ClientProxy
|
|
50
|
+
The ClientProxy of the Client to register.
|
|
50
51
|
|
|
51
52
|
Returns
|
|
52
53
|
-------
|
|
@@ -64,6 +65,7 @@ class ClientManager(ABC):
|
|
|
64
65
|
Parameters
|
|
65
66
|
----------
|
|
66
67
|
client : flwr.server.client_proxy.ClientProxy
|
|
68
|
+
The ClientProxy of the Client to unregister.
|
|
67
69
|
"""
|
|
68
70
|
|
|
69
71
|
@abstractmethod
|
flwr/server/driver/driver.py
CHANGED
|
@@ -158,7 +158,7 @@ class GrpcDriver(Driver):
|
|
|
158
158
|
):
|
|
159
159
|
raise ValueError(f"Invalid message: {message}")
|
|
160
160
|
|
|
161
|
-
def create_message( # pylint: disable=too-many-arguments
|
|
161
|
+
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
162
162
|
self,
|
|
163
163
|
content: RecordSet,
|
|
164
164
|
message_type: str,
|
|
@@ -82,7 +82,7 @@ class InMemoryDriver(Driver):
|
|
|
82
82
|
self._init_run()
|
|
83
83
|
return Run(**vars(cast(Run, self._run)))
|
|
84
84
|
|
|
85
|
-
def create_message( # pylint: disable=too-many-arguments
|
|
85
|
+
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
86
86
|
self,
|
|
87
87
|
content: RecordSet,
|
|
88
88
|
message_type: str,
|
|
@@ -150,7 +150,7 @@ class InMemoryDriver(Driver):
|
|
|
150
150
|
"""
|
|
151
151
|
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
|
152
152
|
# Pull TaskRes
|
|
153
|
-
task_res_list = self.state.get_task_res(task_ids=msg_ids
|
|
153
|
+
task_res_list = self.state.get_task_res(task_ids=msg_ids)
|
|
154
154
|
# Delete tasks in state
|
|
155
155
|
self.state.delete_tasks(msg_ids)
|
|
156
156
|
# Convert TaskRes to Message
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -181,19 +181,17 @@ def run_server_app() -> None:
|
|
|
181
181
|
)
|
|
182
182
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
183
183
|
run_ = driver.run
|
|
184
|
-
if run_.fab_hash:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
|
|
184
|
+
if not run_.fab_hash:
|
|
185
|
+
raise ValueError("FAB hash not provided.")
|
|
186
|
+
fab_req = GetFabRequest(hash_str=run_.fab_hash)
|
|
187
|
+
# pylint: disable-next=W0212
|
|
188
|
+
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
|
|
189
|
+
if fab_res.fab.hash_str != run_.fab_hash:
|
|
190
|
+
raise ValueError("FAB hashes don't match.")
|
|
191
|
+
install_from_fab(fab_res.fab.content, flwr_dir, True)
|
|
192
|
+
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
|
|
193
|
+
|
|
194
|
+
app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
|
|
197
195
|
config = get_project_config(app_path)
|
|
198
196
|
else:
|
|
199
197
|
# User provided `app_dir`, but not `--run-id`
|
flwr/server/server_app.py
CHANGED
|
@@ -88,7 +88,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
88
88
|
>>> )
|
|
89
89
|
"""
|
|
90
90
|
|
|
91
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
91
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
92
92
|
def __init__(
|
|
93
93
|
self,
|
|
94
94
|
strategy: Strategy,
|
|
@@ -307,7 +307,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
307
307
|
>>> )
|
|
308
308
|
"""
|
|
309
309
|
|
|
310
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
310
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
311
311
|
def __init__(
|
|
312
312
|
self,
|
|
313
313
|
strategy: Strategy,
|
|
@@ -39,7 +39,7 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
|
|
|
39
39
|
This class is deprecated and will be removed in a future release.
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
42
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
43
43
|
def __init__(
|
|
44
44
|
self,
|
|
45
45
|
strategy: Strategy,
|
|
@@ -36,7 +36,7 @@ class DPFedAvgFixed(Strategy):
|
|
|
36
36
|
This class is deprecated and will be removed in a future release.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
39
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
42
42
|
strategy: Strategy,
|
|
@@ -155,7 +155,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
155
155
|
context.add_callback(on_rpc_done)
|
|
156
156
|
|
|
157
157
|
# Read from state
|
|
158
|
-
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids
|
|
158
|
+
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
159
159
|
|
|
160
160
|
context.set_code(grpc.StatusCode.OK)
|
|
161
161
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
@@ -60,7 +60,7 @@ def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool:
|
|
|
60
60
|
return is_valid
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def start_grpc_server( # pylint: disable=too-many-arguments
|
|
63
|
+
def start_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
64
64
|
client_manager: ClientManager,
|
|
65
65
|
server_address: str,
|
|
66
66
|
max_concurrent_workers: int = 1000,
|
|
@@ -156,7 +156,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments
|
|
|
156
156
|
return server
|
|
157
157
|
|
|
158
158
|
|
|
159
|
-
def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
159
|
+
def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
160
160
|
servicer_and_add_fn: Union[
|
|
161
161
|
tuple[FleetServicer, AddServicerToServerFn],
|
|
162
162
|
tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
|
@@ -174,7 +174,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
174
174
|
|
|
175
175
|
Parameters
|
|
176
176
|
----------
|
|
177
|
-
servicer_and_add_fn :
|
|
177
|
+
servicer_and_add_fn : tuple
|
|
178
178
|
A tuple holding a servicer implementation and a matching
|
|
179
179
|
add_Servicer_to_server function.
|
|
180
180
|
server_address : str
|
|
@@ -214,6 +214,8 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
214
214
|
* CA certificate.
|
|
215
215
|
* server certificate.
|
|
216
216
|
* server private key.
|
|
217
|
+
interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None)
|
|
218
|
+
A list of gRPC interceptors.
|
|
217
219
|
|
|
218
220
|
Returns
|
|
219
221
|
-------
|
|
@@ -172,6 +172,7 @@ def put_taskres_into_state(
|
|
|
172
172
|
pass
|
|
173
173
|
|
|
174
174
|
|
|
175
|
+
# pylint: disable=too-many-positional-arguments
|
|
175
176
|
def run_api(
|
|
176
177
|
app_fn: Callable[[], ClientApp],
|
|
177
178
|
backend_fn: Callable[[], Backend],
|
|
@@ -251,7 +252,7 @@ def run_api(
|
|
|
251
252
|
|
|
252
253
|
|
|
253
254
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
254
|
-
# pylint: disable=too-many-statements
|
|
255
|
+
# pylint: disable=too-many-statements,too-many-positional-arguments
|
|
255
256
|
def start_vce(
|
|
256
257
|
backend_name: str,
|
|
257
258
|
backend_config_json_stream: str,
|
|
@@ -267,6 +268,8 @@ def start_vce(
|
|
|
267
268
|
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
|
268
269
|
) -> None:
|
|
269
270
|
"""Start Fleet API with the Simulation Engine."""
|
|
271
|
+
nodes_mapping = {}
|
|
272
|
+
|
|
270
273
|
if client_app_attr is not None and client_app is not None:
|
|
271
274
|
raise ValueError(
|
|
272
275
|
"Both `client_app_attr` and `client_app` are provided, "
|
|
@@ -340,17 +343,17 @@ def start_vce(
|
|
|
340
343
|
# Load ClientApp if needed
|
|
341
344
|
def _load() -> ClientApp:
|
|
342
345
|
|
|
346
|
+
if client_app:
|
|
347
|
+
return client_app
|
|
343
348
|
if client_app_attr:
|
|
344
|
-
|
|
349
|
+
return get_load_client_app_fn(
|
|
345
350
|
default_app_ref=client_app_attr,
|
|
346
351
|
app_path=app_dir,
|
|
347
352
|
flwr_dir=flwr_dir,
|
|
348
353
|
multi_app=False,
|
|
349
|
-
)(run.fab_id, run.fab_version)
|
|
354
|
+
)(run.fab_id, run.fab_version, run.fab_hash)
|
|
350
355
|
|
|
351
|
-
|
|
352
|
-
app = client_app
|
|
353
|
-
return app
|
|
356
|
+
raise ValueError("Either `client_app_attr` or `client_app` must be provided")
|
|
354
357
|
|
|
355
358
|
app_fn = _load
|
|
356
359
|
|
|
@@ -116,6 +116,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
116
116
|
# Return TaskIns
|
|
117
117
|
return task_ins_list
|
|
118
118
|
|
|
119
|
+
# pylint: disable=R0911
|
|
119
120
|
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
|
|
120
121
|
"""Store one TaskRes."""
|
|
121
122
|
# Validate task
|
|
@@ -129,6 +130,17 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
129
130
|
task_ins_id = task_res.task.ancestry[0]
|
|
130
131
|
task_ins = self.task_ins_store.get(UUID(task_ins_id))
|
|
131
132
|
|
|
133
|
+
# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
|
|
134
|
+
if (
|
|
135
|
+
task_ins
|
|
136
|
+
and task_res
|
|
137
|
+
and not (
|
|
138
|
+
task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
|
|
139
|
+
)
|
|
140
|
+
and task_ins.task.consumer.node_id != task_res.task.producer.node_id
|
|
141
|
+
):
|
|
142
|
+
return None
|
|
143
|
+
|
|
132
144
|
if task_ins is None:
|
|
133
145
|
log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id)
|
|
134
146
|
return None
|
|
@@ -178,27 +190,33 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
178
190
|
# Return the new task_id
|
|
179
191
|
return task_id
|
|
180
192
|
|
|
181
|
-
def get_task_res(self, task_ids: set[UUID]
|
|
193
|
+
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
182
194
|
"""Get all TaskRes that have not been delivered yet."""
|
|
183
|
-
if limit is not None and limit < 1:
|
|
184
|
-
raise AssertionError("`limit` must be >= 1")
|
|
185
|
-
|
|
186
195
|
with self.lock:
|
|
187
196
|
# Find TaskRes that were not delivered yet
|
|
188
197
|
task_res_list: list[TaskRes] = []
|
|
189
198
|
replied_task_ids: set[UUID] = set()
|
|
190
199
|
for _, task_res in self.task_res_store.items():
|
|
191
200
|
reply_to = UUID(task_res.task.ancestry[0])
|
|
201
|
+
|
|
202
|
+
# Check if corresponding TaskIns exists and is not expired
|
|
203
|
+
task_ins = self.task_ins_store.get(reply_to)
|
|
204
|
+
if task_ins is None:
|
|
205
|
+
log(WARNING, "TaskIns with task_id %s does not exist.", reply_to)
|
|
206
|
+
task_ids.remove(reply_to)
|
|
207
|
+
continue
|
|
208
|
+
|
|
209
|
+
if task_ins.task.created_at + task_ins.task.ttl <= time.time():
|
|
210
|
+
log(WARNING, "TaskIns with task_id %s is expired.", reply_to)
|
|
211
|
+
task_ids.remove(reply_to)
|
|
212
|
+
continue
|
|
213
|
+
|
|
192
214
|
if reply_to in task_ids and task_res.task.delivered_at == "":
|
|
193
215
|
task_res_list.append(task_res)
|
|
194
216
|
replied_task_ids.add(reply_to)
|
|
195
|
-
if limit and len(task_res_list) == limit:
|
|
196
|
-
break
|
|
197
217
|
|
|
198
218
|
# Check if the node is offline
|
|
199
219
|
for task_id in task_ids - replied_task_ids:
|
|
200
|
-
if limit and len(task_res_list) == limit:
|
|
201
|
-
break
|
|
202
220
|
task_ins = self.task_ins_store.get(task_id)
|
|
203
221
|
if task_ins is None:
|
|
204
222
|
continue
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SQLite based implemenation of server state."""
|
|
16
16
|
|
|
17
|
+
# pylint: disable=too-many-lines
|
|
17
18
|
|
|
18
19
|
import json
|
|
19
20
|
import re
|
|
@@ -150,6 +151,11 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
150
151
|
----------
|
|
151
152
|
log_queries : bool
|
|
152
153
|
Log each query which is executed.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
list[tuple[str]]
|
|
158
|
+
The list of all tables in the DB.
|
|
153
159
|
"""
|
|
154
160
|
self.conn = sqlite3.connect(self.database_path)
|
|
155
161
|
self.conn.execute("PRAGMA foreign_keys = ON;")
|
|
@@ -389,6 +395,16 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
389
395
|
)
|
|
390
396
|
return None
|
|
391
397
|
|
|
398
|
+
# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
|
|
399
|
+
if (
|
|
400
|
+
task_ins
|
|
401
|
+
and task_res
|
|
402
|
+
and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
|
|
403
|
+
and convert_sint64_to_uint64(task_ins["consumer_node_id"])
|
|
404
|
+
!= task_res.task.producer.node_id
|
|
405
|
+
):
|
|
406
|
+
return None
|
|
407
|
+
|
|
392
408
|
# Fail if the TaskRes TTL exceeds the
|
|
393
409
|
# expiration time of the TaskIns it replies to.
|
|
394
410
|
# Condition: TaskIns.created_at + TaskIns.ttl ≥
|
|
@@ -432,8 +448,8 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
432
448
|
|
|
433
449
|
return task_id
|
|
434
450
|
|
|
435
|
-
# pylint: disable-next=R0914
|
|
436
|
-
def get_task_res(self, task_ids: set[UUID]
|
|
451
|
+
# pylint: disable-next=R0912,R0915,R0914
|
|
452
|
+
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
437
453
|
"""Get TaskRes for task_ids.
|
|
438
454
|
|
|
439
455
|
Usually, the Driver API calls this method to get results for instructions it has
|
|
@@ -448,8 +464,34 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
448
464
|
will only take effect if enough task_ids are in the set AND are currently
|
|
449
465
|
available. If `limit` is set, it has to be greater than zero.
|
|
450
466
|
"""
|
|
451
|
-
if
|
|
452
|
-
|
|
467
|
+
# Check if corresponding TaskIns exists and is not expired
|
|
468
|
+
task_ids_placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
|
|
469
|
+
query = f"""
|
|
470
|
+
SELECT *
|
|
471
|
+
FROM task_ins
|
|
472
|
+
WHERE task_id IN ({task_ids_placeholders})
|
|
473
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
474
|
+
"""
|
|
475
|
+
query += ";"
|
|
476
|
+
|
|
477
|
+
task_ins_data = {}
|
|
478
|
+
for index, task_id in enumerate(task_ids):
|
|
479
|
+
task_ins_data[f"id_{index}"] = str(task_id)
|
|
480
|
+
|
|
481
|
+
task_ins_rows = self.query(query, task_ins_data)
|
|
482
|
+
|
|
483
|
+
if not task_ins_rows:
|
|
484
|
+
return []
|
|
485
|
+
|
|
486
|
+
for row in task_ins_rows:
|
|
487
|
+
# Convert values from sint64 to uint64
|
|
488
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
489
|
+
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
490
|
+
)
|
|
491
|
+
task_ins = dict_to_task_ins(row)
|
|
492
|
+
if task_ins.task.created_at + task_ins.task.ttl <= time.time():
|
|
493
|
+
log(WARNING, "TaskIns with task_id %s is expired.", task_ins.task_id)
|
|
494
|
+
task_ids.remove(UUID(task_ins.task_id))
|
|
453
495
|
|
|
454
496
|
# Retrieve all anonymous Tasks
|
|
455
497
|
if len(task_ids) == 0:
|
|
@@ -465,10 +507,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
465
507
|
|
|
466
508
|
data: dict[str, Union[str, float, int]] = {}
|
|
467
509
|
|
|
468
|
-
if limit is not None:
|
|
469
|
-
query += " LIMIT :limit"
|
|
470
|
-
data["limit"] = limit
|
|
471
|
-
|
|
472
510
|
query += ";"
|
|
473
511
|
|
|
474
512
|
for index, task_id in enumerate(task_ids):
|
|
@@ -543,9 +581,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
543
581
|
|
|
544
582
|
# Make TaskRes containing node unavailabe error
|
|
545
583
|
for row in task_ins_rows:
|
|
546
|
-
if limit and len(result) == limit:
|
|
547
|
-
break
|
|
548
|
-
|
|
549
584
|
for row in rows:
|
|
550
585
|
# Convert values from sint64 to uint64
|
|
551
586
|
convert_sint64_values_in_dict_to_uint64(
|
|
@@ -98,7 +98,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
98
98
|
"""
|
|
99
99
|
|
|
100
100
|
@abc.abstractmethod
|
|
101
|
-
def get_task_res(self, task_ids: set[UUID]
|
|
101
|
+
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
102
102
|
"""Get TaskRes for task_ids.
|
|
103
103
|
|
|
104
104
|
Usually, the Driver API calls this method to get results for instructions it has
|
|
@@ -106,12 +106,6 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
106
106
|
|
|
107
107
|
Retrieves all TaskRes for the given `task_ids` and returns and empty list of
|
|
108
108
|
none could be found.
|
|
109
|
-
|
|
110
|
-
Constraints
|
|
111
|
-
-----------
|
|
112
|
-
If `limit` is not `None`, return, at most, `limit` number of TaskRes. The limit
|
|
113
|
-
will only take effect if enough task_ids are in the set AND are currently
|
|
114
|
-
available. If `limit` is set, it has to be greater zero.
|
|
115
109
|
"""
|
|
116
110
|
|
|
117
111
|
@abc.abstractmethod
|
|
@@ -100,11 +100,6 @@ def convert_uint64_values_in_dict_to_sint64(
|
|
|
100
100
|
A dictionary where the values are integers to be converted.
|
|
101
101
|
keys : list[str]
|
|
102
102
|
A list of keys in the dictionary whose values need to be converted.
|
|
103
|
-
|
|
104
|
-
Returns
|
|
105
|
-
-------
|
|
106
|
-
None
|
|
107
|
-
This function does not return a value. It modifies `data_dict` in place.
|
|
108
103
|
"""
|
|
109
104
|
for key in keys:
|
|
110
105
|
if key in data_dict:
|
|
@@ -122,11 +117,6 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
122
117
|
A dictionary where the values are integers to be converted.
|
|
123
118
|
keys : list[str]
|
|
124
119
|
A list of keys in the dictionary whose values need to be converted.
|
|
125
|
-
|
|
126
|
-
Returns
|
|
127
|
-
-------
|
|
128
|
-
None
|
|
129
|
-
This function does not return a value. It modifies `data_dict` in place.
|
|
130
120
|
"""
|
|
131
121
|
for key in keys:
|
|
132
122
|
if key in data_dict:
|
|
@@ -48,7 +48,7 @@ from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
|
|
|
48
48
|
class RayActorClientProxy(ClientProxy):
|
|
49
49
|
"""Flower client proxy which delegates work using Ray."""
|
|
50
50
|
|
|
51
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
51
|
+
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
52
52
|
self,
|
|
53
53
|
client_fn: ClientFnExt,
|
|
54
54
|
node_id: int,
|