flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,31 +15,57 @@
|
|
|
15
15
|
"""SQLite based implemenation of server state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
18
|
+
import json
|
|
19
19
|
import re
|
|
20
20
|
import sqlite3
|
|
21
|
-
|
|
21
|
+
import time
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
23
|
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
|
|
24
24
|
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
26
|
from flwr.common import log, now
|
|
27
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
|
|
28
|
+
from flwr.common.typing import Run, UserConfig
|
|
27
29
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
30
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
29
31
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
30
32
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
31
33
|
|
|
32
34
|
from .state import State
|
|
35
|
+
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
|
|
33
36
|
|
|
34
37
|
SQL_CREATE_TABLE_NODE = """
|
|
35
38
|
CREATE TABLE IF NOT EXISTS node(
|
|
36
|
-
node_id
|
|
39
|
+
node_id INTEGER UNIQUE,
|
|
40
|
+
online_until REAL,
|
|
41
|
+
ping_interval REAL,
|
|
42
|
+
public_key BLOB
|
|
37
43
|
);
|
|
38
44
|
"""
|
|
39
45
|
|
|
46
|
+
SQL_CREATE_TABLE_CREDENTIAL = """
|
|
47
|
+
CREATE TABLE IF NOT EXISTS credential(
|
|
48
|
+
private_key BLOB PRIMARY KEY,
|
|
49
|
+
public_key BLOB
|
|
50
|
+
);
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
54
|
+
CREATE TABLE IF NOT EXISTS public_key(
|
|
55
|
+
public_key BLOB UNIQUE
|
|
56
|
+
);
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
60
|
+
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
61
|
+
"""
|
|
62
|
+
|
|
40
63
|
SQL_CREATE_TABLE_RUN = """
|
|
41
64
|
CREATE TABLE IF NOT EXISTS run(
|
|
42
|
-
run_id
|
|
65
|
+
run_id INTEGER UNIQUE,
|
|
66
|
+
fab_id TEXT,
|
|
67
|
+
fab_version TEXT,
|
|
68
|
+
override_config TEXT
|
|
43
69
|
);
|
|
44
70
|
"""
|
|
45
71
|
|
|
@@ -52,9 +78,10 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
52
78
|
producer_node_id INTEGER,
|
|
53
79
|
consumer_anonymous BOOLEAN,
|
|
54
80
|
consumer_node_id INTEGER,
|
|
55
|
-
created_at
|
|
81
|
+
created_at REAL,
|
|
56
82
|
delivered_at TEXT,
|
|
57
|
-
|
|
83
|
+
pushed_at REAL,
|
|
84
|
+
ttl REAL,
|
|
58
85
|
ancestry TEXT,
|
|
59
86
|
task_type TEXT,
|
|
60
87
|
recordset BLOB,
|
|
@@ -62,7 +89,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
62
89
|
);
|
|
63
90
|
"""
|
|
64
91
|
|
|
65
|
-
|
|
66
92
|
SQL_CREATE_TABLE_TASK_RES = """
|
|
67
93
|
CREATE TABLE IF NOT EXISTS task_res(
|
|
68
94
|
task_id TEXT UNIQUE,
|
|
@@ -72,9 +98,10 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
72
98
|
producer_node_id INTEGER,
|
|
73
99
|
consumer_anonymous BOOLEAN,
|
|
74
100
|
consumer_node_id INTEGER,
|
|
75
|
-
created_at
|
|
101
|
+
created_at REAL,
|
|
76
102
|
delivered_at TEXT,
|
|
77
|
-
|
|
103
|
+
pushed_at REAL,
|
|
104
|
+
ttl REAL,
|
|
78
105
|
ancestry TEXT,
|
|
79
106
|
task_type TEXT,
|
|
80
107
|
recordset BLOB,
|
|
@@ -82,10 +109,10 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
82
109
|
);
|
|
83
110
|
"""
|
|
84
111
|
|
|
85
|
-
DictOrTuple = Union[Tuple[Any], Dict[str, Any]]
|
|
112
|
+
DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
|
|
86
113
|
|
|
87
114
|
|
|
88
|
-
class SqliteState(State):
|
|
115
|
+
class SqliteState(State): # pylint: disable=R0904
|
|
89
116
|
"""SQLite-based state implementation."""
|
|
90
117
|
|
|
91
118
|
def __init__(
|
|
@@ -123,6 +150,9 @@ class SqliteState(State):
|
|
|
123
150
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
124
151
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
125
152
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
153
|
+
cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
|
|
154
|
+
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
155
|
+
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
126
156
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
127
157
|
|
|
128
158
|
return res.fetchall()
|
|
@@ -130,7 +160,7 @@ class SqliteState(State):
|
|
|
130
160
|
def query(
|
|
131
161
|
self,
|
|
132
162
|
query: str,
|
|
133
|
-
data: Optional[Union[
|
|
163
|
+
data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
|
|
134
164
|
) -> List[Dict[str, Any]]:
|
|
135
165
|
"""Execute a SQL query."""
|
|
136
166
|
if self.conn is None:
|
|
@@ -185,15 +215,11 @@ class SqliteState(State):
|
|
|
185
215
|
log(ERROR, errors)
|
|
186
216
|
return None
|
|
187
217
|
|
|
188
|
-
# Create task_id
|
|
218
|
+
# Create task_id
|
|
189
219
|
task_id = uuid4()
|
|
190
|
-
created_at: datetime = now()
|
|
191
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
192
220
|
|
|
193
221
|
# Store TaskIns
|
|
194
222
|
task_ins.task_id = str(task_id)
|
|
195
|
-
task_ins.task.created_at = created_at.isoformat()
|
|
196
|
-
task_ins.task.ttl = ttl.isoformat()
|
|
197
223
|
data = (task_ins_to_dict(task_ins),)
|
|
198
224
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
199
225
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
@@ -320,15 +346,11 @@ class SqliteState(State):
|
|
|
320
346
|
log(ERROR, errors)
|
|
321
347
|
return None
|
|
322
348
|
|
|
323
|
-
# Create task_id
|
|
349
|
+
# Create task_id
|
|
324
350
|
task_id = uuid4()
|
|
325
|
-
created_at: datetime = now()
|
|
326
|
-
ttl: datetime = created_at + timedelta(hours=24)
|
|
327
351
|
|
|
328
352
|
# Store TaskIns
|
|
329
353
|
task_res.task_id = str(task_id)
|
|
330
|
-
task_res.task.created_at = created_at.isoformat()
|
|
331
|
-
task_res.task.ttl = ttl.isoformat()
|
|
332
354
|
data = (task_res_to_dict(task_res),)
|
|
333
355
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
334
356
|
query = f"INSERT INTO task_res VALUES({columns});"
|
|
@@ -343,6 +365,7 @@ class SqliteState(State):
|
|
|
343
365
|
|
|
344
366
|
return task_id
|
|
345
367
|
|
|
368
|
+
# pylint: disable-next=R0914
|
|
346
369
|
def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
|
|
347
370
|
"""Get TaskRes for task_ids.
|
|
348
371
|
|
|
@@ -373,7 +396,7 @@ class SqliteState(State):
|
|
|
373
396
|
AND delivered_at = ""
|
|
374
397
|
"""
|
|
375
398
|
|
|
376
|
-
data: Dict[str, Union[str, int]] = {}
|
|
399
|
+
data: Dict[str, Union[str, float, int]] = {}
|
|
377
400
|
|
|
378
401
|
if limit is not None:
|
|
379
402
|
query += " LIMIT :limit"
|
|
@@ -407,6 +430,54 @@ class SqliteState(State):
|
|
|
407
430
|
rows = self.query(query, data)
|
|
408
431
|
|
|
409
432
|
result = [dict_to_task_res(row) for row in rows]
|
|
433
|
+
|
|
434
|
+
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
435
|
+
# Assume the ancestry field only contains one element
|
|
436
|
+
data.clear()
|
|
437
|
+
replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
|
|
438
|
+
remaining_task_ids = task_ids - replied_task_ids
|
|
439
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
|
|
440
|
+
query = f"""
|
|
441
|
+
SELECT consumer_node_id
|
|
442
|
+
FROM task_ins
|
|
443
|
+
WHERE task_id IN ({placeholders});
|
|
444
|
+
"""
|
|
445
|
+
for index, task_id in enumerate(remaining_task_ids):
|
|
446
|
+
data[f"id_{index}"] = str(task_id)
|
|
447
|
+
node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
|
|
448
|
+
|
|
449
|
+
# 2. Query: Select offline nodes
|
|
450
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
|
|
451
|
+
query = f"""
|
|
452
|
+
SELECT node_id
|
|
453
|
+
FROM node
|
|
454
|
+
WHERE node_id IN ({placeholders})
|
|
455
|
+
AND online_until < :time;
|
|
456
|
+
"""
|
|
457
|
+
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
|
|
458
|
+
data["time"] = time.time()
|
|
459
|
+
offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
|
|
460
|
+
|
|
461
|
+
# 3. Query: Select TaskIns for offline nodes
|
|
462
|
+
placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
|
|
463
|
+
query = f"""
|
|
464
|
+
SELECT *
|
|
465
|
+
FROM task_ins
|
|
466
|
+
WHERE consumer_node_id IN ({placeholders});
|
|
467
|
+
"""
|
|
468
|
+
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
|
|
469
|
+
task_ins_rows = self.query(query, data)
|
|
470
|
+
|
|
471
|
+
# Make TaskRes containing node unavailabe error
|
|
472
|
+
for row in task_ins_rows:
|
|
473
|
+
if limit and len(result) == limit:
|
|
474
|
+
break
|
|
475
|
+
task_ins = dict_to_task_ins(row)
|
|
476
|
+
err_taskres = make_node_unavailable_taskres(
|
|
477
|
+
ref_taskins=task_ins,
|
|
478
|
+
)
|
|
479
|
+
result.append(err_taskres)
|
|
480
|
+
|
|
410
481
|
return result
|
|
411
482
|
|
|
412
483
|
def num_task_ins(self) -> int:
|
|
@@ -467,23 +538,54 @@ class SqliteState(State):
|
|
|
467
538
|
|
|
468
539
|
return None
|
|
469
540
|
|
|
470
|
-
def create_node(
|
|
541
|
+
def create_node(
|
|
542
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
543
|
+
) -> int:
|
|
471
544
|
"""Create, store in state, and return `node_id`."""
|
|
472
545
|
# Sample a random int64 as node_id
|
|
473
|
-
node_id
|
|
546
|
+
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
547
|
+
|
|
548
|
+
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
549
|
+
row = self.query(query, {"public_key": public_key})
|
|
550
|
+
|
|
551
|
+
if len(row) > 0:
|
|
552
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
553
|
+
return 0
|
|
554
|
+
|
|
555
|
+
query = (
|
|
556
|
+
"INSERT INTO node "
|
|
557
|
+
"(node_id, online_until, ping_interval, public_key) "
|
|
558
|
+
"VALUES (?, ?, ?, ?)"
|
|
559
|
+
)
|
|
474
560
|
|
|
475
|
-
query = "INSERT INTO node VALUES(:node_id);"
|
|
476
561
|
try:
|
|
477
|
-
self.query(
|
|
562
|
+
self.query(
|
|
563
|
+
query, (node_id, time.time() + ping_interval, ping_interval, public_key)
|
|
564
|
+
)
|
|
478
565
|
except sqlite3.IntegrityError:
|
|
479
566
|
log(ERROR, "Unexpected node registration failure.")
|
|
480
567
|
return 0
|
|
481
568
|
return node_id
|
|
482
569
|
|
|
483
|
-
def delete_node(self, node_id: int) -> None:
|
|
570
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
484
571
|
"""Delete a client node."""
|
|
485
|
-
query = "DELETE FROM node WHERE node_id =
|
|
486
|
-
|
|
572
|
+
query = "DELETE FROM node WHERE node_id = ?"
|
|
573
|
+
params = (node_id,)
|
|
574
|
+
|
|
575
|
+
if public_key is not None:
|
|
576
|
+
query += " AND public_key = ?"
|
|
577
|
+
params += (public_key,) # type: ignore
|
|
578
|
+
|
|
579
|
+
if self.conn is None:
|
|
580
|
+
raise AttributeError("State is not initialized.")
|
|
581
|
+
|
|
582
|
+
try:
|
|
583
|
+
with self.conn:
|
|
584
|
+
rows = self.conn.execute(query, params)
|
|
585
|
+
if rows.rowcount < 1:
|
|
586
|
+
raise ValueError("Public key or node_id not found")
|
|
587
|
+
except KeyError as exc:
|
|
588
|
+
log(ERROR, {"query": query, "data": params, "exception": exc})
|
|
487
589
|
|
|
488
590
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
489
591
|
"""Retrieve all currently stored node IDs as a set.
|
|
@@ -499,26 +601,124 @@ class SqliteState(State):
|
|
|
499
601
|
return set()
|
|
500
602
|
|
|
501
603
|
# Get nodes
|
|
502
|
-
query = "SELECT
|
|
503
|
-
rows = self.query(query)
|
|
604
|
+
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
|
605
|
+
rows = self.query(query, (time.time(),))
|
|
504
606
|
result: Set[int] = {row["node_id"] for row in rows}
|
|
505
607
|
return result
|
|
506
608
|
|
|
507
|
-
def
|
|
508
|
-
"""
|
|
609
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
610
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
611
|
+
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
612
|
+
row = self.query(query, {"public_key": client_public_key})
|
|
613
|
+
if len(row) > 0:
|
|
614
|
+
node_id: int = row[0]["node_id"]
|
|
615
|
+
return node_id
|
|
616
|
+
return None
|
|
617
|
+
|
|
618
|
+
def create_run(
|
|
619
|
+
self,
|
|
620
|
+
fab_id: str,
|
|
621
|
+
fab_version: str,
|
|
622
|
+
override_config: UserConfig,
|
|
623
|
+
) -> int:
|
|
624
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
509
625
|
# Sample a random int64 as run_id
|
|
510
|
-
run_id
|
|
626
|
+
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
511
627
|
|
|
512
628
|
# Check conflicts
|
|
513
629
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
514
630
|
# If run_id does not exist
|
|
515
631
|
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
|
|
516
|
-
query =
|
|
517
|
-
|
|
632
|
+
query = (
|
|
633
|
+
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
|
|
634
|
+
"VALUES (?, ?, ?, ?);"
|
|
635
|
+
)
|
|
636
|
+
self.query(
|
|
637
|
+
query, (run_id, fab_id, fab_version, json.dumps(override_config))
|
|
638
|
+
)
|
|
518
639
|
return run_id
|
|
519
640
|
log(ERROR, "Unexpected run creation failure.")
|
|
520
641
|
return 0
|
|
521
642
|
|
|
643
|
+
def store_server_private_public_key(
|
|
644
|
+
self, private_key: bytes, public_key: bytes
|
|
645
|
+
) -> None:
|
|
646
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
647
|
+
query = "SELECT COUNT(*) FROM credential"
|
|
648
|
+
count = self.query(query)[0]["COUNT(*)"]
|
|
649
|
+
if count < 1:
|
|
650
|
+
query = (
|
|
651
|
+
"INSERT OR REPLACE INTO credential (private_key, public_key) "
|
|
652
|
+
"VALUES (:private_key, :public_key)"
|
|
653
|
+
)
|
|
654
|
+
self.query(query, {"private_key": private_key, "public_key": public_key})
|
|
655
|
+
else:
|
|
656
|
+
raise RuntimeError("Server private and public key already set")
|
|
657
|
+
|
|
658
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
659
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
660
|
+
query = "SELECT private_key FROM credential"
|
|
661
|
+
rows = self.query(query)
|
|
662
|
+
try:
|
|
663
|
+
private_key: Optional[bytes] = rows[0]["private_key"]
|
|
664
|
+
except IndexError:
|
|
665
|
+
private_key = None
|
|
666
|
+
return private_key
|
|
667
|
+
|
|
668
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
669
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
670
|
+
query = "SELECT public_key FROM credential"
|
|
671
|
+
rows = self.query(query)
|
|
672
|
+
try:
|
|
673
|
+
public_key: Optional[bytes] = rows[0]["public_key"]
|
|
674
|
+
except IndexError:
|
|
675
|
+
public_key = None
|
|
676
|
+
return public_key
|
|
677
|
+
|
|
678
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
679
|
+
"""Store a set of `client_public_keys` in state."""
|
|
680
|
+
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
681
|
+
data = [(key,) for key in public_keys]
|
|
682
|
+
self.query(query, data)
|
|
683
|
+
|
|
684
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
685
|
+
"""Store a `client_public_key` in state."""
|
|
686
|
+
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
687
|
+
self.query(query, {"public_key": public_key})
|
|
688
|
+
|
|
689
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
690
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
691
|
+
query = "SELECT public_key FROM public_key"
|
|
692
|
+
rows = self.query(query)
|
|
693
|
+
result: Set[bytes] = {row["public_key"] for row in rows}
|
|
694
|
+
return result
|
|
695
|
+
|
|
696
|
+
def get_run(self, run_id: int) -> Optional[Run]:
|
|
697
|
+
"""Retrieve information about the run with the specified `run_id`."""
|
|
698
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
699
|
+
try:
|
|
700
|
+
row = self.query(query, (run_id,))[0]
|
|
701
|
+
return Run(
|
|
702
|
+
run_id=run_id,
|
|
703
|
+
fab_id=row["fab_id"],
|
|
704
|
+
fab_version=row["fab_version"],
|
|
705
|
+
override_config=json.loads(row["override_config"]),
|
|
706
|
+
)
|
|
707
|
+
except sqlite3.IntegrityError:
|
|
708
|
+
log(ERROR, "`run_id` does not exist.")
|
|
709
|
+
return None
|
|
710
|
+
|
|
711
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
712
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
713
|
+
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
714
|
+
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
|
|
715
|
+
try:
|
|
716
|
+
self.query(query, (time.time() + ping_interval, ping_interval, node_id))
|
|
717
|
+
return True
|
|
718
|
+
except sqlite3.IntegrityError:
|
|
719
|
+
log(ERROR, "`node_id` does not exist.")
|
|
720
|
+
return False
|
|
721
|
+
|
|
522
722
|
|
|
523
723
|
def dict_factory(
|
|
524
724
|
cursor: sqlite3.Cursor,
|
|
@@ -544,6 +744,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
|
|
|
544
744
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
545
745
|
"created_at": task_msg.task.created_at,
|
|
546
746
|
"delivered_at": task_msg.task.delivered_at,
|
|
747
|
+
"pushed_at": task_msg.task.pushed_at,
|
|
547
748
|
"ttl": task_msg.task.ttl,
|
|
548
749
|
"ancestry": ",".join(task_msg.task.ancestry),
|
|
549
750
|
"task_type": task_msg.task.task_type,
|
|
@@ -564,6 +765,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
|
|
|
564
765
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
565
766
|
"created_at": task_msg.task.created_at,
|
|
566
767
|
"delivered_at": task_msg.task.delivered_at,
|
|
768
|
+
"pushed_at": task_msg.task.pushed_at,
|
|
567
769
|
"ttl": task_msg.task.ttl,
|
|
568
770
|
"ancestry": ",".join(task_msg.task.ancestry),
|
|
569
771
|
"task_type": task_msg.task.task_type,
|
|
@@ -592,6 +794,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
|
|
|
592
794
|
),
|
|
593
795
|
created_at=task_dict["created_at"],
|
|
594
796
|
delivered_at=task_dict["delivered_at"],
|
|
797
|
+
pushed_at=task_dict["pushed_at"],
|
|
595
798
|
ttl=task_dict["ttl"],
|
|
596
799
|
ancestry=task_dict["ancestry"].split(","),
|
|
597
800
|
task_type=task_dict["task_type"],
|
|
@@ -621,6 +824,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes:
|
|
|
621
824
|
),
|
|
622
825
|
created_at=task_dict["created_at"],
|
|
623
826
|
delivered_at=task_dict["delivered_at"],
|
|
827
|
+
pushed_at=task_dict["pushed_at"],
|
|
624
828
|
ttl=task_dict["ttl"],
|
|
625
829
|
ancestry=task_dict["ancestry"].split(","),
|
|
626
830
|
task_type=task_dict["task_type"],
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -19,10 +19,11 @@ import abc
|
|
|
19
19
|
from typing import List, Optional, Set
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.typing import Run, UserConfig
|
|
22
23
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
class State(abc.ABC):
|
|
26
|
+
class State(abc.ABC): # pylint: disable=R0904
|
|
26
27
|
"""Abstract State."""
|
|
27
28
|
|
|
28
29
|
@abc.abstractmethod
|
|
@@ -132,11 +133,13 @@ class State(abc.ABC):
|
|
|
132
133
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
133
134
|
|
|
134
135
|
@abc.abstractmethod
|
|
135
|
-
def create_node(
|
|
136
|
+
def create_node(
|
|
137
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
138
|
+
) -> int:
|
|
136
139
|
"""Create, store in state, and return `node_id`."""
|
|
137
140
|
|
|
138
141
|
@abc.abstractmethod
|
|
139
|
-
def delete_node(self, node_id: int) -> None:
|
|
142
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
140
143
|
"""Remove `node_id` from state."""
|
|
141
144
|
|
|
142
145
|
@abc.abstractmethod
|
|
@@ -150,5 +153,77 @@ class State(abc.ABC):
|
|
|
150
153
|
"""
|
|
151
154
|
|
|
152
155
|
@abc.abstractmethod
|
|
153
|
-
def
|
|
154
|
-
"""
|
|
156
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
157
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
158
|
+
|
|
159
|
+
@abc.abstractmethod
|
|
160
|
+
def create_run(
|
|
161
|
+
self,
|
|
162
|
+
fab_id: str,
|
|
163
|
+
fab_version: str,
|
|
164
|
+
override_config: UserConfig,
|
|
165
|
+
) -> int:
|
|
166
|
+
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
167
|
+
|
|
168
|
+
@abc.abstractmethod
|
|
169
|
+
def get_run(self, run_id: int) -> Optional[Run]:
|
|
170
|
+
"""Retrieve information about the run with the specified `run_id`.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
run_id : int
|
|
175
|
+
The identifier of the run.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
Optional[Run]
|
|
180
|
+
A dataclass instance containing three elements if `run_id` is valid:
|
|
181
|
+
- `run_id`: The identifier of the run, same as the specified `run_id`.
|
|
182
|
+
- `fab_id`: The identifier of the FAB used in the specified run.
|
|
183
|
+
- `fab_version`: The version of the FAB used in the specified run.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
@abc.abstractmethod
|
|
187
|
+
def store_server_private_public_key(
|
|
188
|
+
self, private_key: bytes, public_key: bytes
|
|
189
|
+
) -> None:
|
|
190
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
191
|
+
|
|
192
|
+
@abc.abstractmethod
|
|
193
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
194
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
195
|
+
|
|
196
|
+
@abc.abstractmethod
|
|
197
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
198
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
199
|
+
|
|
200
|
+
@abc.abstractmethod
|
|
201
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
202
|
+
"""Store a set of `client_public_keys` in state."""
|
|
203
|
+
|
|
204
|
+
@abc.abstractmethod
|
|
205
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
206
|
+
"""Store a `client_public_key` in state."""
|
|
207
|
+
|
|
208
|
+
@abc.abstractmethod
|
|
209
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
210
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
211
|
+
|
|
212
|
+
@abc.abstractmethod
|
|
213
|
+
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
214
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
node_id : int
|
|
219
|
+
The `node_id` from which the ping was received.
|
|
220
|
+
ping_interval : float
|
|
221
|
+
The interval (in seconds) from the current timestamp within which the next
|
|
222
|
+
ping from this node must be received. This acts as a hard deadline to ensure
|
|
223
|
+
an accurate assessment of the node's availability.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
is_acknowledged : bool
|
|
228
|
+
True if the ping is successfully acknowledged; otherwise, False.
|
|
229
|
+
"""
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -26,7 +26,16 @@ from .state import State
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class StateFactory:
|
|
29
|
-
"""Factory class that creates State instances.
|
|
29
|
+
"""Factory class that creates State instances.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
database : str
|
|
34
|
+
A string representing the path to the database file that will be opened.
|
|
35
|
+
Note that passing ':memory:' will open a connection to a database that is
|
|
36
|
+
in RAM, instead of on disk. For more information on special in-memory
|
|
37
|
+
databases, please refer to https://sqlite.org/inmemorydb.html.
|
|
38
|
+
"""
|
|
30
39
|
|
|
31
40
|
def __init__(self, database: str) -> None:
|
|
32
41
|
self.database = database
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utility functions for State."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import time
|
|
19
|
+
from logging import ERROR
|
|
20
|
+
from os import urandom
|
|
21
|
+
from uuid import uuid4
|
|
22
|
+
|
|
23
|
+
from flwr.common import log
|
|
24
|
+
from flwr.common.constant import ErrorCode
|
|
25
|
+
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
|
|
26
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
27
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
28
|
+
|
|
29
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
30
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
31
|
+
"It exceeds the time limit specified in its last ping."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def generate_rand_int_from_bytes(num_bytes: int) -> int:
|
|
36
|
+
"""Generate a random `num_bytes` integer."""
|
|
37
|
+
return int.from_bytes(urandom(num_bytes), "little", signed=True)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
41
|
+
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
|
|
42
|
+
current_time = time.time()
|
|
43
|
+
ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
|
|
44
|
+
if ttl < 0:
|
|
45
|
+
log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
|
|
46
|
+
ttl = 0
|
|
47
|
+
return TaskRes(
|
|
48
|
+
task_id=str(uuid4()),
|
|
49
|
+
group_id=ref_taskins.group_id,
|
|
50
|
+
run_id=ref_taskins.run_id,
|
|
51
|
+
task=Task(
|
|
52
|
+
producer=Node(node_id=ref_taskins.task.consumer.node_id, anonymous=False),
|
|
53
|
+
consumer=Node(node_id=ref_taskins.task.producer.node_id, anonymous=False),
|
|
54
|
+
created_at=current_time,
|
|
55
|
+
ttl=ttl,
|
|
56
|
+
ancestry=[ref_taskins.task_id],
|
|
57
|
+
task_type=ref_taskins.task.task_type,
|
|
58
|
+
error=Error(
|
|
59
|
+
code=ErrorCode.NODE_UNAVAILABLE, reason=NODE_UNAVAILABLE_ERROR_REASON
|
|
60
|
+
),
|
|
61
|
+
),
|
|
62
|
+
)
|
flwr/server/typing.py
CHANGED
|
@@ -20,6 +20,8 @@ from typing import Callable
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
22
|
from .driver import Driver
|
|
23
|
+
from .serverapp_components import ServerAppComponents
|
|
23
24
|
|
|
24
25
|
ServerAppCallable = Callable[[Driver, Context], None]
|
|
25
26
|
Workflow = Callable[[Driver, Context], None]
|
|
27
|
+
ServerFn = Callable[[Context], ServerAppComponents]
|
flwr/server/utils/__init__.py
CHANGED
flwr/server/utils/tensorboard.py
CHANGED