flwr-nightly 1.9.0.dev20240422__py3-none-any.whl → 1.9.0.dev20240424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

flwr/cli/new/new.py CHANGED
@@ -36,6 +36,7 @@ class MlFramework(str, Enum):
36
36
  NUMPY = "NumPy"
37
37
  PYTORCH = "PyTorch"
38
38
  TENSORFLOW = "TensorFlow"
39
+ SKLEARN = "sklearn"
39
40
 
40
41
 
41
42
  class TemplateNotFound(Exception):
@@ -0,0 +1,94 @@
1
+ """$project_name: A Flower / Scikit-Learn app."""
2
+
3
+ import warnings
4
+
5
+ import numpy as np
6
+ from flwr.client import NumPyClient, ClientApp
7
+ from flwr_datasets import FederatedDataset
8
+ from sklearn.linear_model import LogisticRegression
9
+ from sklearn.metrics import log_loss
10
+
11
+
12
+ def get_model_parameters(model):
13
+ if model.fit_intercept:
14
+ params = [
15
+ model.coef_,
16
+ model.intercept_,
17
+ ]
18
+ else:
19
+ params = [model.coef_]
20
+ return params
21
+
22
+
23
+ def set_model_params(model, params):
24
+ model.coef_ = params[0]
25
+ if model.fit_intercept:
26
+ model.intercept_ = params[1]
27
+ return model
28
+
29
+
30
+ def set_initial_params(model):
31
+ n_classes = 10 # MNIST has 10 classes
32
+ n_features = 784 # Number of features in dataset
33
+ model.classes_ = np.array([i for i in range(10)])
34
+
35
+ model.coef_ = np.zeros((n_classes, n_features))
36
+ if model.fit_intercept:
37
+ model.intercept_ = np.zeros((n_classes,))
38
+
39
+
40
+ class FlowerClient(NumPyClient):
41
+ def __init__(self, model, X_train, X_test, y_train, y_test):
42
+ self.model = model
43
+ self.X_train = X_train
44
+ self.X_test = X_test
45
+ self.y_train = y_train
46
+ self.y_test = y_test
47
+
48
+ def get_parameters(self, config):
49
+ return get_model_parameters(self.model)
50
+
51
+ def fit(self, parameters, config):
52
+ set_model_params(self.model, parameters)
53
+
54
+ # Ignore convergence failure due to low local epochs
55
+ with warnings.catch_warnings():
56
+ warnings.simplefilter("ignore")
57
+ self.model.fit(self.X_train, self.y_train)
58
+
59
+ return get_model_parameters(self.model), len(self.X_train), {}
60
+
61
+ def evaluate(self, parameters, config):
62
+ set_model_params(self.model, parameters)
63
+
64
+ loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
65
+ accuracy = self.model.score(self.X_test, self.y_test)
66
+
67
+ return loss, len(self.X_test), {"accuracy": accuracy}
68
+
69
+ fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})
70
+
71
+ def client_fn(cid: str):
72
+ dataset = fds.load_partition(int(cid), "train").with_format("numpy")
73
+
74
+ X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
75
+
76
+ # Split the on edge data: 80% train, 20% test
77
+ X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
78
+ y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
79
+
80
+ # Create LogisticRegression Model
81
+ model = LogisticRegression(
82
+ penalty="l2",
83
+ max_iter=1, # local epoch
84
+ warm_start=True, # prevent refreshing weights when fitting
85
+ )
86
+
87
+ # Setting initial parameters, akin to model.compile for keras models
88
+ set_initial_params(model)
89
+
90
+ return FlowerClient(model, X_train, X_test, y_train, y_test).to_client()
91
+
92
+
93
+ # Flower ClientApp
94
+ app = ClientApp(client_fn=client_fn)
@@ -0,0 +1,17 @@
1
+ """$project_name: A Flower / Scikit-Learn app."""
2
+
3
+ from flwr.server import ServerApp, ServerConfig
4
+ from flwr.server.strategy import FedAvg
5
+
6
+
7
+ strategy = FedAvg(
8
+ fraction_fit=1.0,
9
+ fraction_evaluate=1.0,
10
+ min_available_clients=2,
11
+ )
12
+
13
+ # Create ServerApp
14
+ app = ServerApp(
15
+ config=ServerConfig(num_rounds=3),
16
+ strategy=strategy,
17
+ )
@@ -0,0 +1,24 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$project_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = {text = "Apache License (2.0)"}
13
+ dependencies = [
14
+ "flwr[simulation]>=1.8.0,<2.0",
15
+ "flwr-datasets[vision]>=0.0.2,<1.0.0",
16
+ "scikit-learn>=1.1.1",
17
+ ]
18
+
19
+ [tool.hatch.build.targets.wheel]
20
+ packages = ["."]
21
+
22
+ [flower.components]
23
+ serverapp = "$project_name.server:app"
24
+ clientapp = "$project_name.client:app"
@@ -22,6 +22,8 @@ from pathlib import Path
22
22
  from queue import Queue
23
23
  from typing import Callable, Iterator, Optional, Tuple, Union, cast
24
24
 
25
+ from cryptography.hazmat.primitives.asymmetric import ec
26
+
25
27
  from flwr.common import (
26
28
  DEFAULT_TTL,
27
29
  GRPC_MAX_MESSAGE_LENGTH,
@@ -56,12 +58,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
56
58
 
57
59
 
58
60
  @contextmanager
59
- def grpc_connection( # pylint: disable=R0915
61
+ def grpc_connection( # pylint: disable=R0913, R0915
60
62
  server_address: str,
61
63
  insecure: bool,
62
64
  retry_invoker: RetryInvoker, # pylint: disable=unused-argument
63
65
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
64
66
  root_certificates: Optional[Union[bytes, str]] = None,
67
+ authentication_keys: Optional[ # pylint: disable=unused-argument
68
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
69
+ ] = None,
65
70
  ) -> Iterator[
66
71
  Tuple[
67
72
  Callable[[], Optional[Message]],
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower client interceptor."""
16
+
17
+
18
+ import base64
19
+ import collections
20
+ from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
+
22
+ import grpc
23
+ from cryptography.hazmat.primitives.asymmetric import ec
24
+
25
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
26
+ bytes_to_public_key,
27
+ compute_hmac,
28
+ generate_shared_key,
29
+ public_key_to_bytes,
30
+ )
31
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
+ CreateNodeRequest,
33
+ DeleteNodeRequest,
34
+ GetRunRequest,
35
+ PullTaskInsRequest,
36
+ PushTaskResRequest,
37
+ )
38
+
39
+ _PUBLIC_KEY_HEADER = "public-key"
40
+ _AUTH_TOKEN_HEADER = "auth-token"
41
+
42
+ Request = Union[
43
+ CreateNodeRequest,
44
+ DeleteNodeRequest,
45
+ PullTaskInsRequest,
46
+ PushTaskResRequest,
47
+ GetRunRequest,
48
+ ]
49
+
50
+
51
+ def _get_value_from_tuples(
52
+ key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
53
+ ) -> bytes:
54
+ value = next((value for key, value in tuples if key == key_string), "")
55
+ if isinstance(value, str):
56
+ return value.encode()
57
+
58
+ return value
59
+
60
+
61
+ class _ClientCallDetails(
62
+ collections.namedtuple(
63
+ "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
64
+ ),
65
+ grpc.ClientCallDetails, # type: ignore
66
+ ):
67
+ """Details for each client call.
68
+
69
+ The class will be passed on as the first argument in continuation function.
70
+ In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
71
+ """
72
+
73
+
74
+ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
75
+ """Client interceptor for client authentication."""
76
+
77
+ def __init__(
78
+ self,
79
+ private_key: ec.EllipticCurvePrivateKey,
80
+ public_key: ec.EllipticCurvePublicKey,
81
+ ):
82
+ self.private_key = private_key
83
+ self.public_key = public_key
84
+ self.shared_secret: Optional[bytes] = None
85
+ self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
86
+ self.encoded_public_key = base64.urlsafe_b64encode(
87
+ public_key_to_bytes(self.public_key)
88
+ )
89
+
90
+ def intercept_unary_unary(
91
+ self,
92
+ continuation: Callable[[Any, Any], Any],
93
+ client_call_details: grpc.ClientCallDetails,
94
+ request: Request,
95
+ ) -> grpc.Call:
96
+ """Flower client interceptor.
97
+
98
+ Intercept unary call from client and add necessary authentication header in the
99
+ RPC metadata.
100
+ """
101
+ metadata = []
102
+ postprocess = False
103
+ if client_call_details.metadata is not None:
104
+ metadata = list(client_call_details.metadata)
105
+
106
+ # Always add the public key header
107
+ metadata.append(
108
+ (
109
+ _PUBLIC_KEY_HEADER,
110
+ self.encoded_public_key,
111
+ )
112
+ )
113
+
114
+ if isinstance(request, CreateNodeRequest):
115
+ postprocess = True
116
+ elif isinstance(
117
+ request,
118
+ (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest),
119
+ ):
120
+ if self.shared_secret is None:
121
+ raise RuntimeError("Failure to compute hmac")
122
+
123
+ metadata.append(
124
+ (
125
+ _AUTH_TOKEN_HEADER,
126
+ base64.urlsafe_b64encode(
127
+ compute_hmac(
128
+ self.shared_secret, request.SerializeToString(True)
129
+ )
130
+ ),
131
+ )
132
+ )
133
+
134
+ client_call_details = _ClientCallDetails(
135
+ client_call_details.method,
136
+ client_call_details.timeout,
137
+ metadata,
138
+ client_call_details.credentials,
139
+ )
140
+
141
+ response = continuation(client_call_details, request)
142
+ if postprocess:
143
+ server_public_key_bytes = base64.urlsafe_b64decode(
144
+ _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
145
+ )
146
+ self.server_public_key = bytes_to_public_key(server_public_key_bytes)
147
+ self.shared_secret = generate_shared_key(
148
+ self.private_key, self.server_public_key
149
+ )
150
+ return response
@@ -21,7 +21,10 @@ from contextlib import contextmanager
21
21
  from copy import copy
22
22
  from logging import DEBUG, ERROR
23
23
  from pathlib import Path
24
- from typing import Callable, Iterator, Optional, Tuple, Union, cast
24
+ from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
25
+
26
+ import grpc
27
+ from cryptography.hazmat.primitives.asymmetric import ec
25
28
 
26
29
  from flwr.client.heartbeat import start_ping_loop
27
30
  from flwr.client.message_handler.message_handler import validate_out_message
@@ -52,6 +55,8 @@ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
52
55
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
53
56
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
54
57
 
58
+ from .client_interceptor import AuthenticateClientInterceptor
59
+
55
60
 
56
61
  def on_channel_state_change(channel_connectivity: str) -> None:
57
62
  """Log channel connectivity."""
@@ -59,12 +64,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
59
64
 
60
65
 
61
66
  @contextmanager
62
- def grpc_request_response( # pylint: disable=R0914, R0915
67
+ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
63
68
  server_address: str,
64
69
  insecure: bool,
65
70
  retry_invoker: RetryInvoker,
66
71
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
67
72
  root_certificates: Optional[Union[bytes, str]] = None,
73
+ authentication_keys: Optional[
74
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
75
+ ] = None,
68
76
  ) -> Iterator[
69
77
  Tuple[
70
78
  Callable[[], Optional[Message]],
@@ -109,11 +117,18 @@ def grpc_request_response( # pylint: disable=R0914, R0915
109
117
  if isinstance(root_certificates, str):
110
118
  root_certificates = Path(root_certificates).read_bytes()
111
119
 
120
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
121
+ if authentication_keys is not None:
122
+ interceptors = AuthenticateClientInterceptor(
123
+ authentication_keys[0], authentication_keys[1]
124
+ )
125
+
112
126
  channel = create_channel(
113
127
  server_address=server_address,
114
128
  insecure=insecure,
115
129
  root_certificates=root_certificates,
116
130
  max_message_length=max_message_length,
131
+ interceptors=interceptors,
117
132
  )
118
133
  channel.subscribe(on_channel_state_change)
119
134
 
@@ -23,6 +23,7 @@ from copy import copy
23
23
  from logging import ERROR, INFO, WARN
24
24
  from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union
25
25
 
26
+ from cryptography.hazmat.primitives.asymmetric import ec
26
27
  from google.protobuf.message import Message as GrpcMessage
27
28
 
28
29
  from flwr.client.heartbeat import start_ping_loop
@@ -74,7 +75,7 @@ T = TypeVar("T", bound=GrpcMessage)
74
75
 
75
76
 
76
77
  @contextmanager
77
- def http_request_response( # pylint: disable=R0914, R0915
78
+ def http_request_response( # pylint: disable=,R0913, R0914, R0915
78
79
  server_address: str,
79
80
  insecure: bool, # pylint: disable=unused-argument
80
81
  retry_invoker: RetryInvoker,
@@ -82,6 +83,9 @@ def http_request_response( # pylint: disable=R0914, R0915
82
83
  root_certificates: Optional[
83
84
  Union[bytes, str]
84
85
  ] = None, # pylint: disable=unused-argument
86
+ authentication_keys: Optional[ # pylint: disable=unused-argument
87
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
88
+ ] = None,
85
89
  ) -> Iterator[
86
90
  Tuple[
87
91
  Callable[[], Optional[Message]],
flwr/common/grpc.py CHANGED
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from logging import DEBUG
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
20
 
21
21
  import grpc
22
22
 
@@ -30,6 +30,7 @@ def create_channel(
30
30
  insecure: bool,
31
31
  root_certificates: Optional[bytes] = None,
32
32
  max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
33
+ interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
33
34
  ) -> grpc.Channel:
34
35
  """Create a gRPC channel, either secure or insecure."""
35
36
  # Check for conflicting parameters
@@ -57,4 +58,7 @@ def create_channel(
57
58
  )
58
59
  log(DEBUG, "Opened secure gRPC connection using certificates")
59
60
 
61
+ if interceptors is not None:
62
+ channel = grpc.intercept_channel(channel, interceptors)
63
+
60
64
  return channel
@@ -18,8 +18,9 @@
18
18
  import base64
19
19
  from typing import Tuple, cast
20
20
 
21
+ from cryptography.exceptions import InvalidSignature
21
22
  from cryptography.fernet import Fernet
22
- from cryptography.hazmat.primitives import hashes, serialization
23
+ from cryptography.hazmat.primitives import hashes, hmac, serialization
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
  from cryptography.hazmat.primitives.kdf.hkdf import HKDF
25
26
 
@@ -98,3 +99,21 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
98
99
  # The input key must be url safe
99
100
  fernet = Fernet(key)
100
101
  return fernet.decrypt(ciphertext)
102
+
103
+
104
+ def compute_hmac(key: bytes, message: bytes) -> bytes:
105
+ """Compute hmac of a message using key as hash."""
106
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
107
+ computed_hmac.update(message)
108
+ return computed_hmac.finalize()
109
+
110
+
111
+ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
112
+ """Verify hmac of a message using key as hash."""
113
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
114
+ computed_hmac.update(message)
115
+ try:
116
+ computed_hmac.verify(hmac_value)
117
+ return True
118
+ except InvalidSignature:
119
+ return False
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
64
64
  """Create run ID."""
65
65
  log(INFO, "DriverServicer.CreateRun")
66
66
  state: State = self.state_factory.state()
67
- run_id = state.create_run("None/None", "None")
67
+ run_id = state.create_run(request.fab_id, request.fab_version)
68
68
  return CreateRunResponse(run_id=run_id)
69
69
 
70
70
  def PushTaskIns(
@@ -30,7 +30,7 @@ from flwr.server.utils import validate_task_ins_or_res
30
30
  from .utils import make_node_unavailable_taskres
31
31
 
32
32
 
33
- class InMemoryState(State):
33
+ class InMemoryState(State): # pylint: disable=R0902
34
34
  """In-memory State implementation."""
35
35
 
36
36
  def __init__(self) -> None:
@@ -40,6 +40,9 @@ class InMemoryState(State):
40
40
  self.run_ids: Dict[int, Tuple[str, str]] = {}
41
41
  self.task_ins_store: Dict[UUID, TaskIns] = {}
42
42
  self.task_res_store: Dict[UUID, TaskRes] = {}
43
+ self.client_public_keys: Set[bytes] = set()
44
+ self.server_public_key: Optional[bytes] = None
45
+ self.server_private_key: Optional[bytes] = None
43
46
  self.lock = threading.Lock()
44
47
 
45
48
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
@@ -251,6 +254,39 @@ class InMemoryState(State):
251
254
  log(ERROR, "Unexpected run creation failure.")
252
255
  return 0
253
256
 
257
+ def store_server_public_private_key(
258
+ self, public_key: bytes, private_key: bytes
259
+ ) -> None:
260
+ """Store `server_public_key` and `server_private_key` in state."""
261
+ with self.lock:
262
+ if self.server_private_key is None and self.server_public_key is None:
263
+ self.server_private_key = private_key
264
+ self.server_public_key = public_key
265
+ else:
266
+ raise RuntimeError("Server public and private key already set")
267
+
268
+ def get_server_private_key(self) -> Optional[bytes]:
269
+ """Retrieve `server_private_key` in urlsafe bytes."""
270
+ return self.server_private_key
271
+
272
+ def get_server_public_key(self) -> Optional[bytes]:
273
+ """Retrieve `server_public_key` in urlsafe bytes."""
274
+ return self.server_public_key
275
+
276
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
277
+ """Store a set of `client_public_keys` in state."""
278
+ with self.lock:
279
+ self.client_public_keys = public_keys
280
+
281
+ def store_client_public_key(self, public_key: bytes) -> None:
282
+ """Store a `client_public_key` in state."""
283
+ with self.lock:
284
+ self.client_public_keys.add(public_key)
285
+
286
+ def get_client_public_keys(self) -> Set[bytes]:
287
+ """Retrieve all currently stored `client_public_keys` as a set."""
288
+ return self.client_public_keys
289
+
254
290
  def get_run(self, run_id: int) -> Tuple[int, str, str]:
255
291
  """Retrieve information about the run with the specified `run_id`."""
256
292
  with self.lock:
@@ -20,7 +20,7 @@ import re
20
20
  import sqlite3
21
21
  import time
22
22
  from logging import DEBUG, ERROR
23
- from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
23
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24
24
  from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import log, now
@@ -40,6 +40,19 @@ CREATE TABLE IF NOT EXISTS node(
40
40
  );
41
41
  """
42
42
 
43
+ SQL_CREATE_TABLE_CREDENTIAL = """
44
+ CREATE TABLE IF NOT EXISTS credential(
45
+ public_key BLOB PRIMARY KEY,
46
+ private_key BLOB
47
+ );
48
+ """
49
+
50
+ SQL_CREATE_TABLE_PUBLIC_KEY = """
51
+ CREATE TABLE IF NOT EXISTS public_key(
52
+ public_key BLOB UNIQUE
53
+ );
54
+ """
55
+
43
56
  SQL_CREATE_INDEX_ONLINE_UNTIL = """
44
57
  CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
45
58
  """
@@ -72,7 +85,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
72
85
  );
73
86
  """
74
87
 
75
-
76
88
  SQL_CREATE_TABLE_TASK_RES = """
77
89
  CREATE TABLE IF NOT EXISTS task_res(
78
90
  task_id TEXT UNIQUE,
@@ -96,7 +108,7 @@ CREATE TABLE IF NOT EXISTS task_res(
96
108
  DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
97
109
 
98
110
 
99
- class SqliteState(State):
111
+ class SqliteState(State): # pylint: disable=R0904
100
112
  """SQLite-based state implementation."""
101
113
 
102
114
  def __init__(
@@ -134,6 +146,8 @@ class SqliteState(State):
134
146
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
135
147
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
136
148
  cur.execute(SQL_CREATE_TABLE_NODE)
149
+ cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
150
+ cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
137
151
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
138
152
  res = cur.execute("SELECT name FROM sqlite_schema;")
139
153
 
@@ -142,7 +156,7 @@ class SqliteState(State):
142
156
  def query(
143
157
  self,
144
158
  query: str,
145
- data: Optional[Union[List[DictOrTuple], DictOrTuple]] = None,
159
+ data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
146
160
  ) -> List[Dict[str, Any]]:
147
161
  """Execute a SQL query."""
148
162
  if self.conn is None:
@@ -575,6 +589,59 @@ class SqliteState(State):
575
589
  log(ERROR, "Unexpected run creation failure.")
576
590
  return 0
577
591
 
592
+ def store_server_public_private_key(
593
+ self, public_key: bytes, private_key: bytes
594
+ ) -> None:
595
+ """Store `server_public_key` and `server_private_key` in state."""
596
+ query = "SELECT COUNT(*) FROM credential"
597
+ count = self.query(query)[0]["COUNT(*)"]
598
+ if count < 1:
599
+ query = (
600
+ "INSERT OR REPLACE INTO credential (public_key, private_key) "
601
+ "VALUES (:public_key, :private_key)"
602
+ )
603
+ self.query(query, {"public_key": public_key, "private_key": private_key})
604
+ else:
605
+ raise RuntimeError("Server public and private key already set")
606
+
607
+ def get_server_private_key(self) -> Optional[bytes]:
608
+ """Retrieve `server_private_key` in urlsafe bytes."""
609
+ query = "SELECT private_key FROM credential"
610
+ rows = self.query(query)
611
+ try:
612
+ private_key: Optional[bytes] = rows[0]["private_key"]
613
+ except IndexError:
614
+ private_key = None
615
+ return private_key
616
+
617
+ def get_server_public_key(self) -> Optional[bytes]:
618
+ """Retrieve `server_public_key` in urlsafe bytes."""
619
+ query = "SELECT public_key FROM credential"
620
+ rows = self.query(query)
621
+ try:
622
+ public_key: Optional[bytes] = rows[0]["public_key"]
623
+ except IndexError:
624
+ public_key = None
625
+ return public_key
626
+
627
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
628
+ """Store a set of `client_public_keys` in state."""
629
+ query = "INSERT INTO public_key (public_key) VALUES (?)"
630
+ data = [(key,) for key in public_keys]
631
+ self.query(query, data)
632
+
633
+ def store_client_public_key(self, public_key: bytes) -> None:
634
+ """Store a `client_public_key` in state."""
635
+ query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
636
+ self.query(query, {"public_key": public_key})
637
+
638
+ def get_client_public_keys(self) -> Set[bytes]:
639
+ """Retrieve all currently stored `client_public_keys` as a set."""
640
+ query = "SELECT public_key FROM public_key"
641
+ rows = self.query(query)
642
+ result: Set[bytes] = {row["public_key"] for row in rows}
643
+ return result
644
+
578
645
  def get_run(self, run_id: int) -> Tuple[int, str, str]:
579
646
  """Retrieve information about the run with the specified `run_id`."""
580
647
  query = "SELECT * FROM run WHERE run_id = ?;"
@@ -171,6 +171,32 @@ class State(abc.ABC):
171
171
  - `fab_version`: The version of the FAB used in the specified run.
172
172
  """
173
173
 
174
+ @abc.abstractmethod
175
+ def store_server_public_private_key(
176
+ self, public_key: bytes, private_key: bytes
177
+ ) -> None:
178
+ """Store `server_public_key` and `server_private_key` in state."""
179
+
180
+ @abc.abstractmethod
181
+ def get_server_private_key(self) -> Optional[bytes]:
182
+ """Retrieve `server_private_key` in urlsafe bytes."""
183
+
184
+ @abc.abstractmethod
185
+ def get_server_public_key(self) -> Optional[bytes]:
186
+ """Retrieve `server_public_key` in urlsafe bytes."""
187
+
188
+ @abc.abstractmethod
189
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
190
+ """Store a set of `client_public_keys` in state."""
191
+
192
+ @abc.abstractmethod
193
+ def store_client_public_key(self, public_key: bytes) -> None:
194
+ """Store a `client_public_key` in state."""
195
+
196
+ @abc.abstractmethod
197
+ def get_client_public_keys(self) -> Set[bytes]:
198
+ """Retrieve all currently stored `client_public_keys` as a set."""
199
+
174
200
  @abc.abstractmethod
175
201
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
176
202
  """Acknowledge a ping received from a node, serving as a heartbeat.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.9.0.dev20240422
3
+ Version: 1.9.0.dev20240424
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -4,7 +4,7 @@ flwr/cli/app.py,sha256=38thPnMydBmNAxNE9mz4By-KdRUhJfoUgeDuAxMYF_U,1095
4
4
  flwr/cli/config_utils.py,sha256=1wTPQqOU2fKeU4FP5KyG0xMa0F-qy8x1m2WvztPORb4,5597
5
5
  flwr/cli/example.py,sha256=1bGDYll3BXQY2kRqSN-oICqS5n1b9m0g0RvXTopXHl4,2215
6
6
  flwr/cli/new/__init__.py,sha256=cQzK1WH4JP2awef1t2UQ2xjl1agVEz9rwutV18SWV1k,789
7
- flwr/cli/new/new.py,sha256=OHTOpuHRqmafsoV_Hv1V1544mZz54Z0qDRRtMT3dR-M,5380
7
+ flwr/cli/new/new.py,sha256=hqcHjun3keeREegDrdLJMPHKkVBYIN4HUUeCl3hzVgI,5404
8
8
  flwr/cli/new/templates/__init__.py,sha256=4luU8RL-CK8JJCstQ_ON809W9bNTkY1l9zSaPKBkgwY,725
9
9
  flwr/cli/new/templates/app/.gitignore.tpl,sha256=XixnHdyeMB2vwkGtGnwHqoWpH-9WChdyG0GXe57duhc,3078
10
10
  flwr/cli/new/templates/app/README.md.tpl,sha256=_qGtgpKYKoCJVjQnvlBMKvFs_1gzTcL908I3KJg0oAM,668
@@ -13,13 +13,16 @@ flwr/cli/new/templates/app/code/__init__.py,sha256=EM6vfvgAILKPaPn7H1wMV1Wi01WyZ
13
13
  flwr/cli/new/templates/app/code/__init__.py.tpl,sha256=olwrBeJemHNBWvjc6gJURloFRqW40dAy7FRQA5pDqHU,21
14
14
  flwr/cli/new/templates/app/code/client.numpy.py.tpl,sha256=mTh7Y_jOJrPUvDYHVJy4wJCnjXZV_q-jlDkB07U5GSk,521
15
15
  flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=671daPcdZaC4Z5k-dqmCovfb2_FShGmqfjwaR8y6EC8,1173
16
+ flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=S71SZiHaRXtKqUk3m5Elc_c6HhKAIKLalrKOQ3p20No,2801
16
17
  flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=N9SbnI65r2K9FHV_wn4JSpmVeyYpD0qEMehbHcGm4t0,1911
17
18
  flwr/cli/new/templates/app/code/server.numpy.py.tpl,sha256=fRxrDXV7pB1aDhQUXMBmrCsC1zp0uKwsBxZBx1JzbHA,248
18
19
  flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=xtKvUivNMzgOcLSOtnjWouJzIFbXdUQVYMm27uwyJpI,594
20
+ flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=cLzOpQzGIUzEazuFsjBpXAQUNPy6in6zR33SCqhix6o,341
19
21
  flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=GUGH8c_6cxgUB9obVJPaA4thxI7OVXsItyfQDsn9E5k,371
20
22
  flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=NvajdZN-eTyfdqKK0v2MrvWITXw9BjJ3Ri5c1haPJDs,3684
21
23
  flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=0oTH0lY7q-PpRV4HA5woxJ1eWIgZRFcFsHa7-1lULIQ,489
22
24
  flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=GYbMAFD90JBRvy8fJbLU7nDITD3sxHv1TncQrg6mjEE,558
25
+ flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=7p6s2jJpC8ZO-TfiJ0cE3fzkIhc4ndj9SY1hiYvSM5Q,538
23
26
  flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=7I8BYtE28cnc7ZiOlOp6_zeLsjLRlwa0Y4sjoP7r9VU,537
24
27
  flwr/cli/run/__init__.py,sha256=oCd6HmQDx-sqver1gecgx-uMA38BLTSiiKpl7RGNceg,789
25
28
  flwr/cli/run/run.py,sha256=qxXgShEXHONx-Gjpl515HF60QzRA-Ygpj2sbl0bZUAA,2331
@@ -30,9 +33,10 @@ flwr/client/client.py,sha256=Vp9UkOkoHdNfn6iMYZsj_5m_GICiFfUlKEVaLad-YhM,8183
30
33
  flwr/client/client_app.py,sha256=-Cs0084tLQUoBCeYZdG2KgU7cjp95_ZJ4MfjoaN4Fzk,8636
31
34
  flwr/client/dpfedavg_numpy_client.py,sha256=9Tnig4iml2J88HBKNahegjXjbfvIQyBtaIQaqjbeqsA,7435
32
35
  flwr/client/grpc_client/__init__.py,sha256=LsnbqXiJhgQcB0XzAlUQgPx011Uf7Y7yabIC1HxivJ8,735
33
- flwr/client/grpc_client/connection.py,sha256=7MfyR6hEq3u46wK3s0vP3eubFq19pKZJCG3EFw_i4T4,8775
36
+ flwr/client/grpc_client/connection.py,sha256=KWbBwuvn1-2wjrAKteydGCZC_7A2zmEjk3DycQWafrA,8993
34
37
  flwr/client/grpc_rere_client/__init__.py,sha256=avn6W_vHEM_yZEB1S7hCZgnTbXb6ZujqRP_vAzyXu-0,752
35
- flwr/client/grpc_rere_client/connection.py,sha256=IEGkM0MymZ1tyL6yAL4ic5ZpGy_zg9bJBVf5KCSL2iY,9052
38
+ flwr/client/grpc_rere_client/client_interceptor.py,sha256=cZZHd_lVlVuyPrhXf3mB4_Zpmhpmrv6-18E9XJisImE,4761
39
+ flwr/client/grpc_rere_client/connection.py,sha256=gSSJJ9pSe5SgUb1Ey-xcrVK6xArUkwq0yGdav0h2kww,9597
36
40
  flwr/client/heartbeat.py,sha256=cx37mJBH8LyoIN4Lks85wtqT1mnU5GulQnr4pGCvAq0,2404
37
41
  flwr/client/message_handler/__init__.py,sha256=abHvBRJJiiaAMNgeILQbMOa6h8WqMK2BcnvxwQZFpic,719
38
42
  flwr/client/message_handler/message_handler.py,sha256=ml_FlduAJ5pxO31n1tKRrWfQRSxkMgKLbwXXcRsNSos,6553
@@ -49,7 +53,7 @@ flwr/client/node_state.py,sha256=KTTs_l4I0jBM7IsSsbAGjhfL_yZC3QANbzyvyfZBRDM,177
49
53
  flwr/client/node_state_tests.py,sha256=gPwz0zf2iuDSa11jedkur_u3Xm7lokIDG5ALD2MCvSw,2195
50
54
  flwr/client/numpy_client.py,sha256=u76GWAdHmJM88Agm2EgLQSvO8Jnk225mJTk-_TmPjFE,10283
51
55
  flwr/client/rest_client/__init__.py,sha256=ThwOnkMdzxo_UuyTI47Q7y9oSpuTgNT2OuFvJCfuDiw,735
52
- flwr/client/rest_client/connection.py,sha256=ZxTFVDXlONqKTX6uYgxshoEWqzqVcQ8QQ2hKS93oLM8,11302
56
+ flwr/client/rest_client/connection.py,sha256=MspqM5RjrQe09_2BUEEVGstA5x9Qz_RWdXXraOic3i8,11520
53
57
  flwr/client/supernode/__init__.py,sha256=D5swXxemuRbA2rB_T9B8LwJW-_PucXwmlFQQerwIUv0,793
54
58
  flwr/client/supernode/app.py,sha256=gauvN8elkIy0vuT0GxT7MmkuBRY74ckZfpxejE7dduM,3861
55
59
  flwr/client/typing.py,sha256=c9EvjlEjasxn1Wqx6bGl6Xg6vM1gMFfmXht-E2i5J-k,1006
@@ -62,7 +66,7 @@ flwr/common/differential_privacy.py,sha256=WZWrL7C9XaB9l9NDkLDI5PvM7jwcoTTFu08ZV
62
66
  flwr/common/differential_privacy_constants.py,sha256=c7b7tqgvT7yMK0XN9ndiTBs4mQf6d3qk6K7KBZGlV4Q,1074
63
67
  flwr/common/dp.py,sha256=Hc3lLHihjexbJaD_ft31gdv9XRcwOTgDBwJzICuok3A,2004
64
68
  flwr/common/exit_handlers.py,sha256=2Nt0wLhc17KQQsLPFSRAjjhUiEFfJK6tNozdGiIY4Fs,2812
65
- flwr/common/grpc.py,sha256=HimjpTtIY3Vfqtlq3u-CYWjqAl9rSn0uo3A8JjhUmwQ,2273
69
+ flwr/common/grpc.py,sha256=Yx_YFK24cU4U81RpXrdVwEVY_jTy4RE19cHtBxE2XOE,2460
66
70
  flwr/common/logger.py,sha256=3hfKun9YISWj4i_QhxgZdnaHJc4x-QvFJQJTKHZ2KHs,6096
67
71
  flwr/common/message.py,sha256=NvxiWT9YI8GmIt2r3EPVPFFAFQo3xhP09mvnAxjHivQ,12385
68
72
  flwr/common/object_ref.py,sha256=ELoUCAFO-vbjJC41CGpa-WBG2SLYe3ErW-d9YCG3zqA,4961
@@ -80,7 +84,7 @@ flwr/common/retry_invoker.py,sha256=dQY5fPIKhy9OiFswZhLxA9fB455u-DYCvDVcFJmrPDk,
80
84
  flwr/common/secure_aggregation/__init__.py,sha256=29nHIUO2L8-KhNHQ2KmIgRo_4CPkq4LgLCUN0on5FgI,731
81
85
  flwr/common/secure_aggregation/crypto/__init__.py,sha256=dz7pVx2aPrHxr_AwgO5mIiTzu4PcvUxRq9NLBbFcsf8,738
82
86
  flwr/common/secure_aggregation/crypto/shamir.py,sha256=yY35ZgHlB4YyGW_buG-1X-0M-ejXuQzISgYLgC_Z9TY,2792
83
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py,sha256=-zDyQoTsHHQjR7o-92FNIikg1zM_Ke9yynaD5u2BXbQ,3546
87
+ flwr/common/secure_aggregation/crypto/symmetric_encryption.py,sha256=BP2qzZLz1FAdJt-E_7D--4yfyXsmfNqtvQG1XJfS4J0,4173
84
88
  flwr/common/secure_aggregation/ndarrays_arithmetic.py,sha256=66mNQCz64r7qzvXwFrXP6zz7YMi8EkTOABN7KulkKc4,3026
85
89
  flwr/common/secure_aggregation/quantization.py,sha256=appui7GGrkRPsupF59TkapeV4Na_CyPi73JtJ1pimdI,2310
86
90
  flwr/common/secure_aggregation/secaggplus_constants.py,sha256=Fh7-n6pgL4TUnHpNYXo8iW-n5cOGQgQa-c7RcU80tqQ,2183
@@ -164,7 +168,7 @@ flwr/server/strategy/strategy.py,sha256=g6VoIFogEviRub6G4QsKdIp6M_Ek6GhBhqcdNx5u
164
168
  flwr/server/superlink/__init__.py,sha256=8tHYCfodUlRD8PCP9fHgvu8cz5N31A2QoRVL0jDJ15E,707
165
169
  flwr/server/superlink/driver/__init__.py,sha256=STB1_DASVEg7Cu6L7VYxTzV7UMkgtBkFim09Z82Dh8I,712
166
170
  flwr/server/superlink/driver/driver_grpc.py,sha256=1qSGDs1k_OVPWxp2ofxvQgtYXExrMeC3N_rNPVWH65M,1932
167
- flwr/server/superlink/driver/driver_servicer.py,sha256=IKx3rC8s2193iCJxLEc_njndTtidkVM7Vk-RWjGngl0,4780
171
+ flwr/server/superlink/driver/driver_servicer.py,sha256=y0w8p3D9RQlMdgizfknHZnCEKf0O0IpLsKhHPxmp2pQ,4796
168
172
  flwr/server/superlink/fleet/__init__.py,sha256=C6GCSD5eP5Of6_dIeSe1jx9HnV0icsvWyQ5EKAUHJRU,711
169
173
  flwr/server/superlink/fleet/grpc_bidi/__init__.py,sha256=mgGJGjwT6VU7ovC1gdnnqttjyBPlNIcZnYRqx4K3IBQ,735
170
174
  flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py,sha256=57b3UL5-baGdLwgCtB0dCUTTSbmmfMAXcXV5bjPZNWQ,5993
@@ -183,9 +187,9 @@ flwr/server/superlink/fleet/vce/backend/backend.py,sha256=LJsKl7oixVvptcG98Rd9ej
183
187
  flwr/server/superlink/fleet/vce/backend/raybackend.py,sha256=TaT2EpbVEsIY0EDzF8obadyZaSXjD38TFGdDPI-ytD0,6375
184
188
  flwr/server/superlink/fleet/vce/vce_api.py,sha256=c2J2m6v1jDyuAhiBArdZNIk4cbiZNFJkpKlBJFEQq-c,12454
185
189
  flwr/server/superlink/state/__init__.py,sha256=ij-7Ms-hyordQdRmGQxY1-nVa4OhixJ0jr7_YDkys0s,1003
186
- flwr/server/superlink/state/in_memory_state.py,sha256=OXpTb7ER7fnI55cFmcux2cLN6U_ACYjmRHkhYVHW2Ww,10083
187
- flwr/server/superlink/state/sqlite_state.py,sha256=xDyvtuInAsLq65czbqLrLOv4ec61XxH_FhW_Q2NXrgM,24580
188
- flwr/server/superlink/state/state.py,sha256=AsORTtR5Y5sRpxKPG0iueWOvnY0uISXgpAsyPSMgZXY,6762
190
+ flwr/server/superlink/state/in_memory_state.py,sha256=d6T6NXGyvo53LnFJSRRsiCnFOXikYMmCcCUmiarOuD0,11651
191
+ flwr/server/superlink/state/sqlite_state.py,sha256=eU5Ll6V0yQp9fnZbo5L-i0BM41SkmjK7kHzc1EHNr_M,27167
192
+ flwr/server/superlink/state/state.py,sha256=bdEqkjfBbVtbT_YudaMFFLSR-q_R1q-tsvjdni3YJKg,7709
189
193
  flwr/server/superlink/state/state_factory.py,sha256=91cSB-KOAFM37z7T098WxTkVeKNaAZ_mTI75snn2_tk,1654
190
194
  flwr/server/superlink/state/utils.py,sha256=qhIjBu5_rqm9GLMB6QS5TIRrMDVs85lmY17BqZ1ccLk,2207
191
195
  flwr/server/typing.py,sha256=2zSG-KuDAgwFPuzgVjTLDaEqJ8gXXGqFR2RD-qIk730,913
@@ -205,8 +209,8 @@ flwr/simulation/ray_transport/ray_actor.py,sha256=_wv2eP7qxkCZ-6rMyYWnjLrGPBZRxj
205
209
  flwr/simulation/ray_transport/ray_client_proxy.py,sha256=oDu4sEPIOu39vrNi-fqDAe10xtNUXMO49bM2RWfRcyw,6738
206
210
  flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
207
211
  flwr/simulation/run_simulation.py,sha256=nxXNv3r8ODImd5o6f0sa_w5L0I08LD2Udw2OTXStRnQ,15694
208
- flwr_nightly-1.9.0.dev20240422.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
209
- flwr_nightly-1.9.0.dev20240422.dist-info/METADATA,sha256=2g_AiXLNJzV4x9RNTWo1h1LjzMpUdhUQ8uNAPPxqlv8,15260
210
- flwr_nightly-1.9.0.dev20240422.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
211
- flwr_nightly-1.9.0.dev20240422.dist-info/entry_points.txt,sha256=DBrrf685V2W9NbbchQwvuqBEpj5ik8tMZNoZg_W2bZY,363
212
- flwr_nightly-1.9.0.dev20240422.dist-info/RECORD,,
212
+ flwr_nightly-1.9.0.dev20240424.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
213
+ flwr_nightly-1.9.0.dev20240424.dist-info/METADATA,sha256=mLtdOKDk566pd7SuCbhp37Rxrd8AdueX412yMXc20ic,15260
214
+ flwr_nightly-1.9.0.dev20240424.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
215
+ flwr_nightly-1.9.0.dev20240424.dist-info/entry_points.txt,sha256=DBrrf685V2W9NbbchQwvuqBEpj5ik8tMZNoZg_W2bZY,363
216
+ flwr_nightly-1.9.0.dev20240424.dist-info/RECORD,,