flwr-nightly 1.12.0.dev20241009__py3-none-any.whl → 1.12.0.dev20241011__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +60 -29
- flwr/cli/config_utils.py +10 -0
- flwr/cli/install.py +60 -20
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -5
- flwr/client/app.py +13 -3
- flwr/client/clientapp/app.py +3 -1
- flwr/client/clientapp/utils.py +11 -5
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +4 -1
- flwr/client/node_state.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/common/config.py +19 -5
- flwr/common/logger.py +1 -1
- flwr/common/message.py +6 -1
- flwr/common/record/configsrecord.py +6 -0
- flwr/common/recordset_compat.py +10 -0
- flwr/common/retry_invoker.py +15 -0
- flwr/server/app.py +1 -0
- flwr/server/client_manager.py +2 -0
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +1 -1
- flwr/server/driver/inmemory_driver.py +2 -2
- flwr/server/run_serverapp.py +11 -13
- flwr/server/server_app.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +2 -2
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -3
- flwr/server/superlink/fleet/vce/vce_api.py +9 -6
- flwr/server/superlink/state/in_memory_state.py +1 -8
- flwr/server/superlink/state/sqlite_state.py +6 -11
- flwr/server/superlink/state/state.py +1 -7
- flwr/server/superlink/state/utils.py +0 -10
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +4 -4
- {flwr_nightly-1.12.0.dev20241009.dist-info → flwr_nightly-1.12.0.dev20241011.dist-info}/METADATA +1 -1
- {flwr_nightly-1.12.0.dev20241009.dist-info → flwr_nightly-1.12.0.dev20241011.dist-info}/RECORD +51 -51
- {flwr_nightly-1.12.0.dev20241009.dist-info → flwr_nightly-1.12.0.dev20241011.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20241009.dist-info → flwr_nightly-1.12.0.dev20241011.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20241009.dist-info → flwr_nightly-1.12.0.dev20241011.dist-info}/entry_points.txt +0 -0
flwr/common/logger.py
CHANGED
|
@@ -111,7 +111,7 @@ FLOWER_LOGGER.addHandler(console_handler)
|
|
|
111
111
|
class CustomHTTPHandler(HTTPHandler):
|
|
112
112
|
"""Custom HTTPHandler which overrides the mapLogRecords method."""
|
|
113
113
|
|
|
114
|
-
# pylint: disable=too-many-arguments,bad-option-value,R1725
|
|
114
|
+
# pylint: disable=too-many-arguments,bad-option-value,R1725,R0917
|
|
115
115
|
def __init__(
|
|
116
116
|
self,
|
|
117
117
|
identifier: str,
|
flwr/common/message.py
CHANGED
|
@@ -52,7 +52,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
52
52
|
the receiving end.
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
55
|
+
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
56
56
|
self,
|
|
57
57
|
run_id: int,
|
|
58
58
|
message_id: str,
|
|
@@ -290,6 +290,11 @@ class Message:
|
|
|
290
290
|
follows the equation:
|
|
291
291
|
|
|
292
292
|
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
|
|
293
|
+
|
|
294
|
+
Returns
|
|
295
|
+
-------
|
|
296
|
+
message : Message
|
|
297
|
+
A Message containing only the relevant error and metadata.
|
|
293
298
|
"""
|
|
294
299
|
# If no TTL passed, use default for message creation (will update after
|
|
295
300
|
# message creation)
|
|
@@ -128,6 +128,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
|
128
128
|
|
|
129
129
|
def get_var_bytes(value: ConfigsScalar) -> int:
|
|
130
130
|
"""Return Bytes of value passed."""
|
|
131
|
+
var_bytes = 0
|
|
131
132
|
if isinstance(value, bool):
|
|
132
133
|
var_bytes = 1
|
|
133
134
|
elif isinstance(value, (int, float)):
|
|
@@ -136,6 +137,11 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
|
136
137
|
)
|
|
137
138
|
if isinstance(value, (str, bytes)):
|
|
138
139
|
var_bytes = len(value)
|
|
140
|
+
if var_bytes == 0:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Config values must be either `bool`, `int`, `float`, "
|
|
143
|
+
"`str`, or `bytes`"
|
|
144
|
+
)
|
|
139
145
|
return var_bytes
|
|
140
146
|
|
|
141
147
|
num_bytes = 0
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -59,6 +59,11 @@ def parametersrecord_to_parameters(
|
|
|
59
59
|
keep_input : bool
|
|
60
60
|
A boolean indicating whether entries in the record should be deleted from the
|
|
61
61
|
input dictionary immediately after adding them to the record.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
parameters : Parameters
|
|
66
|
+
The parameters in the legacy format Parameters.
|
|
62
67
|
"""
|
|
63
68
|
parameters = Parameters(tensors=[], tensor_type="")
|
|
64
69
|
|
|
@@ -94,6 +99,11 @@ def parameters_to_parametersrecord(
|
|
|
94
99
|
A boolean indicating whether parameters should be deleted from the input
|
|
95
100
|
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
|
|
96
101
|
adding them to the record.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
ParametersRecord
|
|
106
|
+
The ParametersRecord containing the provided parameters.
|
|
97
107
|
"""
|
|
98
108
|
tensor_type = parameters.tensor_type
|
|
99
109
|
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -38,6 +38,11 @@ def exponential(
|
|
|
38
38
|
Factor by which the delay is multiplied after each retry.
|
|
39
39
|
max_delay: Optional[float] (default: None)
|
|
40
40
|
The maximum delay duration between two consecutive retries.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Generator[float, None, None]
|
|
45
|
+
A generator for the delay between 2 retries.
|
|
41
46
|
"""
|
|
42
47
|
delay = base_delay if max_delay is None else min(base_delay, max_delay)
|
|
43
48
|
while True:
|
|
@@ -56,6 +61,11 @@ def constant(
|
|
|
56
61
|
----------
|
|
57
62
|
interval: Union[float, Iterable[float]] (default: 1)
|
|
58
63
|
A constant value to yield or an iterable of such values.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Generator[float, None, None]
|
|
68
|
+
A generator for the delay between 2 retries.
|
|
59
69
|
"""
|
|
60
70
|
if not isinstance(interval, Iterable):
|
|
61
71
|
interval = itertools.repeat(interval)
|
|
@@ -73,6 +83,11 @@ def full_jitter(max_value: float) -> float:
|
|
|
73
83
|
----------
|
|
74
84
|
max_value : float
|
|
75
85
|
The upper limit for the randomized value.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
float
|
|
90
|
+
A random float that is less than max_value.
|
|
76
91
|
"""
|
|
77
92
|
return random.uniform(0, max_value)
|
|
78
93
|
|
flwr/server/app.py
CHANGED
flwr/server/client_manager.py
CHANGED
|
@@ -47,6 +47,7 @@ class ClientManager(ABC):
|
|
|
47
47
|
Parameters
|
|
48
48
|
----------
|
|
49
49
|
client : flwr.server.client_proxy.ClientProxy
|
|
50
|
+
The ClientProxy of the Client to register.
|
|
50
51
|
|
|
51
52
|
Returns
|
|
52
53
|
-------
|
|
@@ -64,6 +65,7 @@ class ClientManager(ABC):
|
|
|
64
65
|
Parameters
|
|
65
66
|
----------
|
|
66
67
|
client : flwr.server.client_proxy.ClientProxy
|
|
68
|
+
The ClientProxy of the Client to unregister.
|
|
67
69
|
"""
|
|
68
70
|
|
|
69
71
|
@abstractmethod
|
flwr/server/driver/driver.py
CHANGED
|
@@ -158,7 +158,7 @@ class GrpcDriver(Driver):
|
|
|
158
158
|
):
|
|
159
159
|
raise ValueError(f"Invalid message: {message}")
|
|
160
160
|
|
|
161
|
-
def create_message( # pylint: disable=too-many-arguments
|
|
161
|
+
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
162
162
|
self,
|
|
163
163
|
content: RecordSet,
|
|
164
164
|
message_type: str,
|
|
@@ -82,7 +82,7 @@ class InMemoryDriver(Driver):
|
|
|
82
82
|
self._init_run()
|
|
83
83
|
return Run(**vars(cast(Run, self._run)))
|
|
84
84
|
|
|
85
|
-
def create_message( # pylint: disable=too-many-arguments
|
|
85
|
+
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
86
86
|
self,
|
|
87
87
|
content: RecordSet,
|
|
88
88
|
message_type: str,
|
|
@@ -150,7 +150,7 @@ class InMemoryDriver(Driver):
|
|
|
150
150
|
"""
|
|
151
151
|
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
|
152
152
|
# Pull TaskRes
|
|
153
|
-
task_res_list = self.state.get_task_res(task_ids=msg_ids
|
|
153
|
+
task_res_list = self.state.get_task_res(task_ids=msg_ids)
|
|
154
154
|
# Delete tasks in state
|
|
155
155
|
self.state.delete_tasks(msg_ids)
|
|
156
156
|
# Convert TaskRes to Message
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -181,19 +181,17 @@ def run_server_app() -> None:
|
|
|
181
181
|
)
|
|
182
182
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
183
183
|
run_ = driver.run
|
|
184
|
-
if run_.fab_hash:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
|
|
184
|
+
if not run_.fab_hash:
|
|
185
|
+
raise ValueError("FAB hash not provided.")
|
|
186
|
+
fab_req = GetFabRequest(hash_str=run_.fab_hash)
|
|
187
|
+
# pylint: disable-next=W0212
|
|
188
|
+
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
|
|
189
|
+
if fab_res.fab.hash_str != run_.fab_hash:
|
|
190
|
+
raise ValueError("FAB hashes don't match.")
|
|
191
|
+
install_from_fab(fab_res.fab.content, flwr_dir, True)
|
|
192
|
+
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
|
|
193
|
+
|
|
194
|
+
app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
|
|
197
195
|
config = get_project_config(app_path)
|
|
198
196
|
else:
|
|
199
197
|
# User provided `app_dir`, but not `--run-id`
|
flwr/server/server_app.py
CHANGED
|
@@ -88,7 +88,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
88
88
|
>>> )
|
|
89
89
|
"""
|
|
90
90
|
|
|
91
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
91
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
92
92
|
def __init__(
|
|
93
93
|
self,
|
|
94
94
|
strategy: Strategy,
|
|
@@ -307,7 +307,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
307
307
|
>>> )
|
|
308
308
|
"""
|
|
309
309
|
|
|
310
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
310
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
311
311
|
def __init__(
|
|
312
312
|
self,
|
|
313
313
|
strategy: Strategy,
|
|
@@ -39,7 +39,7 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
|
|
|
39
39
|
This class is deprecated and will be removed in a future release.
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
42
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
43
43
|
def __init__(
|
|
44
44
|
self,
|
|
45
45
|
strategy: Strategy,
|
|
@@ -36,7 +36,7 @@ class DPFedAvgFixed(Strategy):
|
|
|
36
36
|
This class is deprecated and will be removed in a future release.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
-
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
|
39
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
42
42
|
strategy: Strategy,
|
|
@@ -155,7 +155,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
155
155
|
context.add_callback(on_rpc_done)
|
|
156
156
|
|
|
157
157
|
# Read from state
|
|
158
|
-
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids
|
|
158
|
+
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
159
159
|
|
|
160
160
|
context.set_code(grpc.StatusCode.OK)
|
|
161
161
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
@@ -60,7 +60,7 @@ def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool:
|
|
|
60
60
|
return is_valid
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def start_grpc_server( # pylint: disable=too-many-arguments
|
|
63
|
+
def start_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
64
64
|
client_manager: ClientManager,
|
|
65
65
|
server_address: str,
|
|
66
66
|
max_concurrent_workers: int = 1000,
|
|
@@ -156,7 +156,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments
|
|
|
156
156
|
return server
|
|
157
157
|
|
|
158
158
|
|
|
159
|
-
def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
159
|
+
def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
160
160
|
servicer_and_add_fn: Union[
|
|
161
161
|
tuple[FleetServicer, AddServicerToServerFn],
|
|
162
162
|
tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
|
@@ -174,7 +174,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
174
174
|
|
|
175
175
|
Parameters
|
|
176
176
|
----------
|
|
177
|
-
servicer_and_add_fn :
|
|
177
|
+
servicer_and_add_fn : tuple
|
|
178
178
|
A tuple holding a servicer implementation and a matching
|
|
179
179
|
add_Servicer_to_server function.
|
|
180
180
|
server_address : str
|
|
@@ -214,6 +214,8 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
214
214
|
* CA certificate.
|
|
215
215
|
* server certificate.
|
|
216
216
|
* server private key.
|
|
217
|
+
interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None)
|
|
218
|
+
A list of gRPC interceptors.
|
|
217
219
|
|
|
218
220
|
Returns
|
|
219
221
|
-------
|
|
@@ -172,6 +172,7 @@ def put_taskres_into_state(
|
|
|
172
172
|
pass
|
|
173
173
|
|
|
174
174
|
|
|
175
|
+
# pylint: disable=too-many-positional-arguments
|
|
175
176
|
def run_api(
|
|
176
177
|
app_fn: Callable[[], ClientApp],
|
|
177
178
|
backend_fn: Callable[[], Backend],
|
|
@@ -251,7 +252,7 @@ def run_api(
|
|
|
251
252
|
|
|
252
253
|
|
|
253
254
|
# pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
|
|
254
|
-
# pylint: disable=too-many-statements
|
|
255
|
+
# pylint: disable=too-many-statements,too-many-positional-arguments
|
|
255
256
|
def start_vce(
|
|
256
257
|
backend_name: str,
|
|
257
258
|
backend_config_json_stream: str,
|
|
@@ -267,6 +268,8 @@ def start_vce(
|
|
|
267
268
|
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
|
268
269
|
) -> None:
|
|
269
270
|
"""Start Fleet API with the Simulation Engine."""
|
|
271
|
+
nodes_mapping = {}
|
|
272
|
+
|
|
270
273
|
if client_app_attr is not None and client_app is not None:
|
|
271
274
|
raise ValueError(
|
|
272
275
|
"Both `client_app_attr` and `client_app` are provided, "
|
|
@@ -340,17 +343,17 @@ def start_vce(
|
|
|
340
343
|
# Load ClientApp if needed
|
|
341
344
|
def _load() -> ClientApp:
|
|
342
345
|
|
|
346
|
+
if client_app:
|
|
347
|
+
return client_app
|
|
343
348
|
if client_app_attr:
|
|
344
|
-
|
|
349
|
+
return get_load_client_app_fn(
|
|
345
350
|
default_app_ref=client_app_attr,
|
|
346
351
|
app_path=app_dir,
|
|
347
352
|
flwr_dir=flwr_dir,
|
|
348
353
|
multi_app=False,
|
|
349
|
-
)(run.fab_id, run.fab_version)
|
|
354
|
+
)(run.fab_id, run.fab_version, run.fab_hash)
|
|
350
355
|
|
|
351
|
-
|
|
352
|
-
app = client_app
|
|
353
|
-
return app
|
|
356
|
+
raise ValueError("Either `client_app_attr` or `client_app` must be provided")
|
|
354
357
|
|
|
355
358
|
app_fn = _load
|
|
356
359
|
|
|
@@ -190,11 +190,8 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
190
190
|
# Return the new task_id
|
|
191
191
|
return task_id
|
|
192
192
|
|
|
193
|
-
def get_task_res(self, task_ids: set[UUID]
|
|
193
|
+
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
194
194
|
"""Get all TaskRes that have not been delivered yet."""
|
|
195
|
-
if limit is not None and limit < 1:
|
|
196
|
-
raise AssertionError("`limit` must be >= 1")
|
|
197
|
-
|
|
198
195
|
with self.lock:
|
|
199
196
|
# Find TaskRes that were not delivered yet
|
|
200
197
|
task_res_list: list[TaskRes] = []
|
|
@@ -217,13 +214,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
217
214
|
if reply_to in task_ids and task_res.task.delivered_at == "":
|
|
218
215
|
task_res_list.append(task_res)
|
|
219
216
|
replied_task_ids.add(reply_to)
|
|
220
|
-
if limit and len(task_res_list) == limit:
|
|
221
|
-
break
|
|
222
217
|
|
|
223
218
|
# Check if the node is offline
|
|
224
219
|
for task_id in task_ids - replied_task_ids:
|
|
225
|
-
if limit and len(task_res_list) == limit:
|
|
226
|
-
break
|
|
227
220
|
task_ins = self.task_ins_store.get(task_id)
|
|
228
221
|
if task_ins is None:
|
|
229
222
|
continue
|
|
@@ -151,6 +151,11 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
151
151
|
----------
|
|
152
152
|
log_queries : bool
|
|
153
153
|
Log each query which is executed.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
list[tuple[str]]
|
|
158
|
+
The list of all tables in the DB.
|
|
154
159
|
"""
|
|
155
160
|
self.conn = sqlite3.connect(self.database_path)
|
|
156
161
|
self.conn.execute("PRAGMA foreign_keys = ON;")
|
|
@@ -444,7 +449,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
444
449
|
return task_id
|
|
445
450
|
|
|
446
451
|
# pylint: disable-next=R0912,R0915,R0914
|
|
447
|
-
def get_task_res(self, task_ids: set[UUID]
|
|
452
|
+
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
448
453
|
"""Get TaskRes for task_ids.
|
|
449
454
|
|
|
450
455
|
Usually, the Driver API calls this method to get results for instructions it has
|
|
@@ -459,9 +464,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
459
464
|
will only take effect if enough task_ids are in the set AND are currently
|
|
460
465
|
available. If `limit` is set, it has to be greater than zero.
|
|
461
466
|
"""
|
|
462
|
-
if limit is not None and limit < 1:
|
|
463
|
-
raise AssertionError("`limit` must be >= 1")
|
|
464
|
-
|
|
465
467
|
# Check if corresponding TaskIns exists and is not expired
|
|
466
468
|
task_ids_placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
|
|
467
469
|
query = f"""
|
|
@@ -505,10 +507,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
505
507
|
|
|
506
508
|
data: dict[str, Union[str, float, int]] = {}
|
|
507
509
|
|
|
508
|
-
if limit is not None:
|
|
509
|
-
query += " LIMIT :limit"
|
|
510
|
-
data["limit"] = limit
|
|
511
|
-
|
|
512
510
|
query += ";"
|
|
513
511
|
|
|
514
512
|
for index, task_id in enumerate(task_ids):
|
|
@@ -583,9 +581,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
583
581
|
|
|
584
582
|
# Make TaskRes containing node unavailabe error
|
|
585
583
|
for row in task_ins_rows:
|
|
586
|
-
if limit and len(result) == limit:
|
|
587
|
-
break
|
|
588
|
-
|
|
589
584
|
for row in rows:
|
|
590
585
|
# Convert values from sint64 to uint64
|
|
591
586
|
convert_sint64_values_in_dict_to_uint64(
|
|
@@ -98,7 +98,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
98
98
|
"""
|
|
99
99
|
|
|
100
100
|
@abc.abstractmethod
|
|
101
|
-
def get_task_res(self, task_ids: set[UUID]
|
|
101
|
+
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
102
102
|
"""Get TaskRes for task_ids.
|
|
103
103
|
|
|
104
104
|
Usually, the Driver API calls this method to get results for instructions it has
|
|
@@ -106,12 +106,6 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
106
106
|
|
|
107
107
|
Retrieves all TaskRes for the given `task_ids` and returns and empty list of
|
|
108
108
|
none could be found.
|
|
109
|
-
|
|
110
|
-
Constraints
|
|
111
|
-
-----------
|
|
112
|
-
If `limit` is not `None`, return, at most, `limit` number of TaskRes. The limit
|
|
113
|
-
will only take effect if enough task_ids are in the set AND are currently
|
|
114
|
-
available. If `limit` is set, it has to be greater zero.
|
|
115
109
|
"""
|
|
116
110
|
|
|
117
111
|
@abc.abstractmethod
|
|
@@ -100,11 +100,6 @@ def convert_uint64_values_in_dict_to_sint64(
|
|
|
100
100
|
A dictionary where the values are integers to be converted.
|
|
101
101
|
keys : list[str]
|
|
102
102
|
A list of keys in the dictionary whose values need to be converted.
|
|
103
|
-
|
|
104
|
-
Returns
|
|
105
|
-
-------
|
|
106
|
-
None
|
|
107
|
-
This function does not return a value. It modifies `data_dict` in place.
|
|
108
103
|
"""
|
|
109
104
|
for key in keys:
|
|
110
105
|
if key in data_dict:
|
|
@@ -122,11 +117,6 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
122
117
|
A dictionary where the values are integers to be converted.
|
|
123
118
|
keys : list[str]
|
|
124
119
|
A list of keys in the dictionary whose values need to be converted.
|
|
125
|
-
|
|
126
|
-
Returns
|
|
127
|
-
-------
|
|
128
|
-
None
|
|
129
|
-
This function does not return a value. It modifies `data_dict` in place.
|
|
130
120
|
"""
|
|
131
121
|
for key in keys:
|
|
132
122
|
if key in data_dict:
|
|
@@ -48,7 +48,7 @@ from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
|
|
|
48
48
|
class RayActorClientProxy(ClientProxy):
|
|
49
49
|
"""Flower client proxy which delegates work using Ray."""
|
|
50
50
|
|
|
51
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
51
|
+
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
52
52
|
self,
|
|
53
53
|
client_fn: ClientFnExt,
|
|
54
54
|
node_id: int,
|
|
@@ -225,7 +225,7 @@ def run_simulation_from_cli() -> None:
|
|
|
225
225
|
|
|
226
226
|
|
|
227
227
|
# Entry point from Python session (script or notebook)
|
|
228
|
-
# pylint: disable=too-many-arguments
|
|
228
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
229
229
|
def run_simulation(
|
|
230
230
|
server_app: ServerApp,
|
|
231
231
|
client_app: ClientApp,
|
|
@@ -300,7 +300,7 @@ def run_simulation(
|
|
|
300
300
|
)
|
|
301
301
|
|
|
302
302
|
|
|
303
|
-
# pylint: disable=too-many-arguments
|
|
303
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
304
304
|
def run_serverapp_th(
|
|
305
305
|
server_app_attr: Optional[str],
|
|
306
306
|
server_app: Optional[ServerApp],
|
|
@@ -369,7 +369,7 @@ def run_serverapp_th(
|
|
|
369
369
|
return serverapp_th
|
|
370
370
|
|
|
371
371
|
|
|
372
|
-
# pylint: disable=too-many-locals
|
|
372
|
+
# pylint: disable=too-many-locals,too-many-positional-arguments
|
|
373
373
|
def _main_loop(
|
|
374
374
|
num_supernodes: int,
|
|
375
375
|
backend_name: str,
|
|
@@ -455,7 +455,7 @@ def _main_loop(
|
|
|
455
455
|
log(DEBUG, "Stopping Simulation Engine now.")
|
|
456
456
|
|
|
457
457
|
|
|
458
|
-
# pylint: disable=too-many-arguments,too-many-locals
|
|
458
|
+
# pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
|
|
459
459
|
def _run_simulation(
|
|
460
460
|
num_supernodes: int,
|
|
461
461
|
exit_event: EventType,
|