flwr-nightly 1.12.0.dev20240906__py3-none-any.whl → 1.12.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +1 -2
- flwr/cli/config_utils.py +10 -10
- flwr/cli/install.py +1 -2
- flwr/cli/new/new.py +26 -40
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +6 -7
- flwr/cli/utils.py +2 -2
- flwr/client/app.py +14 -14
- flwr/client/client_app.py +5 -5
- flwr/client/clientapp/app.py +2 -2
- flwr/client/dpfedavg_numpy_client.py +6 -7
- flwr/client/grpc_adapter_client/connection.py +4 -3
- flwr/client/grpc_client/connection.py +4 -3
- flwr/client/grpc_rere_client/client_interceptor.py +5 -5
- flwr/client/grpc_rere_client/connection.py +5 -4
- flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
- flwr/client/mod/utils.py +1 -3
- flwr/client/node_state.py +2 -2
- flwr/client/numpy_client.py +8 -8
- flwr/client/rest_client/connection.py +5 -4
- flwr/client/supernode/app.py +7 -8
- flwr/common/address.py +2 -2
- flwr/common/config.py +8 -8
- flwr/common/constant.py +12 -1
- flwr/common/differential_privacy.py +2 -2
- flwr/common/dp.py +1 -3
- flwr/common/exit_handlers.py +3 -3
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +3 -3
- flwr/common/object_ref.py +3 -3
- flwr/common/record/configsrecord.py +3 -3
- flwr/common/record/metricsrecord.py +3 -3
- flwr/common/record/parametersrecord.py +3 -2
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +23 -10
- flwr/common/recordset_compat.py +7 -5
- flwr/common/retry_invoker.py +6 -17
- flwr/common/secure_aggregation/crypto/shamir.py +10 -10
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
- flwr/common/secure_aggregation/quantization.py +7 -7
- flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
- flwr/common/serde.py +11 -9
- flwr/common/telemetry.py +5 -5
- flwr/common/typing.py +19 -19
- flwr/common/version.py +2 -3
- flwr/server/app.py +18 -18
- flwr/server/client_manager.py +6 -6
- flwr/server/compat/app_utils.py +2 -3
- flwr/server/driver/driver.py +3 -2
- flwr/server/driver/grpc_driver.py +7 -7
- flwr/server/driver/inmemory_driver.py +5 -4
- flwr/server/history.py +8 -9
- flwr/server/run_serverapp.py +5 -6
- flwr/server/server.py +36 -36
- flwr/server/strategy/aggregate.py +13 -13
- flwr/server/strategy/bulyan.py +8 -8
- flwr/server/strategy/dp_adaptive_clipping.py +20 -20
- flwr/server/strategy/dp_fixed_clipping.py +19 -19
- flwr/server/strategy/dpfedavg_adaptive.py +6 -6
- flwr/server/strategy/dpfedavg_fixed.py +10 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
- flwr/server/strategy/fedadagrad.py +8 -8
- flwr/server/strategy/fedadam.py +8 -8
- flwr/server/strategy/fedavg.py +16 -16
- flwr/server/strategy/fedavg_android.py +16 -16
- flwr/server/strategy/fedavgm.py +8 -8
- flwr/server/strategy/fedmedian.py +4 -4
- flwr/server/strategy/fedopt.py +5 -5
- flwr/server/strategy/fedprox.py +6 -6
- flwr/server/strategy/fedtrimmedavg.py +8 -8
- flwr/server/strategy/fedxgb_bagging.py +11 -11
- flwr/server/strategy/fedxgb_cyclic.py +9 -9
- flwr/server/strategy/fedxgb_nn_avg.py +5 -5
- flwr/server/strategy/fedyogi.py +8 -8
- flwr/server/strategy/krum.py +8 -8
- flwr/server/strategy/qfedavg.py +15 -15
- flwr/server/strategy/strategy.py +10 -10
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +6 -6
- flwr/server/superlink/ffs/disk_ffs.py +4 -4
- flwr/server/superlink/ffs/ffs.py +4 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/state/in_memory_state.py +18 -18
- flwr/server/superlink/state/sqlite_state.py +22 -21
- flwr/server/superlink/state/state.py +7 -7
- flwr/server/utils/tensorboard.py +4 -4
- flwr/server/utils/validator.py +2 -2
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
- flwr/simulation/app.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +23 -23
- flwr/simulation/run_simulation.py +16 -4
- flwr/superexec/app.py +4 -4
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/exec_grpc.py +2 -2
- flwr/superexec/exec_servicer.py +3 -2
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/METADATA +4 -6
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/RECORD +118 -118
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240906.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/entry_points.txt +0 -0
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import sys
|
|
18
18
|
from logging import DEBUG, ERROR
|
|
19
|
-
from typing import Callable,
|
|
19
|
+
from typing import Callable, Optional, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
22
22
|
|
|
@@ -31,8 +31,8 @@ from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
|
|
31
31
|
|
|
32
32
|
from .backend import Backend, BackendConfig
|
|
33
33
|
|
|
34
|
-
ClientResourcesDict =
|
|
35
|
-
ActorArgsDict =
|
|
34
|
+
ClientResourcesDict = dict[str, Union[int, float]]
|
|
35
|
+
ActorArgsDict = dict[str, Union[int, float, Callable[[], None]]]
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class RayBackend(Backend):
|
|
@@ -52,16 +52,11 @@ class RayBackend(Backend):
|
|
|
52
52
|
|
|
53
53
|
# Validate client resources
|
|
54
54
|
self.client_resources_key = "client_resources"
|
|
55
|
-
client_resources = self._validate_client_resources(config=backend_config)
|
|
55
|
+
self.client_resources = self._validate_client_resources(config=backend_config)
|
|
56
56
|
|
|
57
|
-
#
|
|
58
|
-
actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
59
|
-
|
|
60
|
-
self.pool = BasicActorPool(
|
|
61
|
-
actor_type=ClientAppActor,
|
|
62
|
-
client_resources=client_resources,
|
|
63
|
-
actor_kwargs=actor_kwargs,
|
|
64
|
-
)
|
|
57
|
+
# Valide actor resources
|
|
58
|
+
self.actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
59
|
+
self.pool: Optional[BasicActorPool] = None
|
|
65
60
|
|
|
66
61
|
self.app_fn: Optional[Callable[[], ClientApp]] = None
|
|
67
62
|
|
|
@@ -106,7 +101,7 @@ class RayBackend(Backend):
|
|
|
106
101
|
def init_ray(self, backend_config: BackendConfig) -> None:
|
|
107
102
|
"""Intialises Ray if not already initialised."""
|
|
108
103
|
if not ray.is_initialized():
|
|
109
|
-
ray_init_args:
|
|
104
|
+
ray_init_args: dict[
|
|
110
105
|
str,
|
|
111
106
|
ConfigsRecordValues,
|
|
112
107
|
] = {}
|
|
@@ -122,14 +117,24 @@ class RayBackend(Backend):
|
|
|
122
117
|
@property
|
|
123
118
|
def num_workers(self) -> int:
|
|
124
119
|
"""Return number of actors in pool."""
|
|
125
|
-
return self.pool.num_actors
|
|
120
|
+
return self.pool.num_actors if self.pool else 0
|
|
126
121
|
|
|
127
122
|
def is_worker_idle(self) -> bool:
|
|
128
123
|
"""Report whether the pool has idle actors."""
|
|
129
|
-
return self.pool.is_actor_available()
|
|
124
|
+
return self.pool.is_actor_available() if self.pool else False
|
|
130
125
|
|
|
131
126
|
def build(self, app_fn: Callable[[], ClientApp]) -> None:
|
|
132
127
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
128
|
+
# Create Actor Pool
|
|
129
|
+
try:
|
|
130
|
+
self.pool = BasicActorPool(
|
|
131
|
+
actor_type=ClientAppActor,
|
|
132
|
+
client_resources=self.client_resources,
|
|
133
|
+
actor_kwargs=self.actor_kwargs,
|
|
134
|
+
)
|
|
135
|
+
except Exception as ex:
|
|
136
|
+
raise ex
|
|
137
|
+
|
|
133
138
|
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
134
139
|
# Set ClientApp callable that ray actors will use
|
|
135
140
|
self.app_fn = app_fn
|
|
@@ -139,13 +144,16 @@ class RayBackend(Backend):
|
|
|
139
144
|
self,
|
|
140
145
|
message: Message,
|
|
141
146
|
context: Context,
|
|
142
|
-
) ->
|
|
147
|
+
) -> tuple[Message, Context]:
|
|
143
148
|
"""Run ClientApp that process a given message.
|
|
144
149
|
|
|
145
150
|
Return output message and updated context.
|
|
146
151
|
"""
|
|
147
152
|
partition_id = context.node_config[PARTITION_ID_KEY]
|
|
148
153
|
|
|
154
|
+
if self.pool is None:
|
|
155
|
+
raise ValueError("The actor pool is empty, unfit to process messages.")
|
|
156
|
+
|
|
149
157
|
if self.app_fn is None:
|
|
150
158
|
raise ValueError(
|
|
151
159
|
"Unspecified function to load a `ClientApp`. "
|
|
@@ -179,6 +187,7 @@ class RayBackend(Backend):
|
|
|
179
187
|
|
|
180
188
|
def terminate(self) -> None:
|
|
181
189
|
"""Terminate all actors in actor pool."""
|
|
182
|
-
self.pool
|
|
190
|
+
if self.pool:
|
|
191
|
+
self.pool.terminate_all_actors()
|
|
183
192
|
ray.shutdown()
|
|
184
193
|
log(DEBUG, "Terminated %s", self.__class__.__name__)
|
|
@@ -24,7 +24,7 @@ from logging import DEBUG, ERROR, INFO, WARN
|
|
|
24
24
|
from pathlib import Path
|
|
25
25
|
from queue import Empty, Queue
|
|
26
26
|
from time import sleep
|
|
27
|
-
from typing import Callable,
|
|
27
|
+
from typing import Callable, Optional
|
|
28
28
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
30
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
@@ -44,7 +44,7 @@ from flwr.server.superlink.state import State, StateFactory
|
|
|
44
44
|
|
|
45
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
46
46
|
|
|
47
|
-
NodeToPartitionMapping =
|
|
47
|
+
NodeToPartitionMapping = dict[int, int]
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
def _register_nodes(
|
|
@@ -64,9 +64,9 @@ def _register_node_states(
|
|
|
64
64
|
nodes_mapping: NodeToPartitionMapping,
|
|
65
65
|
run: Run,
|
|
66
66
|
app_dir: Optional[str] = None,
|
|
67
|
-
) ->
|
|
67
|
+
) -> dict[int, NodeState]:
|
|
68
68
|
"""Create NodeState objects and pre-register the context for the run."""
|
|
69
|
-
node_states:
|
|
69
|
+
node_states: dict[int, NodeState] = {}
|
|
70
70
|
num_partitions = len(set(nodes_mapping.values()))
|
|
71
71
|
for node_id, partition_id in nodes_mapping.items():
|
|
72
72
|
node_states[node_id] = NodeState(
|
|
@@ -89,7 +89,7 @@ def _register_node_states(
|
|
|
89
89
|
def worker(
|
|
90
90
|
taskins_queue: "Queue[TaskIns]",
|
|
91
91
|
taskres_queue: "Queue[TaskRes]",
|
|
92
|
-
node_states:
|
|
92
|
+
node_states: dict[int, NodeState],
|
|
93
93
|
backend: Backend,
|
|
94
94
|
f_stop: threading.Event,
|
|
95
95
|
) -> None:
|
|
@@ -177,7 +177,7 @@ def run_api(
|
|
|
177
177
|
backend_fn: Callable[[], Backend],
|
|
178
178
|
nodes_mapping: NodeToPartitionMapping,
|
|
179
179
|
state_factory: StateFactory,
|
|
180
|
-
node_states:
|
|
180
|
+
node_states: dict[int, NodeState],
|
|
181
181
|
f_stop: threading.Event,
|
|
182
182
|
) -> None:
|
|
183
183
|
"""Run the VCE."""
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import threading
|
|
19
19
|
import time
|
|
20
20
|
from logging import ERROR
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional
|
|
22
22
|
from uuid import UUID, uuid4
|
|
23
23
|
|
|
24
24
|
from flwr.common import log, now
|
|
@@ -37,15 +37,15 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
37
37
|
def __init__(self) -> None:
|
|
38
38
|
|
|
39
39
|
# Map node_id to (online_until, ping_interval)
|
|
40
|
-
self.node_ids:
|
|
41
|
-
self.public_key_to_node_id:
|
|
40
|
+
self.node_ids: dict[int, tuple[float, float]] = {}
|
|
41
|
+
self.public_key_to_node_id: dict[bytes, int] = {}
|
|
42
42
|
|
|
43
43
|
# Map run_id to (fab_id, fab_version)
|
|
44
|
-
self.run_ids:
|
|
45
|
-
self.task_ins_store:
|
|
46
|
-
self.task_res_store:
|
|
44
|
+
self.run_ids: dict[int, Run] = {}
|
|
45
|
+
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
46
|
+
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
47
47
|
|
|
48
|
-
self.node_public_keys:
|
|
48
|
+
self.node_public_keys: set[bytes] = set()
|
|
49
49
|
self.server_public_key: Optional[bytes] = None
|
|
50
50
|
self.server_private_key: Optional[bytes] = None
|
|
51
51
|
|
|
@@ -76,13 +76,13 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
76
76
|
|
|
77
77
|
def get_task_ins(
|
|
78
78
|
self, node_id: Optional[int], limit: Optional[int]
|
|
79
|
-
) ->
|
|
79
|
+
) -> list[TaskIns]:
|
|
80
80
|
"""Get all TaskIns that have not been delivered yet."""
|
|
81
81
|
if limit is not None and limit < 1:
|
|
82
82
|
raise AssertionError("`limit` must be >= 1")
|
|
83
83
|
|
|
84
84
|
# Find TaskIns for node_id that were not delivered yet
|
|
85
|
-
task_ins_list:
|
|
85
|
+
task_ins_list: list[TaskIns] = []
|
|
86
86
|
with self.lock:
|
|
87
87
|
for _, task_ins in self.task_ins_store.items():
|
|
88
88
|
# pylint: disable=too-many-boolean-expressions
|
|
@@ -133,15 +133,15 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
133
133
|
# Return the new task_id
|
|
134
134
|
return task_id
|
|
135
135
|
|
|
136
|
-
def get_task_res(self, task_ids:
|
|
136
|
+
def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
|
|
137
137
|
"""Get all TaskRes that have not been delivered yet."""
|
|
138
138
|
if limit is not None and limit < 1:
|
|
139
139
|
raise AssertionError("`limit` must be >= 1")
|
|
140
140
|
|
|
141
141
|
with self.lock:
|
|
142
142
|
# Find TaskRes that were not delivered yet
|
|
143
|
-
task_res_list:
|
|
144
|
-
replied_task_ids:
|
|
143
|
+
task_res_list: list[TaskRes] = []
|
|
144
|
+
replied_task_ids: set[UUID] = set()
|
|
145
145
|
for _, task_res in self.task_res_store.items():
|
|
146
146
|
reply_to = UUID(task_res.task.ancestry[0])
|
|
147
147
|
if reply_to in task_ids and task_res.task.delivered_at == "":
|
|
@@ -175,10 +175,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
175
175
|
# Return TaskRes
|
|
176
176
|
return task_res_list
|
|
177
177
|
|
|
178
|
-
def delete_tasks(self, task_ids:
|
|
178
|
+
def delete_tasks(self, task_ids: set[UUID]) -> None:
|
|
179
179
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
180
|
-
task_ins_to_be_deleted:
|
|
181
|
-
task_res_to_be_deleted:
|
|
180
|
+
task_ins_to_be_deleted: set[UUID] = set()
|
|
181
|
+
task_res_to_be_deleted: set[UUID] = set()
|
|
182
182
|
|
|
183
183
|
with self.lock:
|
|
184
184
|
for task_ins_id in task_ids:
|
|
@@ -253,7 +253,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
253
253
|
|
|
254
254
|
del self.node_ids[node_id]
|
|
255
255
|
|
|
256
|
-
def get_nodes(self, run_id: int) ->
|
|
256
|
+
def get_nodes(self, run_id: int) -> set[int]:
|
|
257
257
|
"""Return all available nodes.
|
|
258
258
|
|
|
259
259
|
Constraints
|
|
@@ -318,7 +318,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
318
318
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
319
319
|
return self.server_public_key
|
|
320
320
|
|
|
321
|
-
def store_node_public_keys(self, public_keys:
|
|
321
|
+
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
322
322
|
"""Store a set of `node_public_keys` in state."""
|
|
323
323
|
with self.lock:
|
|
324
324
|
self.node_public_keys = public_keys
|
|
@@ -328,7 +328,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
328
328
|
with self.lock:
|
|
329
329
|
self.node_public_keys.add(public_key)
|
|
330
330
|
|
|
331
|
-
def get_node_public_keys(self) ->
|
|
331
|
+
def get_node_public_keys(self) -> set[bytes]:
|
|
332
332
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
333
333
|
return self.node_public_keys
|
|
334
334
|
|
|
@@ -19,8 +19,9 @@ import json
|
|
|
19
19
|
import re
|
|
20
20
|
import sqlite3
|
|
21
21
|
import time
|
|
22
|
+
from collections.abc import Sequence
|
|
22
23
|
from logging import DEBUG, ERROR
|
|
23
|
-
from typing import Any,
|
|
24
|
+
from typing import Any, Optional, Union, cast
|
|
24
25
|
from uuid import UUID, uuid4
|
|
25
26
|
|
|
26
27
|
from flwr.common import log, now
|
|
@@ -110,7 +111,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
110
111
|
);
|
|
111
112
|
"""
|
|
112
113
|
|
|
113
|
-
DictOrTuple = Union[
|
|
114
|
+
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
|
114
115
|
|
|
115
116
|
|
|
116
117
|
class SqliteState(State): # pylint: disable=R0904
|
|
@@ -131,7 +132,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
131
132
|
self.database_path = database_path
|
|
132
133
|
self.conn: Optional[sqlite3.Connection] = None
|
|
133
134
|
|
|
134
|
-
def initialize(self, log_queries: bool = False) ->
|
|
135
|
+
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
135
136
|
"""Create tables if they don't exist yet.
|
|
136
137
|
|
|
137
138
|
Parameters
|
|
@@ -162,7 +163,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
162
163
|
self,
|
|
163
164
|
query: str,
|
|
164
165
|
data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
|
|
165
|
-
) ->
|
|
166
|
+
) -> list[dict[str, Any]]:
|
|
166
167
|
"""Execute a SQL query."""
|
|
167
168
|
if self.conn is None:
|
|
168
169
|
raise AttributeError("State is not initialized.")
|
|
@@ -237,7 +238,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
237
238
|
|
|
238
239
|
def get_task_ins(
|
|
239
240
|
self, node_id: Optional[int], limit: Optional[int]
|
|
240
|
-
) ->
|
|
241
|
+
) -> list[TaskIns]:
|
|
241
242
|
"""Get undelivered TaskIns for one node (either anonymous or with ID).
|
|
242
243
|
|
|
243
244
|
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
|
@@ -271,7 +272,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
271
272
|
)
|
|
272
273
|
raise AssertionError(msg)
|
|
273
274
|
|
|
274
|
-
data:
|
|
275
|
+
data: dict[str, Union[str, int]] = {}
|
|
275
276
|
|
|
276
277
|
if node_id is None:
|
|
277
278
|
# Retrieve all anonymous Tasks
|
|
@@ -367,7 +368,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
367
368
|
return task_id
|
|
368
369
|
|
|
369
370
|
# pylint: disable-next=R0914
|
|
370
|
-
def get_task_res(self, task_ids:
|
|
371
|
+
def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
|
|
371
372
|
"""Get TaskRes for task_ids.
|
|
372
373
|
|
|
373
374
|
Usually, the Driver API calls this method to get results for instructions it has
|
|
@@ -397,7 +398,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
397
398
|
AND delivered_at = ""
|
|
398
399
|
"""
|
|
399
400
|
|
|
400
|
-
data:
|
|
401
|
+
data: dict[str, Union[str, float, int]] = {}
|
|
401
402
|
|
|
402
403
|
if limit is not None:
|
|
403
404
|
query += " LIMIT :limit"
|
|
@@ -435,7 +436,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
435
436
|
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
436
437
|
# Assume the ancestry field only contains one element
|
|
437
438
|
data.clear()
|
|
438
|
-
replied_task_ids:
|
|
439
|
+
replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
|
|
439
440
|
remaining_task_ids = task_ids - replied_task_ids
|
|
440
441
|
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
|
|
441
442
|
query = f"""
|
|
@@ -499,10 +500,10 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
499
500
|
"""
|
|
500
501
|
query = "SELECT count(*) AS num FROM task_res;"
|
|
501
502
|
rows = self.query(query)
|
|
502
|
-
result:
|
|
503
|
+
result: dict[str, int] = rows[0]
|
|
503
504
|
return result["num"]
|
|
504
505
|
|
|
505
|
-
def delete_tasks(self, task_ids:
|
|
506
|
+
def delete_tasks(self, task_ids: set[UUID]) -> None:
|
|
506
507
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
507
508
|
ids = list(task_ids)
|
|
508
509
|
if len(ids) == 0:
|
|
@@ -588,7 +589,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
588
589
|
except KeyError as exc:
|
|
589
590
|
log(ERROR, {"query": query, "data": params, "exception": exc})
|
|
590
591
|
|
|
591
|
-
def get_nodes(self, run_id: int) ->
|
|
592
|
+
def get_nodes(self, run_id: int) -> set[int]:
|
|
592
593
|
"""Retrieve all currently stored node IDs as a set.
|
|
593
594
|
|
|
594
595
|
Constraints
|
|
@@ -604,7 +605,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
604
605
|
# Get nodes
|
|
605
606
|
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
|
606
607
|
rows = self.query(query, (time.time(),))
|
|
607
|
-
result:
|
|
608
|
+
result: set[int] = {row["node_id"] for row in rows}
|
|
608
609
|
return result
|
|
609
610
|
|
|
610
611
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
@@ -684,7 +685,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
684
685
|
public_key = None
|
|
685
686
|
return public_key
|
|
686
687
|
|
|
687
|
-
def store_node_public_keys(self, public_keys:
|
|
688
|
+
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
688
689
|
"""Store a set of `node_public_keys` in state."""
|
|
689
690
|
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
690
691
|
data = [(key,) for key in public_keys]
|
|
@@ -695,11 +696,11 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
695
696
|
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
696
697
|
self.query(query, {"public_key": public_key})
|
|
697
698
|
|
|
698
|
-
def get_node_public_keys(self) ->
|
|
699
|
+
def get_node_public_keys(self) -> set[bytes]:
|
|
699
700
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
700
701
|
query = "SELECT public_key FROM public_key"
|
|
701
702
|
rows = self.query(query)
|
|
702
|
-
result:
|
|
703
|
+
result: set[bytes] = {row["public_key"] for row in rows}
|
|
703
704
|
return result
|
|
704
705
|
|
|
705
706
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
@@ -733,7 +734,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
733
734
|
def dict_factory(
|
|
734
735
|
cursor: sqlite3.Cursor,
|
|
735
736
|
row: sqlite3.Row,
|
|
736
|
-
) ->
|
|
737
|
+
) -> dict[str, Any]:
|
|
737
738
|
"""Turn SQLite results into dicts.
|
|
738
739
|
|
|
739
740
|
Less efficent for retrival of large amounts of data but easier to use.
|
|
@@ -742,7 +743,7 @@ def dict_factory(
|
|
|
742
743
|
return dict(zip(fields, row))
|
|
743
744
|
|
|
744
745
|
|
|
745
|
-
def task_ins_to_dict(task_msg: TaskIns) ->
|
|
746
|
+
def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
|
|
746
747
|
"""Transform TaskIns to dict."""
|
|
747
748
|
result = {
|
|
748
749
|
"task_id": task_msg.task_id,
|
|
@@ -763,7 +764,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
|
|
|
763
764
|
return result
|
|
764
765
|
|
|
765
766
|
|
|
766
|
-
def task_res_to_dict(task_msg: TaskRes) ->
|
|
767
|
+
def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
767
768
|
"""Transform TaskRes to dict."""
|
|
768
769
|
result = {
|
|
769
770
|
"task_id": task_msg.task_id,
|
|
@@ -784,7 +785,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
|
|
|
784
785
|
return result
|
|
785
786
|
|
|
786
787
|
|
|
787
|
-
def dict_to_task_ins(task_dict:
|
|
788
|
+
def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
788
789
|
"""Turn task_dict into protobuf message."""
|
|
789
790
|
recordset = RecordSet()
|
|
790
791
|
recordset.ParseFromString(task_dict["recordset"])
|
|
@@ -814,7 +815,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
|
|
|
814
815
|
return result
|
|
815
816
|
|
|
816
817
|
|
|
817
|
-
def dict_to_task_res(task_dict:
|
|
818
|
+
def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
818
819
|
"""Turn task_dict into protobuf message."""
|
|
819
820
|
recordset = RecordSet()
|
|
820
821
|
recordset.ParseFromString(task_dict["recordset"])
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.common.typing import Run, UserConfig
|
|
@@ -51,7 +51,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
51
51
|
@abc.abstractmethod
|
|
52
52
|
def get_task_ins(
|
|
53
53
|
self, node_id: Optional[int], limit: Optional[int]
|
|
54
|
-
) ->
|
|
54
|
+
) -> list[TaskIns]:
|
|
55
55
|
"""Get TaskIns optionally filtered by node_id.
|
|
56
56
|
|
|
57
57
|
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
|
@@ -98,7 +98,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
98
98
|
"""
|
|
99
99
|
|
|
100
100
|
@abc.abstractmethod
|
|
101
|
-
def get_task_res(self, task_ids:
|
|
101
|
+
def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]:
|
|
102
102
|
"""Get TaskRes for task_ids.
|
|
103
103
|
|
|
104
104
|
Usually, the Driver API calls this method to get results for instructions it has
|
|
@@ -129,7 +129,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
129
129
|
"""
|
|
130
130
|
|
|
131
131
|
@abc.abstractmethod
|
|
132
|
-
def delete_tasks(self, task_ids:
|
|
132
|
+
def delete_tasks(self, task_ids: set[UUID]) -> None:
|
|
133
133
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
134
134
|
|
|
135
135
|
@abc.abstractmethod
|
|
@@ -143,7 +143,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
143
143
|
"""Remove `node_id` from state."""
|
|
144
144
|
|
|
145
145
|
@abc.abstractmethod
|
|
146
|
-
def get_nodes(self, run_id: int) ->
|
|
146
|
+
def get_nodes(self, run_id: int) -> set[int]:
|
|
147
147
|
"""Retrieve all currently stored node IDs as a set.
|
|
148
148
|
|
|
149
149
|
Constraints
|
|
@@ -199,7 +199,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
199
199
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
200
200
|
|
|
201
201
|
@abc.abstractmethod
|
|
202
|
-
def store_node_public_keys(self, public_keys:
|
|
202
|
+
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
203
203
|
"""Store a set of `node_public_keys` in state."""
|
|
204
204
|
|
|
205
205
|
@abc.abstractmethod
|
|
@@ -207,7 +207,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
207
207
|
"""Store a `node_public_key` in state."""
|
|
208
208
|
|
|
209
209
|
@abc.abstractmethod
|
|
210
|
-
def get_node_public_keys(self) ->
|
|
210
|
+
def get_node_public_keys(self) -> set[bytes]:
|
|
211
211
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
212
212
|
|
|
213
213
|
@abc.abstractmethod
|
flwr/server/utils/tensorboard.py
CHANGED
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import os
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from logging import WARN
|
|
21
|
-
from typing import Callable,
|
|
21
|
+
from typing import Callable, Optional, Union, cast
|
|
22
22
|
|
|
23
23
|
from flwr.common import EvaluateRes, Scalar
|
|
24
24
|
from flwr.common.logger import log
|
|
@@ -92,9 +92,9 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
|
|
|
92
92
|
def aggregate_evaluate(
|
|
93
93
|
self,
|
|
94
94
|
server_round: int,
|
|
95
|
-
results:
|
|
96
|
-
failures:
|
|
97
|
-
) ->
|
|
95
|
+
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
96
|
+
failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]],
|
|
97
|
+
) -> tuple[Optional[float], dict[str, Scalar]]:
|
|
98
98
|
"""Hooks into aggregate_evaluate for TensorBoard logging purpose."""
|
|
99
99
|
# Execute decorated function and extract results for logging
|
|
100
100
|
# They will be returned at the end of this function but also
|
flwr/server/utils/validator.py
CHANGED
|
@@ -15,13 +15,13 @@
|
|
|
15
15
|
"""Validators."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import Union
|
|
19
19
|
|
|
20
20
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
# pylint: disable-next=too-many-branches,too-many-statements
|
|
24
|
-
def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) ->
|
|
24
|
+
def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str]:
|
|
25
25
|
"""Validate a TaskIns or TaskRes."""
|
|
26
26
|
validation_errors = []
|
|
27
27
|
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import io
|
|
19
19
|
import timeit
|
|
20
20
|
from logging import INFO, WARN
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional, Union, cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
@@ -276,8 +276,8 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
276
276
|
)
|
|
277
277
|
|
|
278
278
|
# Aggregate training results
|
|
279
|
-
results:
|
|
280
|
-
failures:
|
|
279
|
+
results: list[tuple[ClientProxy, FitRes]] = []
|
|
280
|
+
failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = []
|
|
281
281
|
for msg in messages:
|
|
282
282
|
if msg.has_content():
|
|
283
283
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
@@ -362,8 +362,8 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
362
362
|
)
|
|
363
363
|
|
|
364
364
|
# Aggregate the evaluation results
|
|
365
|
-
results:
|
|
366
|
-
failures:
|
|
365
|
+
results: list[tuple[ClientProxy, EvaluateRes]] = []
|
|
366
|
+
failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = []
|
|
367
367
|
for msg in messages:
|
|
368
368
|
if msg.has_content():
|
|
369
369
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import random
|
|
19
19
|
from dataclasses import dataclass, field
|
|
20
20
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional, Union, cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
@@ -65,23 +65,23 @@ from ..constant import Key as WorkflowKey
|
|
|
65
65
|
class WorkflowState: # pylint: disable=R0902
|
|
66
66
|
"""The state of the SecAgg+ protocol."""
|
|
67
67
|
|
|
68
|
-
nid_to_proxies:
|
|
69
|
-
nid_to_fitins:
|
|
70
|
-
sampled_node_ids:
|
|
71
|
-
active_node_ids:
|
|
68
|
+
nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
|
|
69
|
+
nid_to_fitins: dict[int, RecordSet] = field(default_factory=dict)
|
|
70
|
+
sampled_node_ids: set[int] = field(default_factory=set)
|
|
71
|
+
active_node_ids: set[int] = field(default_factory=set)
|
|
72
72
|
num_shares: int = 0
|
|
73
73
|
threshold: int = 0
|
|
74
74
|
clipping_range: float = 0.0
|
|
75
75
|
quantization_range: int = 0
|
|
76
76
|
mod_range: int = 0
|
|
77
77
|
max_weight: float = 0.0
|
|
78
|
-
nid_to_neighbours:
|
|
79
|
-
nid_to_publickeys:
|
|
80
|
-
forward_srcs:
|
|
81
|
-
forward_ciphertexts:
|
|
78
|
+
nid_to_neighbours: dict[int, set[int]] = field(default_factory=dict)
|
|
79
|
+
nid_to_publickeys: dict[int, list[bytes]] = field(default_factory=dict)
|
|
80
|
+
forward_srcs: dict[int, list[int]] = field(default_factory=dict)
|
|
81
|
+
forward_ciphertexts: dict[int, list[bytes]] = field(default_factory=dict)
|
|
82
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
|
83
|
-
legacy_results:
|
|
84
|
-
failures:
|
|
83
|
+
legacy_results: list[tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
|
84
|
+
failures: list[Exception] = field(default_factory=list)
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
class SecAggPlusWorkflow:
|
|
@@ -444,13 +444,13 @@ class SecAggPlusWorkflow:
|
|
|
444
444
|
)
|
|
445
445
|
|
|
446
446
|
# Build forward packet list dictionary
|
|
447
|
-
srcs:
|
|
448
|
-
dsts:
|
|
449
|
-
ciphertexts:
|
|
450
|
-
fwd_ciphertexts:
|
|
447
|
+
srcs: list[int] = []
|
|
448
|
+
dsts: list[int] = []
|
|
449
|
+
ciphertexts: list[bytes] = []
|
|
450
|
+
fwd_ciphertexts: dict[int, list[bytes]] = {
|
|
451
451
|
nid: [] for nid in state.active_node_ids
|
|
452
452
|
} # dest node ID -> list of ciphertexts
|
|
453
|
-
fwd_srcs:
|
|
453
|
+
fwd_srcs: dict[int, list[int]] = {
|
|
454
454
|
nid: [] for nid in state.active_node_ids
|
|
455
455
|
} # dest node ID -> list of src node IDs
|
|
456
456
|
for msg in msgs:
|
|
@@ -459,8 +459,8 @@ class SecAggPlusWorkflow:
|
|
|
459
459
|
continue
|
|
460
460
|
node_id = msg.metadata.src_node_id
|
|
461
461
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
462
|
-
dst_lst = cast(
|
|
463
|
-
ctxt_lst = cast(
|
|
462
|
+
dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
|
|
463
|
+
ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
|
|
464
464
|
srcs += [node_id] * len(dst_lst)
|
|
465
465
|
dsts += dst_lst
|
|
466
466
|
ciphertexts += ctxt_lst
|
|
@@ -525,7 +525,7 @@ class SecAggPlusWorkflow:
|
|
|
525
525
|
state.failures.append(Exception(msg.error))
|
|
526
526
|
continue
|
|
527
527
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
528
|
-
bytes_list = cast(
|
|
528
|
+
bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
529
529
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
530
530
|
if masked_vector is None:
|
|
531
531
|
masked_vector = client_masked_vec
|
|
@@ -592,7 +592,7 @@ class SecAggPlusWorkflow:
|
|
|
592
592
|
)
|
|
593
593
|
|
|
594
594
|
# Build collected shares dict
|
|
595
|
-
collected_shares_dict:
|
|
595
|
+
collected_shares_dict: dict[int, list[bytes]] = {}
|
|
596
596
|
for nid in state.sampled_node_ids:
|
|
597
597
|
collected_shares_dict[nid] = []
|
|
598
598
|
for msg in msgs:
|
|
@@ -600,8 +600,8 @@ class SecAggPlusWorkflow:
|
|
|
600
600
|
state.failures.append(Exception(msg.error))
|
|
601
601
|
continue
|
|
602
602
|
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
|
603
|
-
nids = cast(
|
|
604
|
-
shares = cast(
|
|
603
|
+
nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
|
|
604
|
+
shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
|
|
605
605
|
for owner_nid, share in zip(nids, shares):
|
|
606
606
|
collected_shares_dict[owner_nid].append(share)
|
|
607
607
|
|