flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (66) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +19 -14
  4. flwr/cli/new/new.py +51 -22
  5. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  6. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  7. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  8. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  9. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
  10. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  11. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  13. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
  14. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  15. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  16. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  17. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  18. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  19. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  20. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  21. flwr/cli/run/run.py +1 -1
  22. flwr/cli/utils.py +18 -17
  23. flwr/client/__init__.py +3 -1
  24. flwr/client/app.py +20 -142
  25. flwr/client/grpc_client/connection.py +8 -2
  26. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  27. flwr/client/grpc_rere_client/connection.py +33 -4
  28. flwr/client/mod/centraldp_mods.py +4 -2
  29. flwr/client/mod/localdp_mod.py +9 -3
  30. flwr/client/rest_client/connection.py +92 -169
  31. flwr/client/supernode/__init__.py +24 -0
  32. flwr/client/supernode/app.py +281 -0
  33. flwr/common/grpc.py +5 -1
  34. flwr/common/logger.py +37 -4
  35. flwr/common/message.py +105 -86
  36. flwr/common/record/parametersrecord.py +0 -1
  37. flwr/common/record/recordset.py +78 -27
  38. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  39. flwr/common/telemetry.py +4 -0
  40. flwr/server/app.py +116 -6
  41. flwr/server/compat/app.py +2 -2
  42. flwr/server/compat/app_utils.py +1 -1
  43. flwr/server/compat/driver_client_proxy.py +27 -70
  44. flwr/server/driver/__init__.py +2 -1
  45. flwr/server/driver/driver.py +12 -139
  46. flwr/server/driver/grpc_driver.py +199 -13
  47. flwr/server/run_serverapp.py +18 -4
  48. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  49. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  50. flwr/server/superlink/driver/driver_servicer.py +1 -1
  51. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  52. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  53. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
  54. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  55. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  56. flwr/server/superlink/state/in_memory_state.py +89 -12
  57. flwr/server/superlink/state/sqlite_state.py +133 -16
  58. flwr/server/superlink/state/state.py +56 -6
  59. flwr/simulation/__init__.py +2 -2
  60. flwr/simulation/app.py +16 -1
  61. flwr/simulation/run_simulation.py +10 -7
  62. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
  63. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
  64. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
  65. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
  66. {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
flwr/server/app.py CHANGED
@@ -16,15 +16,21 @@
16
16
 
17
17
  import argparse
18
18
  import asyncio
19
+ import csv
19
20
  import importlib.util
20
21
  import sys
21
22
  import threading
22
23
  from logging import ERROR, INFO, WARN
23
24
  from os.path import isfile
24
25
  from pathlib import Path
25
- from typing import List, Optional, Tuple
26
+ from typing import List, Optional, Sequence, Set, Tuple
26
27
 
27
28
  import grpc
29
+ from cryptography.hazmat.primitives.asymmetric import ec
30
+ from cryptography.hazmat.primitives.serialization import (
31
+ load_ssh_private_key,
32
+ load_ssh_public_key,
33
+ )
28
34
 
29
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
30
36
  from flwr.common.address import parse_address
@@ -36,6 +42,11 @@ from flwr.common.constant import (
36
42
  )
37
43
  from flwr.common.exit_handlers import register_exit_handlers
38
44
  from flwr.common.logger import log
45
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
46
+ private_key_to_bytes,
47
+ public_key_to_bytes,
48
+ ssh_types_to_elliptic_curve,
49
+ )
39
50
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
40
51
  add_FleetServicer_to_server,
41
52
  )
@@ -51,6 +62,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
51
62
  start_grpc_server,
52
63
  )
53
64
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
65
+ from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
54
66
  from .superlink.fleet.vce import start_vce
55
67
  from .superlink.state import StateFactory
56
68
 
@@ -291,9 +303,11 @@ def run_fleet_api() -> None:
291
303
 
292
304
  # pylint: disable=too-many-branches, too-many-locals, too-many-statements
293
305
  def run_superlink() -> None:
294
- """Run Flower server (Driver API and Fleet API)."""
295
- log(INFO, "Starting Flower server")
306
+ """Run Flower SuperLink (Driver API and Fleet API)."""
307
+ log(INFO, "Starting Flower SuperLink")
308
+
296
309
  event(EventType.RUN_SUPERLINK_ENTER)
310
+
297
311
  args = _parse_args_run_superlink().parse_args()
298
312
 
299
313
  # Parse IP address
@@ -352,10 +366,33 @@ def run_superlink() -> None:
352
366
  sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
353
367
  host, port, is_v6 = parsed_address
354
368
  address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
369
+
370
+ maybe_keys = _try_setup_client_authentication(args, certificates)
371
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
372
+ if maybe_keys is not None:
373
+ (
374
+ client_public_keys,
375
+ server_private_key,
376
+ server_public_key,
377
+ ) = maybe_keys
378
+ state = state_factory.state()
379
+ state.store_client_public_keys(client_public_keys)
380
+ state.store_server_private_public_key(
381
+ private_key_to_bytes(server_private_key),
382
+ public_key_to_bytes(server_public_key),
383
+ )
384
+ log(
385
+ INFO,
386
+ "Client authentication enabled with %d known public keys",
387
+ len(client_public_keys),
388
+ )
389
+ interceptors = [AuthenticateServerInterceptor(state)]
390
+
355
391
  fleet_server = _run_fleet_api_grpc_rere(
356
392
  address=address,
357
393
  state_factory=state_factory,
358
394
  certificates=certificates,
395
+ interceptors=interceptors,
359
396
  )
360
397
  grpc_servers.append(fleet_server)
361
398
  elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
@@ -388,6 +425,70 @@ def run_superlink() -> None:
388
425
  driver_server.wait_for_termination(timeout=1)
389
426
 
390
427
 
428
+ def _try_setup_client_authentication(
429
+ args: argparse.Namespace,
430
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
431
+ ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
432
+ if not args.require_client_authentication:
433
+ return None
434
+
435
+ if certificates is None:
436
+ sys.exit(
437
+ "Client authentication only works over secure connections. "
438
+ "Please provide certificate paths using '--certificates' when "
439
+ "enabling '--require-client-authentication'."
440
+ )
441
+
442
+ client_keys_file_path = Path(args.require_client_authentication[0])
443
+ if not client_keys_file_path.exists():
444
+ sys.exit(
445
+ "The provided path to the client public keys CSV file does not exist: "
446
+ f"{client_keys_file_path}. "
447
+ "Please provide the CSV file path containing known client public keys "
448
+ "to '--require-client-authentication'."
449
+ )
450
+
451
+ client_public_keys: Set[bytes] = set()
452
+ ssh_private_key = load_ssh_private_key(
453
+ Path(args.require_client_authentication[1]).read_bytes(),
454
+ None,
455
+ )
456
+ ssh_public_key = load_ssh_public_key(
457
+ Path(args.require_client_authentication[2]).read_bytes()
458
+ )
459
+
460
+ try:
461
+ server_private_key, server_public_key = ssh_types_to_elliptic_curve(
462
+ ssh_private_key, ssh_public_key
463
+ )
464
+ except TypeError:
465
+ sys.exit(
466
+ "The file paths provided could not be read as a private and public "
467
+ "key pair. Client authentication requires an elliptic curve public and "
468
+ "private key pair. Please provide the file paths containing elliptic "
469
+ "curve private and public keys to '--require-client-authentication'."
470
+ )
471
+
472
+ with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
473
+ reader = csv.reader(csvfile)
474
+ for row in reader:
475
+ for element in row:
476
+ public_key = load_ssh_public_key(element.encode())
477
+ if isinstance(public_key, ec.EllipticCurvePublicKey):
478
+ client_public_keys.add(public_key_to_bytes(public_key))
479
+ else:
480
+ sys.exit(
481
+ "Error: Unable to parse the public keys in the .csv "
482
+ "file. Please ensure that the .csv file contains valid "
483
+ "SSH public keys and try again."
484
+ )
485
+ return (
486
+ client_public_keys,
487
+ server_private_key,
488
+ server_public_key,
489
+ )
490
+
491
+
391
492
  def _try_obtain_certificates(
392
493
  args: argparse.Namespace,
393
494
  ) -> Optional[Tuple[bytes, bytes, bytes]]:
@@ -415,6 +516,7 @@ def _run_fleet_api_grpc_rere(
415
516
  address: str,
416
517
  state_factory: StateFactory,
417
518
  certificates: Optional[Tuple[bytes, bytes, bytes]],
519
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
418
520
  ) -> grpc.Server:
419
521
  """Run Fleet API (gRPC, request-response)."""
420
522
  # Create Fleet API gRPC server
@@ -427,6 +529,7 @@ def _run_fleet_api_grpc_rere(
427
529
  server_address=address,
428
530
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
429
531
  certificates=certificates,
532
+ interceptors=interceptors,
430
533
  )
431
534
 
432
535
  log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
@@ -568,9 +671,7 @@ def _parse_args_run_fleet_api() -> argparse.ArgumentParser:
568
671
  def _parse_args_run_superlink() -> argparse.ArgumentParser:
569
672
  """Parse command line arguments for both Driver API and Fleet API."""
570
673
  parser = argparse.ArgumentParser(
571
- description="This will start a Flower server "
572
- "(meaning, a Driver API and a Fleet API), "
573
- "that clients will be able to connect to.",
674
+ description="Start a Flower SuperLink",
574
675
  )
575
676
 
576
677
  _add_args_common(parser=parser)
@@ -606,6 +707,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
606
707
  "Flower will just create a state in memory.",
607
708
  default=DATABASE,
608
709
  )
710
+ parser.add_argument(
711
+ "--require-client-authentication",
712
+ nargs=3,
713
+ metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
714
+ type=str,
715
+ help="Provide three file paths: (1) a .csv file containing a list of "
716
+ "known client public keys for authentication, (2) the server's private "
717
+ "key file, and (3) the server's public key file.",
718
+ )
609
719
 
610
720
 
611
721
  def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
flwr/server/compat/app.py CHANGED
@@ -29,7 +29,7 @@ from flwr.server.server import Server, init_defaults, run_fl
29
29
  from flwr.server.server_config import ServerConfig
30
30
  from flwr.server.strategy import Strategy
31
31
 
32
- from ..driver import Driver
32
+ from ..driver import Driver, GrpcDriver
33
33
  from .app_utils import start_update_client_manager_thread
34
34
 
35
35
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
@@ -114,7 +114,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
114
114
  # Create the Driver
115
115
  if isinstance(root_certificates, str):
116
116
  root_certificates = Path(root_certificates).read_bytes()
117
- driver = Driver(
117
+ driver = GrpcDriver(
118
118
  driver_service_address=address, root_certificates=root_certificates
119
119
  )
120
120
 
@@ -89,7 +89,7 @@ def _update_client_manager(
89
89
  for node_id in new_nodes:
90
90
  client_proxy = DriverClientProxy(
91
91
  node_id=node_id,
92
- driver=driver.grpc_driver, # type: ignore
92
+ driver=driver,
93
93
  anonymous=False,
94
94
  run_id=driver.run_id, # type: ignore
95
95
  )
@@ -16,16 +16,14 @@
16
16
 
17
17
 
18
18
  import time
19
- from typing import List, Optional
19
+ from typing import Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
22
+ from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
- from flwr.common import serde
25
- from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
26
24
  from flwr.server.client_proxy import ClientProxy
27
25
 
28
- from ..driver.grpc_driver import GrpcDriver
26
+ from ..driver.driver import Driver
29
27
 
30
28
  SLEEP_TIME = 1
31
29
 
@@ -33,7 +31,7 @@ SLEEP_TIME = 1
33
31
  class DriverClientProxy(ClientProxy):
34
32
  """Flower client proxy which delegates work using the Driver API."""
35
33
 
36
- def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: int):
34
+ def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
37
35
  super().__init__(str(node_id))
38
36
  self.node_id = node_id
39
37
  self.driver = driver
@@ -114,80 +112,39 @@ class DriverClientProxy(ClientProxy):
114
112
  timeout: Optional[float],
115
113
  group_id: Optional[int],
116
114
  ) -> RecordSet:
117
- task_ins = task_pb2.TaskIns( # pylint: disable=E1101
118
- task_id="",
119
- group_id=str(group_id) if group_id is not None else "",
120
- run_id=self.run_id,
121
- task=task_pb2.Task( # pylint: disable=E1101
122
- producer=node_pb2.Node( # pylint: disable=E1101
123
- node_id=0,
124
- anonymous=True,
125
- ),
126
- consumer=node_pb2.Node( # pylint: disable=E1101
127
- node_id=self.node_id,
128
- anonymous=self.anonymous,
129
- ),
130
- task_type=task_type,
131
- recordset=serde.recordset_to_proto(recordset),
132
- ttl=DEFAULT_TTL,
133
- ),
134
- )
135
-
136
- # This would normally be recorded upon common.Message creation
137
- # but this compatibility stack doesn't create Messages,
138
- # so we need to inject `created_at` manually (needed for
139
- # taskins validation by server.utils.validator)
140
- task_ins.task.created_at = time.time()
141
115
 
142
- push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
143
- task_ins_list=[task_ins]
116
+ # Create message
117
+ message = self.driver.create_message(
118
+ content=recordset,
119
+ message_type=task_type,
120
+ dst_node_id=self.node_id,
121
+ group_id=str(group_id) if group_id else "",
122
+ ttl=timeout,
144
123
  )
145
124
 
146
- # Send TaskIns to Driver API
147
- push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
148
-
149
- if len(push_task_ins_res.task_ids) != 1:
150
- raise ValueError("Unexpected number of task_ids")
125
+ # Push message
126
+ message_ids = list(self.driver.push_messages(messages=[message]))
127
+ if len(message_ids) != 1:
128
+ raise ValueError("Unexpected number of message_ids")
151
129
 
152
- task_id = push_task_ins_res.task_ids[0]
153
- if task_id == "":
154
- raise ValueError(f"Failed to schedule task for node {self.node_id}")
130
+ message_id = message_ids[0]
131
+ if message_id == "":
132
+ raise ValueError(f"Failed to send message to node {self.node_id}")
155
133
 
156
134
  if timeout:
157
135
  start_time = time.time()
158
136
 
159
137
  while True:
160
- pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101
161
- node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101
162
- task_ids=[task_id],
163
- )
164
-
165
- # Ask Driver API for TaskRes
166
- pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
167
-
168
- task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101
169
- pull_task_res_res.task_res_list
170
- )
171
- if len(task_res_list) == 1:
172
- task_res = task_res_list[0]
173
-
174
- # This will raise an Exception if task_res carries an `error`
175
- validate_task_res(task_res=task_res)
176
-
177
- return serde.recordset_from_proto(task_res.task.recordset)
138
+ messages = list(self.driver.pull_messages(message_ids))
139
+ if len(messages) == 1:
140
+ msg: Message = messages[0]
141
+ if msg.has_error():
142
+ raise ValueError(
143
+ f"Message contains an Error (reason: {msg.error.reason}). "
144
+ "It originated during client-side execution of a message."
145
+ )
146
+ return msg.content
178
147
 
179
148
  if timeout is not None and time.time() > start_time + timeout:
180
149
  raise RuntimeError("Timeout reached")
181
150
  time.sleep(SLEEP_TIME)
182
-
183
-
184
- def validate_task_res(
185
- task_res: task_pb2.TaskRes, # pylint: disable=E1101
186
- ) -> None:
187
- """Validate if a TaskRes is empty or not."""
188
- if not task_res.HasField("task"):
189
- raise ValueError("Invalid TaskRes, field `task` missing")
190
- if task_res.task.HasField("error"):
191
- raise ValueError("Exception during client-side task execution")
192
- if not task_res.task.HasField("recordset"):
193
- raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
@@ -16,9 +16,10 @@
16
16
 
17
17
 
18
18
  from .driver import Driver
19
- from .grpc_driver import GrpcDriver
19
+ from .grpc_driver import GrpcDriver, GrpcDriverHelper
20
20
 
21
21
  __all__ = [
22
22
  "Driver",
23
23
  "GrpcDriver",
24
+ "GrpcDriverHelper",
24
25
  ]
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,79 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service client."""
15
+ """Driver (abstract base class)."""
16
16
 
17
- import time
18
- import warnings
19
- from typing import Iterable, List, Optional, Tuple
20
17
 
21
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
22
- from flwr.common.serde import message_from_taskres, message_to_taskins
23
- from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
24
- CreateRunRequest,
25
- GetNodesRequest,
26
- PullTaskResRequest,
27
- PushTaskInsRequest,
28
- )
29
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
30
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
18
+ from abc import ABC, abstractmethod
19
+ from typing import Iterable, List, Optional
31
20
 
32
- from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
21
+ from flwr.common import Message, RecordSet
33
22
 
34
23
 
35
- class Driver:
36
- """`Driver` class provides an interface to the Driver API.
37
-
38
- Parameters
39
- ----------
40
- driver_service_address : Optional[str]
41
- The IPv4 or IPv6 address of the Driver API server.
42
- Defaults to `"[::]:9091"`.
43
- certificates : bytes (default: None)
44
- Tuple containing root certificate, server certificate, and private key
45
- to start a secure SSL-enabled server. The tuple is expected to have
46
- three bytes elements in the following order:
47
-
48
- * CA certificate.
49
- * server certificate.
50
- * server private key.
51
- """
52
-
53
- def __init__(
54
- self,
55
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
56
- root_certificates: Optional[bytes] = None,
57
- ) -> None:
58
- self.addr = driver_service_address
59
- self.root_certificates = root_certificates
60
- self.grpc_driver: Optional[GrpcDriver] = None
61
- self.run_id: Optional[int] = None
62
- self.node = Node(node_id=0, anonymous=True)
63
-
64
- def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]:
65
- # Check if the GrpcDriver is initialized
66
- if self.grpc_driver is None or self.run_id is None:
67
- # Connect and create run
68
- self.grpc_driver = GrpcDriver(
69
- driver_service_address=self.addr,
70
- root_certificates=self.root_certificates,
71
- )
72
- self.grpc_driver.connect()
73
- res = self.grpc_driver.create_run(CreateRunRequest())
74
- self.run_id = res.run_id
75
- return self.grpc_driver, self.run_id
76
-
77
- def _check_message(self, message: Message) -> None:
78
- # Check if the message is valid
79
- if not (
80
- message.metadata.run_id == self.run_id
81
- and message.metadata.src_node_id == self.node.node_id
82
- and message.metadata.message_id == ""
83
- and message.metadata.reply_to_message == ""
84
- and message.metadata.ttl > 0
85
- ):
86
- raise ValueError(f"Invalid message: {message}")
24
+ class Driver(ABC):
25
+ """Abstract base Driver class for the Driver API."""
87
26
 
27
+ @abstractmethod
88
28
  def create_message( # pylint: disable=too-many-arguments
89
29
  self,
90
30
  content: RecordSet,
@@ -122,35 +62,12 @@ class Driver:
122
62
  message : Message
123
63
  A new `Message` instance with the specified content and metadata.
124
64
  """
125
- _, run_id = self._get_grpc_driver_and_run_id()
126
- if ttl:
127
- warnings.warn(
128
- "A custom TTL was set, but note that the SuperLink does not enforce "
129
- "the TTL yet. The SuperLink will start enforcing the TTL in a future "
130
- "version of Flower.",
131
- stacklevel=2,
132
- )
133
-
134
- ttl_ = DEFAULT_TTL if ttl is None else ttl
135
- metadata = Metadata(
136
- run_id=run_id,
137
- message_id="", # Will be set by the server
138
- src_node_id=self.node.node_id,
139
- dst_node_id=dst_node_id,
140
- reply_to_message="",
141
- group_id=group_id,
142
- ttl=ttl_,
143
- message_type=message_type,
144
- )
145
- return Message(metadata=metadata, content=content)
146
65
 
66
+ @abstractmethod
147
67
  def get_node_ids(self) -> List[int]:
148
68
  """Get node IDs."""
149
- grpc_driver, run_id = self._get_grpc_driver_and_run_id()
150
- # Call GrpcDriver method
151
- res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id))
152
- return [node.node_id for node in res.nodes]
153
69
 
70
+ @abstractmethod
154
71
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
155
72
  """Push messages to specified node IDs.
156
73
 
@@ -168,20 +85,8 @@ class Driver:
168
85
  An iterable of IDs for the messages that were sent, which can be used
169
86
  to pull replies.
170
87
  """
171
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
172
- # Construct TaskIns
173
- task_ins_list: List[TaskIns] = []
174
- for msg in messages:
175
- # Check message
176
- self._check_message(msg)
177
- # Convert Message to TaskIns
178
- taskins = message_to_taskins(msg)
179
- # Add to list
180
- task_ins_list.append(taskins)
181
- # Call GrpcDriver method
182
- res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
183
- return list(res.task_ids)
184
88
 
89
+ @abstractmethod
185
90
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
186
91
  """Pull messages based on message IDs.
187
92
 
@@ -198,15 +103,8 @@ class Driver:
198
103
  messages : Iterable[Message]
199
104
  An iterable of messages received.
200
105
  """
201
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
202
- # Pull TaskRes
203
- res = grpc_driver.pull_task_res(
204
- PullTaskResRequest(node=self.node, task_ids=message_ids)
205
- )
206
- # Convert TaskRes to Message
207
- msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
208
- return msgs
209
106
 
107
+ @abstractmethod
210
108
  def send_and_receive(
211
109
  self,
212
110
  messages: Iterable[Message],
@@ -240,28 +138,3 @@ class Driver:
240
138
  replies for all sent messages. A message remains valid until its TTL,
241
139
  which is not affected by `timeout`.
242
140
  """
243
- # Push messages
244
- msg_ids = set(self.push_messages(messages))
245
-
246
- # Pull messages
247
- end_time = time.time() + (timeout if timeout is not None else 0.0)
248
- ret: List[Message] = []
249
- while timeout is None or time.time() < end_time:
250
- res_msgs = self.pull_messages(msg_ids)
251
- ret.extend(res_msgs)
252
- msg_ids.difference_update(
253
- {msg.metadata.reply_to_message for msg in res_msgs}
254
- )
255
- if len(msg_ids) == 0:
256
- break
257
- # Sleep
258
- time.sleep(3)
259
- return ret
260
-
261
- def close(self) -> None:
262
- """Disconnect from the SuperLink if connected."""
263
- # Check if GrpcDriver is initialized
264
- if self.grpc_driver is None:
265
- return
266
- # Disconnect
267
- self.grpc_driver.disconnect()