flwr-nightly 1.9.0.dev20240509__py3-none-any.whl → 1.9.0.dev20240531__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (40) hide show
  1. flwr/cli/build.py +2 -4
  2. flwr/cli/config_utils.py +1 -23
  3. flwr/cli/new/new.py +2 -0
  4. flwr/cli/new/templates/app/code/client.jax.py.tpl +55 -0
  5. flwr/cli/new/templates/app/code/server.jax.py.tpl +12 -0
  6. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  7. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +6 -0
  8. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +28 -0
  9. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +6 -0
  10. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +6 -0
  11. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +6 -0
  12. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +6 -0
  13. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +6 -0
  14. flwr/cli/run/run.py +20 -4
  15. flwr/client/grpc_rere_client/connection.py +5 -2
  16. flwr/client/mod/comms_mods.py +4 -4
  17. flwr/client/mod/localdp_mod.py +1 -2
  18. flwr/client/supernode/app.py +41 -23
  19. flwr/common/recordset_compat.py +8 -1
  20. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +0 -15
  21. flwr/proto/grpcadapter_pb2.py +32 -0
  22. flwr/proto/grpcadapter_pb2.pyi +43 -0
  23. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  24. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  25. flwr/server/app.py +134 -182
  26. flwr/server/driver/__init__.py +3 -2
  27. flwr/server/driver/inmemory_driver.py +181 -0
  28. flwr/server/server.py +9 -2
  29. flwr/server/strategy/dp_adaptive_clipping.py +2 -4
  30. flwr/server/strategy/dp_fixed_clipping.py +2 -4
  31. flwr/server/superlink/driver/driver_servicer.py +2 -2
  32. flwr/server/superlink/fleet/vce/backend/raybackend.py +8 -3
  33. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  34. flwr/server/workflow/default_workflows.py +67 -22
  35. flwr/simulation/run_simulation.py +2 -31
  36. {flwr_nightly-1.9.0.dev20240509.dist-info → flwr_nightly-1.9.0.dev20240531.dist-info}/METADATA +5 -3
  37. {flwr_nightly-1.9.0.dev20240509.dist-info → flwr_nightly-1.9.0.dev20240531.dist-info}/RECORD +40 -31
  38. {flwr_nightly-1.9.0.dev20240509.dist-info → flwr_nightly-1.9.0.dev20240531.dist-info}/LICENSE +0 -0
  39. {flwr_nightly-1.9.0.dev20240509.dist-info → flwr_nightly-1.9.0.dev20240531.dist-info}/WHEEL +0 -0
  40. {flwr_nightly-1.9.0.dev20240509.dist-info → flwr_nightly-1.9.0.dev20240531.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py CHANGED
@@ -24,7 +24,7 @@ import pathspec
24
24
  import typer
25
25
  from typing_extensions import Annotated
26
26
 
27
- from .config_utils import load_and_validate_with_defaults
27
+ from .config_utils import load_and_validate
28
28
  from .utils import is_valid_project_name
29
29
 
30
30
 
@@ -67,9 +67,7 @@ def build(
67
67
  )
68
68
  raise typer.Exit(code=1)
69
69
 
70
- conf, errors, warnings = load_and_validate_with_defaults(
71
- directory / "pyproject.toml"
72
- )
70
+ conf, errors, warnings = load_and_validate(directory / "pyproject.toml")
73
71
  if conf is None:
74
72
  typer.secho(
75
73
  "Project configuration could not be loaded.\npyproject.toml is invalid:\n"
flwr/cli/config_utils.py CHANGED
@@ -22,7 +22,7 @@ import tomli
22
22
  from flwr.common import object_ref
23
23
 
24
24
 
25
- def load_and_validate_with_defaults(
25
+ def load_and_validate(
26
26
  path: Optional[Path] = None,
27
27
  ) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]:
28
28
  """Load and validate pyproject.toml as dict.
@@ -47,14 +47,6 @@ def load_and_validate_with_defaults(
47
47
  if not is_valid:
48
48
  return (None, errors, warnings)
49
49
 
50
- # Apply defaults
51
- defaults = {
52
- "flower": {
53
- "engine": {"name": "simulation", "simulation": {"supernode": {"num": 2}}}
54
- }
55
- }
56
- config = apply_defaults(config, defaults)
57
-
58
50
  return (config, errors, warnings)
59
51
 
60
52
 
@@ -129,17 +121,3 @@ def validate(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
129
121
  return False, [reason], []
130
122
 
131
123
  return True, [], []
132
-
133
-
134
- def apply_defaults(
135
- config: Dict[str, Any],
136
- defaults: Dict[str, Any],
137
- ) -> Dict[str, Any]:
138
- """Apply defaults to config."""
139
- for key in defaults:
140
- if key in config:
141
- if isinstance(config[key], dict) and isinstance(defaults[key], dict):
142
- apply_defaults(config[key], defaults[key])
143
- else:
144
- config[key] = defaults[key]
145
- return config
flwr/cli/new/new.py CHANGED
@@ -37,6 +37,7 @@ class MlFramework(str, Enum):
37
37
  NUMPY = "NumPy"
38
38
  PYTORCH = "PyTorch"
39
39
  TENSORFLOW = "TensorFlow"
40
+ JAX = "JAX"
40
41
  HUGGINGFACE = "HF"
41
42
  MLX = "MLX"
42
43
  SKLEARN = "sklearn"
@@ -155,6 +156,7 @@ def new(
155
156
  # Depending on the framework, generate task.py file
156
157
  frameworks_with_tasks = [
157
158
  MlFramework.PYTORCH.value.lower(),
159
+ MlFramework.JAX.value.lower(),
158
160
  MlFramework.HUGGINGFACE.value.lower(),
159
161
  MlFramework.MLX.value.lower(),
160
162
  MlFramework.TENSORFLOW.value.lower(),
@@ -0,0 +1,55 @@
1
+ """$project_name: A Flower / JAX app."""
2
+
3
+ import jax
4
+ from flwr.client import NumPyClient, ClientApp
5
+
6
+ from $import_name.task import (
7
+ evaluation,
8
+ get_params,
9
+ load_data,
10
+ load_model,
11
+ loss_fn,
12
+ set_params,
13
+ train,
14
+ )
15
+
16
+
17
+ # Define Flower Client and client_fn
18
+ class FlowerClient(NumPyClient):
19
+ def __init__(self):
20
+ self.train_x, self.train_y, self.test_x, self.test_y = load_data()
21
+ self.grad_fn = jax.grad(loss_fn)
22
+ model_shape = self.train_x.shape[1:]
23
+
24
+ self.params = load_model(model_shape)
25
+
26
+ def get_parameters(self, config):
27
+ return get_params(self.params)
28
+
29
+ def set_parameters(self, parameters):
30
+ set_params(self.params, parameters)
31
+
32
+ def fit(self, parameters, config):
33
+ self.set_parameters(parameters)
34
+ self.params, loss, num_examples = train(
35
+ self.params, self.grad_fn, self.train_x, self.train_y
36
+ )
37
+ parameters = self.get_parameters(config={})
38
+ return parameters, num_examples, {"loss": float(loss)}
39
+
40
+ def evaluate(self, parameters, config):
41
+ self.set_parameters(parameters)
42
+ loss, num_examples = evaluation(
43
+ self.params, self.grad_fn, self.test_x, self.test_y
44
+ )
45
+ return float(loss), num_examples, {"loss": float(loss)}
46
+
47
+ def client_fn(cid):
48
+ # Return Client instance
49
+ return FlowerClient().to_client()
50
+
51
+
52
+ # Flower ClientApp
53
+ app = ClientApp(
54
+ client_fn,
55
+ )
@@ -0,0 +1,12 @@
1
+ """$project_name: A Flower / JAX app."""
2
+
3
+ import flwr as fl
4
+
5
+ # Configure the strategy
6
+ strategy = fl.server.strategy.FedAvg()
7
+
8
+ # Flower ServerApp
9
+ app = fl.server.ServerApp(
10
+ config=fl.server.ServerConfig(num_rounds=3),
11
+ strategy=strategy,
12
+ )
@@ -0,0 +1,57 @@
1
+ """$project_name: A Flower / JAX app."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from sklearn.datasets import make_regression
6
+ from sklearn.model_selection import train_test_split
7
+ import numpy as np
8
+
9
+ key = jax.random.PRNGKey(0)
10
+
11
+
12
+ def load_data():
13
+ # Load dataset
14
+ X, y = make_regression(n_features=3, random_state=0)
15
+ X, X_test, y, y_test = train_test_split(X, y)
16
+ return X, y, X_test, y_test
17
+
18
+
19
+ def load_model(model_shape):
20
+ # Extract model parameters
21
+ params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
22
+ return params
23
+
24
+
25
+ def loss_fn(params, X, y):
26
+ # Return MSE as loss
27
+ err = jnp.dot(X, params["w"]) + params["b"] - y
28
+ return jnp.mean(jnp.square(err))
29
+
30
+
31
+ def train(params, grad_fn, X, y):
32
+ loss = 1_000_000
33
+ num_examples = X.shape[0]
34
+ for epochs in range(50):
35
+ grads = grad_fn(params, X, y)
36
+ params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
37
+ loss = loss_fn(params, X, y)
38
+ return params, loss, num_examples
39
+
40
+
41
+ def evaluation(params, grad_fn, X_test, y_test):
42
+ num_examples = X_test.shape[0]
43
+ err_test = loss_fn(params, X_test, y_test)
44
+ loss_test = jnp.mean(jnp.square(err_test))
45
+ return loss_test, num_examples
46
+
47
+
48
+ def get_params(params):
49
+ parameters = []
50
+ for _, val in params.items():
51
+ parameters.append(np.array(val))
52
+ return parameters
53
+
54
+
55
+ def set_params(local_params, global_params):
56
+ for key, value in list(zip(local_params.keys(), global_params)):
57
+ local_params[key] = value
@@ -29,3 +29,9 @@ publisher = "$username"
29
29
  [flower.components]
30
30
  serverapp = "$import_name.server:app"
31
31
  clientapp = "$import_name.client:app"
32
+
33
+ [flower.engine]
34
+ name = "simulation"
35
+
36
+ [flower.engine.simulation.supernode]
37
+ num = 2
@@ -0,0 +1,28 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = {text = "Apache License (2.0)"}
13
+ dependencies = [
14
+ "flwr[simulation]>=1.8.0,<2.0",
15
+ "jax==0.4.26",
16
+ "jaxlib==0.4.26",
17
+ "scikit-learn==1.4.2",
18
+ ]
19
+
20
+ [tool.hatch.build.targets.wheel]
21
+ packages = ["."]
22
+
23
+ [flower]
24
+ publisher = "$username"
25
+
26
+ [flower.components]
27
+ serverapp = "$import_name.server:app"
28
+ clientapp = "$import_name.client:app"
@@ -26,3 +26,9 @@ publisher = "$username"
26
26
  [flower.components]
27
27
  serverapp = "$import_name.server:app"
28
28
  clientapp = "$import_name.client:app"
29
+
30
+ [flower.engine]
31
+ name = "simulation"
32
+
33
+ [flower.engine.simulation.supernode]
34
+ num = 2
@@ -24,3 +24,9 @@ publisher = "$username"
24
24
  [flower.components]
25
25
  serverapp = "$import_name.server:app"
26
26
  clientapp = "$import_name.client:app"
27
+
28
+ [flower.engine]
29
+ name = "simulation"
30
+
31
+ [flower.engine.simulation.supernode]
32
+ num = 2
@@ -26,3 +26,9 @@ publisher = "$username"
26
26
  [flower.components]
27
27
  serverapp = "$import_name.server:app"
28
28
  clientapp = "$import_name.client:app"
29
+
30
+ [flower.engine]
31
+ name = "simulation"
32
+
33
+ [flower.engine.simulation.supernode]
34
+ num = 2
@@ -25,3 +25,9 @@ publisher = "$username"
25
25
  [flower.components]
26
26
  serverapp = "$import_name.server:app"
27
27
  clientapp = "$import_name.client:app"
28
+
29
+ [flower.engine]
30
+ name = "simulation"
31
+
32
+ [flower.engine.simulation.supernode]
33
+ num = 2
@@ -25,3 +25,9 @@ publisher = "$username"
25
25
  [flower.components]
26
26
  serverapp = "$import_name.server:app"
27
27
  clientapp = "$import_name.client:app"
28
+
29
+ [flower.engine]
30
+ name = "simulation"
31
+
32
+ [flower.engine.simulation.supernode]
33
+ num = 2
flwr/cli/run/run.py CHANGED
@@ -15,18 +15,32 @@
15
15
  """Flower command line interface `run` command."""
16
16
 
17
17
  import sys
18
+ from enum import Enum
19
+ from typing import Optional
18
20
 
19
21
  import typer
22
+ from typing_extensions import Annotated
20
23
 
21
24
  from flwr.cli import config_utils
22
25
  from flwr.simulation.run_simulation import _run_simulation
23
26
 
24
27
 
25
- def run() -> None:
28
+ class Engine(str, Enum):
29
+ """Enum defining the engine to run on."""
30
+
31
+ SIMULATION = "simulation"
32
+
33
+
34
+ def run(
35
+ engine: Annotated[
36
+ Optional[Engine],
37
+ typer.Option(case_sensitive=False, help="The ML framework to use"),
38
+ ] = None,
39
+ ) -> None:
26
40
  """Run Flower project."""
27
41
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
28
42
 
29
- config, errors, warnings = config_utils.load_and_validate_with_defaults()
43
+ config, errors, warnings = config_utils.load_and_validate()
30
44
 
31
45
  if config is None:
32
46
  typer.secho(
@@ -49,9 +63,11 @@ def run() -> None:
49
63
 
50
64
  server_app_ref = config["flower"]["components"]["serverapp"]
51
65
  client_app_ref = config["flower"]["components"]["clientapp"]
52
- engine = config["flower"]["engine"]["name"]
53
66
 
54
- if engine == "simulation":
67
+ if engine is None:
68
+ engine = config["flower"]["engine"]["name"]
69
+
70
+ if engine == Engine.SIMULATION:
55
71
  num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
56
72
 
57
73
  typer.secho("Starting run... ", fg=typer.colors.BLUE)
@@ -21,7 +21,7 @@ from contextlib import contextmanager
21
21
  from copy import copy
22
22
  from logging import DEBUG, ERROR
23
23
  from pathlib import Path
24
- from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
24
+ from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast
25
25
 
26
26
  import grpc
27
27
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -73,6 +73,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
73
73
  authentication_keys: Optional[
74
74
  Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
75
75
  ] = None,
76
+ adapter_cls: Optional[Type[FleetStub]] = None,
76
77
  ) -> Iterator[
77
78
  Tuple[
78
79
  Callable[[], Optional[Message]],
@@ -133,7 +134,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
133
134
  channel.subscribe(on_channel_state_change)
134
135
 
135
136
  # Shared variables for inner functions
136
- stub = FleetStub(channel)
137
+ if adapter_cls is None:
138
+ adapter_cls = FleetStub
139
+ stub = adapter_cls(channel)
137
140
  metadata: Optional[Metadata] = None
138
141
  node: Optional[Node] = None
139
142
  ping_thread: Optional[threading.Thread] = None
@@ -29,7 +29,7 @@ def message_size_mod(
29
29
  ) -> Message:
30
30
  """Message size mod.
31
31
 
32
- This mod logs the size in Bytes of the message being transmited.
32
+ This mod logs the size in bytes of the message being transmited.
33
33
  """
34
34
  message_size_in_bytes = 0
35
35
 
@@ -42,7 +42,7 @@ def message_size_mod(
42
42
  for m_record in msg.content.metrics_records.values():
43
43
  message_size_in_bytes += m_record.count_bytes()
44
44
 
45
- log(INFO, "Message size: %i Bytes", message_size_in_bytes)
45
+ log(INFO, "Message size: %i bytes", message_size_in_bytes)
46
46
 
47
47
  return call_next(msg, ctxt)
48
48
 
@@ -53,7 +53,7 @@ def parameters_size_mod(
53
53
  """Parameters size mod.
54
54
 
55
55
  This mod logs the number of parameters transmitted in the message as well as their
56
- size in Bytes.
56
+ size in bytes.
57
57
  """
58
58
  model_size_stats = {}
59
59
  parameters_size_in_bytes = 0
@@ -74,6 +74,6 @@ def parameters_size_mod(
74
74
  if model_size_stats:
75
75
  log(INFO, model_size_stats)
76
76
 
77
- log(INFO, "Total parameters transmited: %i Bytes", parameters_size_in_bytes)
77
+ log(INFO, "Total parameters transmitted: %i bytes", parameters_size_in_bytes)
78
78
 
79
79
  return call_next(msg, ctxt)
@@ -145,8 +145,7 @@ class LocalDpMod:
145
145
  )
146
146
  log(
147
147
  INFO,
148
- "LocalDpMod: local DP noise with "
149
- "standard deviation: %.4f added to parameters.",
148
+ "LocalDpMod: local DP noise with %.4f stedv added to parameters",
150
149
  noise_value_sd,
151
150
  )
152
151
 
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO, WARN
20
20
  from pathlib import Path
21
21
  from typing import Callable, Optional, Tuple
22
22
 
23
+ from cryptography.exceptions import UnsupportedAlgorithm
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
  from cryptography.hazmat.primitives.serialization import (
25
26
  load_ssh_private_key,
@@ -31,9 +32,6 @@ from flwr.common import EventType, event
31
32
  from flwr.common.exit_handlers import register_exit_handlers
32
33
  from flwr.common.logger import log
33
34
  from flwr.common.object_ref import load_app, validate
34
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
35
- ssh_types_to_elliptic_curve,
36
- )
37
35
 
38
36
  from ..app import _start_client_internal
39
37
 
@@ -242,40 +240,60 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
242
240
  " Default: current working directory.",
243
241
  )
244
242
  parser.add_argument(
245
- "--authentication-keys",
246
- nargs=2,
247
- metavar=("CLIENT_PRIVATE_KEY", "CLIENT_PUBLIC_KEY"),
243
+ "--auth-supernode-private-key",
244
+ type=str,
245
+ help="The SuperNode's private key (as a path str) to enable authentication.",
246
+ )
247
+ parser.add_argument(
248
+ "--auth-supernode-public-key",
248
249
  type=str,
249
- help="Provide two file paths: (1) the client's private "
250
- "key file, and (2) the client's public key file.",
250
+ help="The SuperNode's public key (as a path str) to enable authentication.",
251
251
  )
252
252
 
253
253
 
254
254
  def _try_setup_client_authentication(
255
255
  args: argparse.Namespace,
256
256
  ) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
257
- if not args.authentication_keys:
257
+ if not args.auth_supernode_private_key and not args.auth_supernode_public_key:
258
258
  return None
259
259
 
260
- ssh_private_key = load_ssh_private_key(
261
- Path(args.authentication_keys[0]).read_bytes(),
262
- None,
263
- )
264
- ssh_public_key = load_ssh_public_key(Path(args.authentication_keys[1]).read_bytes())
260
+ if not args.auth_supernode_private_key or not args.auth_supernode_public_key:
261
+ sys.exit(
262
+ "Authentication requires file paths to both "
263
+ "'--auth-supernode-private-key' and '--auth-supernode-public-key'"
264
+ "to be provided (providing only one of them is not sufficient)."
265
+ )
266
+
267
+ try:
268
+ ssh_private_key = load_ssh_private_key(
269
+ Path(args.auth_supernode_private_key).read_bytes(),
270
+ None,
271
+ )
272
+ if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
273
+ raise ValueError()
274
+ except (ValueError, UnsupportedAlgorithm):
275
+ sys.exit(
276
+ "Error: Unable to parse the private key file in "
277
+ "'--auth-supernode-private-key'. Authentication requires elliptic "
278
+ "curve private and public key pair. Please ensure that the file "
279
+ "path points to a valid private key file and try again."
280
+ )
265
281
 
266
282
  try:
267
- client_private_key, client_public_key = ssh_types_to_elliptic_curve(
268
- ssh_private_key, ssh_public_key
283
+ ssh_public_key = load_ssh_public_key(
284
+ Path(args.auth_supernode_public_key).read_bytes()
269
285
  )
270
- except TypeError:
286
+ if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
287
+ raise ValueError()
288
+ except (ValueError, UnsupportedAlgorithm):
271
289
  sys.exit(
272
- "The file paths provided could not be read as a private and public "
273
- "key pair. Client authentication requires an elliptic curve public and "
274
- "private key pair. Please provide the file paths containing elliptic "
275
- "curve private and public keys to '--authentication-keys'."
290
+ "Error: Unable to parse the public key file in "
291
+ "'--auth-supernode-public-key'. Authentication requires elliptic "
292
+ "curve private and public key pair. Please ensure that the file "
293
+ "path points to a valid public key file and try again."
276
294
  )
277
295
 
278
296
  return (
279
- client_private_key,
280
- client_public_key,
297
+ ssh_private_key,
298
+ ssh_public_key,
281
299
  )
@@ -35,6 +35,8 @@ from .typing import (
35
35
  Status,
36
36
  )
37
37
 
38
+ EMPTY_TENSOR_KEY = "_empty"
39
+
38
40
 
39
41
  def parametersrecord_to_parameters(
40
42
  record: ParametersRecord, keep_input: bool
@@ -59,7 +61,8 @@ def parametersrecord_to_parameters(
59
61
  parameters = Parameters(tensors=[], tensor_type="")
60
62
 
61
63
  for key in list(record.keys()):
62
- parameters.tensors.append(record[key].data)
64
+ if key != EMPTY_TENSOR_KEY:
65
+ parameters.tensors.append(record[key].data)
63
66
 
64
67
  if not parameters.tensor_type:
65
68
  # Setting from first array in record. Recall the warning in the docstrings
@@ -103,6 +106,10 @@ def parameters_to_parametersrecord(
103
106
  data=tensor, dtype="", stype=tensor_type, shape=[]
104
107
  )
105
108
 
109
+ if num_arrays == 0:
110
+ ordered_dict[EMPTY_TENSOR_KEY] = Array(
111
+ data=b"", dtype="", stype=tensor_type, shape=[]
112
+ )
106
113
  return ParametersRecord(ordered_dict, keep_input=keep_input)
107
114
 
108
115
 
@@ -117,18 +117,3 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
117
117
  return True
118
118
  except InvalidSignature:
119
119
  return False
120
-
121
-
122
- def ssh_types_to_elliptic_curve(
123
- private_key: serialization.SSHPrivateKeyTypes,
124
- public_key: serialization.SSHPublicKeyTypes,
125
- ) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
126
- """Cast SSH key types to elliptic curve."""
127
- if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
128
- public_key, ec.EllipticCurvePublicKey
129
- ):
130
- return (private_key, public_key)
131
-
132
- raise TypeError(
133
- "The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
134
- )
@@ -0,0 +1,32 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: flwr/proto/grpcadapter.proto
4
+ # Protobuf Python Version: 4.25.0
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/grpcadapter.proto\x12\nflwr.proto\"\xba\x01\n\x10MessageContainer\x12<\n\x08metadata\x18\x01 \x03(\x0b\x32*.flwr.proto.MessageContainer.MetadataEntry\x12\x19\n\x11grpc_message_name\x18\x02 \x01(\t\x12\x1c\n\x14grpc_message_content\x18\x03 \x01(\x0c\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x32Z\n\x0bGrpcAdapter\x12K\n\x0bSendReceive\x12\x1c.flwr.proto.MessageContainer\x1a\x1c.flwr.proto.MessageContainer\"\x00\x62\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.grpcadapter_pb2', _globals)
22
+ if _descriptor._USE_C_DESCRIPTORS == False:
23
+ DESCRIPTOR._options = None
24
+ _globals['_MESSAGECONTAINER_METADATAENTRY']._options = None
25
+ _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_options = b'8\001'
26
+ _globals['_MESSAGECONTAINER']._serialized_start=45
27
+ _globals['_MESSAGECONTAINER']._serialized_end=231
28
+ _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_start=184
29
+ _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_end=231
30
+ _globals['_GRPCADAPTER']._serialized_start=233
31
+ _globals['_GRPCADAPTER']._serialized_end=323
32
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,43 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import builtins
6
+ import google.protobuf.descriptor
7
+ import google.protobuf.internal.containers
8
+ import google.protobuf.message
9
+ import typing
10
+ import typing_extensions
11
+
12
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
13
+
14
+ class MessageContainer(google.protobuf.message.Message):
15
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
16
+ class MetadataEntry(google.protobuf.message.Message):
17
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
18
+ KEY_FIELD_NUMBER: builtins.int
19
+ VALUE_FIELD_NUMBER: builtins.int
20
+ key: typing.Text
21
+ value: typing.Text
22
+ def __init__(self,
23
+ *,
24
+ key: typing.Text = ...,
25
+ value: typing.Text = ...,
26
+ ) -> None: ...
27
+ def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
28
+
29
+ METADATA_FIELD_NUMBER: builtins.int
30
+ GRPC_MESSAGE_NAME_FIELD_NUMBER: builtins.int
31
+ GRPC_MESSAGE_CONTENT_FIELD_NUMBER: builtins.int
32
+ @property
33
+ def metadata(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ...
34
+ grpc_message_name: typing.Text
35
+ grpc_message_content: builtins.bytes
36
+ def __init__(self,
37
+ *,
38
+ metadata: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ...,
39
+ grpc_message_name: typing.Text = ...,
40
+ grpc_message_content: builtins.bytes = ...,
41
+ ) -> None: ...
42
+ def ClearField(self, field_name: typing_extensions.Literal["grpc_message_content",b"grpc_message_content","grpc_message_name",b"grpc_message_name","metadata",b"metadata"]) -> None: ...
43
+ global___MessageContainer = MessageContainer