flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240509__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (71) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +18 -46
  4. flwr/cli/new/new.py +44 -18
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
  6. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  7. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  8. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  9. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
  10. flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
  11. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  12. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  14. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
  15. flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
  16. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  17. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  18. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
  19. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  20. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  21. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  22. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  23. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  24. flwr/cli/run/run.py +1 -1
  25. flwr/cli/utils.py +18 -17
  26. flwr/client/__init__.py +1 -1
  27. flwr/client/app.py +17 -93
  28. flwr/client/grpc_client/connection.py +6 -1
  29. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  30. flwr/client/grpc_rere_client/connection.py +17 -2
  31. flwr/client/mod/centraldp_mods.py +4 -2
  32. flwr/client/mod/localdp_mod.py +9 -3
  33. flwr/client/rest_client/connection.py +5 -1
  34. flwr/client/supernode/__init__.py +2 -0
  35. flwr/client/supernode/app.py +181 -7
  36. flwr/common/grpc.py +5 -1
  37. flwr/common/logger.py +37 -4
  38. flwr/common/message.py +105 -86
  39. flwr/common/record/parametersrecord.py +0 -1
  40. flwr/common/record/recordset.py +17 -5
  41. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  42. flwr/server/__init__.py +0 -2
  43. flwr/server/app.py +118 -2
  44. flwr/server/compat/app.py +5 -56
  45. flwr/server/compat/app_utils.py +1 -1
  46. flwr/server/compat/driver_client_proxy.py +27 -72
  47. flwr/server/driver/__init__.py +3 -0
  48. flwr/server/driver/driver.py +12 -242
  49. flwr/server/driver/grpc_driver.py +315 -0
  50. flwr/server/history.py +20 -20
  51. flwr/server/run_serverapp.py +18 -4
  52. flwr/server/server.py +2 -5
  53. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  54. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  55. flwr/server/superlink/driver/driver_servicer.py +1 -1
  56. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  57. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  58. flwr/server/superlink/fleet/vce/backend/raybackend.py +9 -6
  59. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  60. flwr/server/superlink/state/in_memory_state.py +76 -8
  61. flwr/server/superlink/state/sqlite_state.py +116 -11
  62. flwr/server/superlink/state/state.py +35 -3
  63. flwr/simulation/__init__.py +2 -2
  64. flwr/simulation/app.py +16 -1
  65. flwr/simulation/run_simulation.py +14 -9
  66. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/METADATA +3 -2
  67. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/RECORD +70 -55
  68. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/entry_points.txt +1 -1
  69. flwr/server/driver/abc_driver.py +0 -140
  70. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/LICENSE +0 -0
  71. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/WHEEL +0 -0
@@ -24,6 +24,7 @@ from .parametersrecord import ParametersRecord
24
24
  from .typeddict import TypedDict
25
25
 
26
26
 
27
+ @dataclass
27
28
  class RecordSetData:
28
29
  """Inner data container for the RecordSet class."""
29
30
 
@@ -82,7 +83,6 @@ class RecordSetData:
82
83
  )
83
84
 
84
85
 
85
- @dataclass
86
86
  class RecordSet:
87
87
  """RecordSet stores groups of parameters, metrics and configs."""
88
88
 
@@ -97,22 +97,34 @@ class RecordSet:
97
97
  metrics_records=metrics_records,
98
98
  configs_records=configs_records,
99
99
  )
100
- setattr(self, "_data", data) # noqa
100
+ self.__dict__["_data"] = data
101
101
 
102
102
  @property
103
103
  def parameters_records(self) -> TypedDict[str, ParametersRecord]:
104
104
  """Dictionary holding ParametersRecord instances."""
105
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
105
+ data = cast(RecordSetData, self.__dict__["_data"])
106
106
  return data.parameters_records
107
107
 
108
108
  @property
109
109
  def metrics_records(self) -> TypedDict[str, MetricsRecord]:
110
110
  """Dictionary holding MetricsRecord instances."""
111
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
111
+ data = cast(RecordSetData, self.__dict__["_data"])
112
112
  return data.metrics_records
113
113
 
114
114
  @property
115
115
  def configs_records(self) -> TypedDict[str, ConfigsRecord]:
116
116
  """Dictionary holding ConfigsRecord instances."""
117
- data = cast(RecordSetData, getattr(self, "_data")) # noqa
117
+ data = cast(RecordSetData, self.__dict__["_data"])
118
118
  return data.configs_records
119
+
120
+ def __repr__(self) -> str:
121
+ """Return a string representation of this instance."""
122
+ flds = ("parameters_records", "metrics_records", "configs_records")
123
+ view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
124
+ return f"{self.__class__.__qualname__}({view})"
125
+
126
+ def __eq__(self, other: object) -> bool:
127
+ """Compare two instances of the class."""
128
+ if not isinstance(other, self.__class__):
129
+ raise NotImplementedError
130
+ return self.__dict__ == other.__dict__
@@ -18,8 +18,9 @@
18
18
  import base64
19
19
  from typing import Tuple, cast
20
20
 
21
+ from cryptography.exceptions import InvalidSignature
21
22
  from cryptography.fernet import Fernet
22
- from cryptography.hazmat.primitives import hashes, serialization
23
+ from cryptography.hazmat.primitives import hashes, hmac, serialization
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
  from cryptography.hazmat.primitives.kdf.hkdf import HKDF
25
26
 
@@ -98,3 +99,36 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
98
99
  # The input key must be url safe
99
100
  fernet = Fernet(key)
100
101
  return fernet.decrypt(ciphertext)
102
+
103
+
104
+ def compute_hmac(key: bytes, message: bytes) -> bytes:
105
+ """Compute hmac of a message using key as hash."""
106
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
107
+ computed_hmac.update(message)
108
+ return computed_hmac.finalize()
109
+
110
+
111
+ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
112
+ """Verify hmac of a message using key as hash."""
113
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
114
+ computed_hmac.update(message)
115
+ try:
116
+ computed_hmac.verify(hmac_value)
117
+ return True
118
+ except InvalidSignature:
119
+ return False
120
+
121
+
122
+ def ssh_types_to_elliptic_curve(
123
+ private_key: serialization.SSHPrivateKeyTypes,
124
+ public_key: serialization.SSHPublicKeyTypes,
125
+ ) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
126
+ """Cast SSH key types to elliptic curve."""
127
+ if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
128
+ public_key, ec.EllipticCurvePublicKey
129
+ ):
130
+ return (private_key, public_key)
131
+
132
+ raise TypeError(
133
+ "The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
134
+ )
flwr/server/__init__.py CHANGED
@@ -24,7 +24,6 @@ from .app import start_server as start_server
24
24
  from .client_manager import ClientManager as ClientManager
25
25
  from .client_manager import SimpleClientManager as SimpleClientManager
26
26
  from .compat import LegacyContext as LegacyContext
27
- from .compat import start_driver as start_driver
28
27
  from .driver import Driver as Driver
29
28
  from .history import History as History
30
29
  from .run_serverapp import run_server_app as run_server_app
@@ -45,7 +44,6 @@ __all__ = [
45
44
  "ServerApp",
46
45
  "ServerConfig",
47
46
  "SimpleClientManager",
48
- "start_driver",
49
47
  "start_server",
50
48
  "strategy",
51
49
  "workflow",
flwr/server/app.py CHANGED
@@ -16,15 +16,21 @@
16
16
 
17
17
  import argparse
18
18
  import asyncio
19
+ import csv
19
20
  import importlib.util
20
21
  import sys
21
22
  import threading
22
23
  from logging import ERROR, INFO, WARN
23
24
  from os.path import isfile
24
25
  from pathlib import Path
25
- from typing import List, Optional, Tuple
26
+ from typing import List, Optional, Sequence, Set, Tuple
26
27
 
27
28
  import grpc
29
+ from cryptography.hazmat.primitives.asymmetric import ec
30
+ from cryptography.hazmat.primitives.serialization import (
31
+ load_ssh_private_key,
32
+ load_ssh_public_key,
33
+ )
28
34
 
29
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
30
36
  from flwr.common.address import parse_address
@@ -35,7 +41,12 @@ from flwr.common.constant import (
35
41
  TRANSPORT_TYPE_VCE,
36
42
  )
37
43
  from flwr.common.exit_handlers import register_exit_handlers
38
- from flwr.common.logger import log
44
+ from flwr.common.logger import log, warn_deprecated_feature
45
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
46
+ private_key_to_bytes,
47
+ public_key_to_bytes,
48
+ ssh_types_to_elliptic_curve,
49
+ )
39
50
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
40
51
  add_FleetServicer_to_server,
41
52
  )
@@ -51,6 +62,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
51
62
  start_grpc_server,
52
63
  )
53
64
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
65
+ from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
54
66
  from .superlink.fleet.vce import start_vce
55
67
  from .superlink.state import StateFactory
56
68
 
@@ -184,6 +196,9 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
184
196
  def run_driver_api() -> None:
185
197
  """Run Flower server (Driver API)."""
186
198
  log(INFO, "Starting Flower server (Driver API)")
199
+ # Running `flower-driver-api` is deprecated
200
+ warn_deprecated_feature("flower-driver-api")
201
+ log(WARN, "Use `flower-superlink` instead")
187
202
  event(EventType.RUN_DRIVER_API_ENTER)
188
203
  args = _parse_args_run_driver_api().parse_args()
189
204
 
@@ -221,6 +236,9 @@ def run_driver_api() -> None:
221
236
  def run_fleet_api() -> None:
222
237
  """Run Flower server (Fleet API)."""
223
238
  log(INFO, "Starting Flower server (Fleet API)")
239
+ # Running `flower-fleet-api` is deprecated
240
+ warn_deprecated_feature("flower-fleet-api")
241
+ log(WARN, "Use `flower-superlink` instead")
224
242
  event(EventType.RUN_FLEET_API_ENTER)
225
243
  args = _parse_args_run_fleet_api().parse_args()
226
244
 
@@ -354,10 +372,33 @@ def run_superlink() -> None:
354
372
  sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
355
373
  host, port, is_v6 = parsed_address
356
374
  address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
375
+
376
+ maybe_keys = _try_setup_client_authentication(args, certificates)
377
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
378
+ if maybe_keys is not None:
379
+ (
380
+ client_public_keys,
381
+ server_private_key,
382
+ server_public_key,
383
+ ) = maybe_keys
384
+ state = state_factory.state()
385
+ state.store_client_public_keys(client_public_keys)
386
+ state.store_server_private_public_key(
387
+ private_key_to_bytes(server_private_key),
388
+ public_key_to_bytes(server_public_key),
389
+ )
390
+ log(
391
+ INFO,
392
+ "Client authentication enabled with %d known public keys",
393
+ len(client_public_keys),
394
+ )
395
+ interceptors = [AuthenticateServerInterceptor(state)]
396
+
357
397
  fleet_server = _run_fleet_api_grpc_rere(
358
398
  address=address,
359
399
  state_factory=state_factory,
360
400
  certificates=certificates,
401
+ interceptors=interceptors,
361
402
  )
362
403
  grpc_servers.append(fleet_server)
363
404
  elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
@@ -390,6 +431,70 @@ def run_superlink() -> None:
390
431
  driver_server.wait_for_termination(timeout=1)
391
432
 
392
433
 
434
+ def _try_setup_client_authentication(
435
+ args: argparse.Namespace,
436
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
437
+ ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
438
+ if not args.require_client_authentication:
439
+ return None
440
+
441
+ if certificates is None:
442
+ sys.exit(
443
+ "Client authentication only works over secure connections. "
444
+ "Please provide certificate paths using '--certificates' when "
445
+ "enabling '--require-client-authentication'."
446
+ )
447
+
448
+ client_keys_file_path = Path(args.require_client_authentication[0])
449
+ if not client_keys_file_path.exists():
450
+ sys.exit(
451
+ "The provided path to the client public keys CSV file does not exist: "
452
+ f"{client_keys_file_path}. "
453
+ "Please provide the CSV file path containing known client public keys "
454
+ "to '--require-client-authentication'."
455
+ )
456
+
457
+ client_public_keys: Set[bytes] = set()
458
+ ssh_private_key = load_ssh_private_key(
459
+ Path(args.require_client_authentication[1]).read_bytes(),
460
+ None,
461
+ )
462
+ ssh_public_key = load_ssh_public_key(
463
+ Path(args.require_client_authentication[2]).read_bytes()
464
+ )
465
+
466
+ try:
467
+ server_private_key, server_public_key = ssh_types_to_elliptic_curve(
468
+ ssh_private_key, ssh_public_key
469
+ )
470
+ except TypeError:
471
+ sys.exit(
472
+ "The file paths provided could not be read as a private and public "
473
+ "key pair. Client authentication requires an elliptic curve public and "
474
+ "private key pair. Please provide the file paths containing elliptic "
475
+ "curve private and public keys to '--require-client-authentication'."
476
+ )
477
+
478
+ with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
479
+ reader = csv.reader(csvfile)
480
+ for row in reader:
481
+ for element in row:
482
+ public_key = load_ssh_public_key(element.encode())
483
+ if isinstance(public_key, ec.EllipticCurvePublicKey):
484
+ client_public_keys.add(public_key_to_bytes(public_key))
485
+ else:
486
+ sys.exit(
487
+ "Error: Unable to parse the public keys in the .csv "
488
+ "file. Please ensure that the .csv file contains valid "
489
+ "SSH public keys and try again."
490
+ )
491
+ return (
492
+ client_public_keys,
493
+ server_private_key,
494
+ server_public_key,
495
+ )
496
+
497
+
393
498
  def _try_obtain_certificates(
394
499
  args: argparse.Namespace,
395
500
  ) -> Optional[Tuple[bytes, bytes, bytes]]:
@@ -417,6 +522,7 @@ def _run_fleet_api_grpc_rere(
417
522
  address: str,
418
523
  state_factory: StateFactory,
419
524
  certificates: Optional[Tuple[bytes, bytes, bytes]],
525
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
420
526
  ) -> grpc.Server:
421
527
  """Run Fleet API (gRPC, request-response)."""
422
528
  # Create Fleet API gRPC server
@@ -429,6 +535,7 @@ def _run_fleet_api_grpc_rere(
429
535
  server_address=address,
430
536
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
431
537
  certificates=certificates,
538
+ interceptors=interceptors,
432
539
  )
433
540
 
434
541
  log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
@@ -606,6 +713,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
606
713
  "Flower will just create a state in memory.",
607
714
  default=DATABASE,
608
715
  )
716
+ parser.add_argument(
717
+ "--require-client-authentication",
718
+ nargs=3,
719
+ metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
720
+ type=str,
721
+ help="Provide three file paths: (1) a .csv file containing a list of "
722
+ "known client public keys for authentication, (2) the server's private "
723
+ "key file, and (3) the server's public key file.",
724
+ )
609
725
 
610
726
 
611
727
  def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
flwr/server/compat/app.py CHANGED
@@ -15,14 +15,11 @@
15
15
  """Flower driver app."""
16
16
 
17
17
 
18
- import sys
19
18
  from logging import INFO
20
- from pathlib import Path
21
- from typing import Optional, Union
19
+ from typing import Optional
22
20
 
23
21
  from flwr.common import EventType, event
24
- from flwr.common.address import parse_address
25
- from flwr.common.logger import log, warn_deprecated_feature
22
+ from flwr.common.logger import log
26
23
  from flwr.server.client_manager import ClientManager
27
24
  from flwr.server.history import History
28
25
  from flwr.server.server import Server, init_defaults, run_fl
@@ -32,33 +29,21 @@ from flwr.server.strategy import Strategy
32
29
  from ..driver import Driver
33
30
  from .app_utils import start_update_client_manager_thread
34
31
 
35
- DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
36
-
37
- ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
38
- [Driver] Error: Not connected.
39
-
40
- Call `connect()` on the `Driver` instance before calling any of the other `Driver`
41
- methods.
42
- """
43
-
44
32
 
45
33
  def start_driver( # pylint: disable=too-many-arguments, too-many-locals
46
34
  *,
47
- server_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
35
+ driver: Driver,
48
36
  server: Optional[Server] = None,
49
37
  config: Optional[ServerConfig] = None,
50
38
  strategy: Optional[Strategy] = None,
51
39
  client_manager: Optional[ClientManager] = None,
52
- root_certificates: Optional[Union[bytes, str]] = None,
53
- driver: Optional[Driver] = None,
54
40
  ) -> History:
55
41
  """Start a Flower Driver API server.
56
42
 
57
43
  Parameters
58
44
  ----------
59
- server_address : Optional[str]
60
- The IPv4 or IPv6 address of the Driver API server.
61
- Defaults to `"[::]:8080"`.
45
+ driver : Driver
46
+ The Driver object to use.
62
47
  server : Optional[flwr.server.Server] (default: None)
63
48
  A server implementation, either `flwr.server.Server` or a subclass
64
49
  thereof. If no instance is provided, then `start_driver` will create
@@ -74,50 +59,14 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
74
59
  An implementation of the class `flwr.server.ClientManager`. If no
75
60
  implementation is provided, then `start_driver` will use
76
61
  `flwr.server.SimpleClientManager`.
77
- root_certificates : Optional[Union[bytes, str]] (default: None)
78
- The PEM-encoded root certificates as a byte string or a path string.
79
- If provided, a secure connection using the certificates will be
80
- established to an SSL-enabled Flower server.
81
- driver : Optional[Driver] (default: None)
82
- The Driver object to use.
83
62
 
84
63
  Returns
85
64
  -------
86
65
  hist : flwr.server.history.History
87
66
  Object containing training and evaluation metrics.
88
-
89
- Examples
90
- --------
91
- Starting a driver that connects to an insecure server:
92
-
93
- >>> start_driver()
94
-
95
- Starting a driver that connects to an SSL-enabled server:
96
-
97
- >>> start_driver(
98
- >>> root_certificates=Path("/crts/root.pem").read_bytes()
99
- >>> )
100
67
  """
101
68
  event(EventType.START_DRIVER_ENTER)
102
69
 
103
- if driver is None:
104
- # Not passing a `Driver` object is deprecated
105
- warn_deprecated_feature("start_driver")
106
-
107
- # Parse IP address
108
- parsed_address = parse_address(server_address)
109
- if not parsed_address:
110
- sys.exit(f"Server IP address ({server_address}) cannot be parsed.")
111
- host, port, is_v6 = parsed_address
112
- address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
113
-
114
- # Create the Driver
115
- if isinstance(root_certificates, str):
116
- root_certificates = Path(root_certificates).read_bytes()
117
- driver = Driver(
118
- driver_service_address=address, root_certificates=root_certificates
119
- )
120
-
121
70
  # Initialize the Driver API server and config
122
71
  initialized_server, initialized_config = init_defaults(
123
72
  server=server,
@@ -89,7 +89,7 @@ def _update_client_manager(
89
89
  for node_id in new_nodes:
90
90
  client_proxy = DriverClientProxy(
91
91
  node_id=node_id,
92
- driver=driver.grpc_driver_helper, # type: ignore
92
+ driver=driver,
93
93
  anonymous=False,
94
94
  run_id=driver.run_id, # type: ignore
95
95
  )
@@ -16,16 +16,14 @@
16
16
 
17
17
 
18
18
  import time
19
- from typing import List, Optional
19
+ from typing import Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
22
+ from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
- from flwr.common import serde
25
- from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
26
24
  from flwr.server.client_proxy import ClientProxy
27
25
 
28
- from ..driver.driver import GrpcDriverHelper
26
+ from ..driver.driver import Driver
29
27
 
30
28
  SLEEP_TIME = 1
31
29
 
@@ -33,9 +31,7 @@ SLEEP_TIME = 1
33
31
  class DriverClientProxy(ClientProxy):
34
32
  """Flower client proxy which delegates work using the Driver API."""
35
33
 
36
- def __init__(
37
- self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
38
- ):
34
+ def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
39
35
  super().__init__(str(node_id))
40
36
  self.node_id = node_id
41
37
  self.driver = driver
@@ -116,80 +112,39 @@ class DriverClientProxy(ClientProxy):
116
112
  timeout: Optional[float],
117
113
  group_id: Optional[int],
118
114
  ) -> RecordSet:
119
- task_ins = task_pb2.TaskIns( # pylint: disable=E1101
120
- task_id="",
121
- group_id=str(group_id) if group_id is not None else "",
122
- run_id=self.run_id,
123
- task=task_pb2.Task( # pylint: disable=E1101
124
- producer=node_pb2.Node( # pylint: disable=E1101
125
- node_id=0,
126
- anonymous=True,
127
- ),
128
- consumer=node_pb2.Node( # pylint: disable=E1101
129
- node_id=self.node_id,
130
- anonymous=self.anonymous,
131
- ),
132
- task_type=task_type,
133
- recordset=serde.recordset_to_proto(recordset),
134
- ttl=DEFAULT_TTL,
135
- ),
136
- )
137
-
138
- # This would normally be recorded upon common.Message creation
139
- # but this compatibility stack doesn't create Messages,
140
- # so we need to inject `created_at` manually (needed for
141
- # taskins validation by server.utils.validator)
142
- task_ins.task.created_at = time.time()
143
115
 
144
- push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
145
- task_ins_list=[task_ins]
116
+ # Create message
117
+ message = self.driver.create_message(
118
+ content=recordset,
119
+ message_type=task_type,
120
+ dst_node_id=self.node_id,
121
+ group_id=str(group_id) if group_id else "",
122
+ ttl=timeout,
146
123
  )
147
124
 
148
- # Send TaskIns to Driver API
149
- push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
150
-
151
- if len(push_task_ins_res.task_ids) != 1:
152
- raise ValueError("Unexpected number of task_ids")
125
+ # Push message
126
+ message_ids = list(self.driver.push_messages(messages=[message]))
127
+ if len(message_ids) != 1:
128
+ raise ValueError("Unexpected number of message_ids")
153
129
 
154
- task_id = push_task_ins_res.task_ids[0]
155
- if task_id == "":
156
- raise ValueError(f"Failed to schedule task for node {self.node_id}")
130
+ message_id = message_ids[0]
131
+ if message_id == "":
132
+ raise ValueError(f"Failed to send message to node {self.node_id}")
157
133
 
158
134
  if timeout:
159
135
  start_time = time.time()
160
136
 
161
137
  while True:
162
- pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101
163
- node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101
164
- task_ids=[task_id],
165
- )
166
-
167
- # Ask Driver API for TaskRes
168
- pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
169
-
170
- task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101
171
- pull_task_res_res.task_res_list
172
- )
173
- if len(task_res_list) == 1:
174
- task_res = task_res_list[0]
175
-
176
- # This will raise an Exception if task_res carries an `error`
177
- validate_task_res(task_res=task_res)
178
-
179
- return serde.recordset_from_proto(task_res.task.recordset)
138
+ messages = list(self.driver.pull_messages(message_ids))
139
+ if len(messages) == 1:
140
+ msg: Message = messages[0]
141
+ if msg.has_error():
142
+ raise ValueError(
143
+ f"Message contains an Error (reason: {msg.error.reason}). "
144
+ "It originated during client-side execution of a message."
145
+ )
146
+ return msg.content
180
147
 
181
148
  if timeout is not None and time.time() > start_time + timeout:
182
149
  raise RuntimeError("Timeout reached")
183
150
  time.sleep(SLEEP_TIME)
184
-
185
-
186
- def validate_task_res(
187
- task_res: task_pb2.TaskRes, # pylint: disable=E1101
188
- ) -> None:
189
- """Validate if a TaskRes is empty or not."""
190
- if not task_res.HasField("task"):
191
- raise ValueError("Invalid TaskRes, field `task` missing")
192
- if task_res.task.HasField("error"):
193
- raise ValueError("Exception during client-side task execution")
194
- if not task_res.task.HasField("recordset"):
195
- raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
@@ -16,7 +16,10 @@
16
16
 
17
17
 
18
18
  from .driver import Driver
19
+ from .grpc_driver import GrpcDriver, GrpcDriverHelper
19
20
 
20
21
  __all__ = [
21
22
  "Driver",
23
+ "GrpcDriver",
24
+ "GrpcDriverHelper",
22
25
  ]