flwr-nightly 1.10.0.dev20240618__py3-none-any.whl → 1.10.0.dev20240620__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (98) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +3 -7
  3. flwr/cli/new/new.py +1 -1
  4. flwr/cli/run/run.py +8 -1
  5. flwr/client/__init__.py +1 -1
  6. flwr/client/app.py +4 -0
  7. flwr/client/client_app.py +1 -1
  8. flwr/client/dpfedavg_numpy_client.py +1 -1
  9. flwr/client/grpc_rere_client/__init__.py +1 -1
  10. flwr/client/grpc_rere_client/connection.py +1 -1
  11. flwr/client/message_handler/__init__.py +1 -1
  12. flwr/client/message_handler/message_handler.py +1 -1
  13. flwr/client/mod/__init__.py +4 -4
  14. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  15. flwr/client/mod/utils.py +1 -1
  16. flwr/client/rest_client/__init__.py +1 -1
  17. flwr/client/rest_client/connection.py +1 -1
  18. flwr/client/supernode/app.py +29 -6
  19. flwr/common/__init__.py +12 -12
  20. flwr/common/address.py +1 -1
  21. flwr/common/config.py +8 -6
  22. flwr/common/constant.py +5 -1
  23. flwr/common/date.py +1 -1
  24. flwr/common/dp.py +1 -1
  25. flwr/common/grpc.py +1 -1
  26. flwr/common/object_ref.py +39 -5
  27. flwr/common/record/__init__.py +1 -1
  28. flwr/common/secure_aggregation/__init__.py +1 -1
  29. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  30. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  31. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  32. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  33. flwr/common/secure_aggregation/quantization.py +1 -1
  34. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  35. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  36. flwr/common/version.py +14 -0
  37. flwr/server/__init__.py +2 -2
  38. flwr/server/app.py +47 -7
  39. flwr/server/compat/app.py +1 -1
  40. flwr/server/compat/app_utils.py +1 -1
  41. flwr/server/compat/driver_client_proxy.py +1 -1
  42. flwr/server/driver/driver.py +6 -0
  43. flwr/server/driver/grpc_driver.py +85 -63
  44. flwr/server/driver/inmemory_driver.py +28 -26
  45. flwr/server/run_serverapp.py +15 -8
  46. flwr/server/strategy/__init__.py +2 -2
  47. flwr/server/strategy/bulyan.py +1 -1
  48. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  49. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  50. flwr/server/strategy/fedadagrad.py +1 -1
  51. flwr/server/strategy/fedadam.py +1 -1
  52. flwr/server/strategy/fedavg_android.py +1 -1
  53. flwr/server/strategy/fedavgm.py +1 -1
  54. flwr/server/strategy/fedmedian.py +1 -1
  55. flwr/server/strategy/fedopt.py +1 -1
  56. flwr/server/strategy/fedprox.py +1 -1
  57. flwr/server/strategy/fedxgb_bagging.py +1 -1
  58. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  59. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  60. flwr/server/strategy/fedyogi.py +1 -1
  61. flwr/server/strategy/krum.py +1 -1
  62. flwr/server/strategy/qfedavg.py +1 -1
  63. flwr/server/superlink/driver/__init__.py +1 -1
  64. flwr/server/superlink/driver/driver_grpc.py +1 -1
  65. flwr/server/superlink/driver/driver_servicer.py +15 -3
  66. flwr/server/superlink/fleet/__init__.py +1 -1
  67. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  68. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  69. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  70. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  71. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  72. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  73. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -1
  74. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  76. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  77. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  78. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  79. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  80. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  81. flwr/server/superlink/state/__init__.py +1 -1
  82. flwr/server/superlink/state/in_memory_state.py +1 -1
  83. flwr/server/superlink/state/sqlite_state.py +1 -1
  84. flwr/server/superlink/state/state.py +1 -1
  85. flwr/server/superlink/state/state_factory.py +11 -2
  86. flwr/server/utils/__init__.py +1 -1
  87. flwr/server/utils/tensorboard.py +1 -1
  88. flwr/simulation/__init__.py +5 -2
  89. flwr/simulation/app.py +1 -1
  90. flwr/simulation/ray_transport/__init__.py +1 -1
  91. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  92. flwr/simulation/run_simulation.py +15 -8
  93. flwr/superexec/app.py +1 -1
  94. {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/METADATA +2 -1
  95. {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/RECORD +98 -96
  96. {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/LICENSE +0 -0
  97. {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/WHEEL +0 -0
  98. {flwr_nightly-1.10.0.dev20240618.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/entry_points.txt +0 -0
flwr/server/app.py CHANGED
@@ -36,6 +36,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
36
36
  from flwr.common.address import parse_address
37
37
  from flwr.common.constant import (
38
38
  MISSING_EXTRA_REST,
39
+ TRANSPORT_TYPE_GRPC_ADAPTER,
39
40
  TRANSPORT_TYPE_GRPC_RERE,
40
41
  TRANSPORT_TYPE_REST,
41
42
  )
@@ -48,6 +49,7 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
48
49
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
49
50
  add_FleetServicer_to_server,
50
51
  )
52
+ from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
51
53
 
52
54
  from .client_manager import ClientManager
53
55
  from .history import History
@@ -55,6 +57,7 @@ from .server import Server, init_defaults, run_fl
55
57
  from .server_config import ServerConfig
56
58
  from .strategy import Strategy
57
59
  from .superlink.driver.driver_grpc import run_driver_api_grpc
60
+ from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
58
61
  from .superlink.fleet.grpc_bidi.grpc_server import (
59
62
  generic_create_grpc_server,
60
63
  start_grpc_server,
@@ -218,11 +221,13 @@ def run_superlink() -> None:
218
221
  grpc_servers = [driver_server]
219
222
  bckg_threads = []
220
223
  if not args.fleet_api_address:
221
- args.fleet_api_address = (
222
- ADDRESS_FLEET_API_GRPC_RERE
223
- if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE
224
- else ADDRESS_FLEET_API_REST
225
- )
224
+ if args.fleet_api_type in [
225
+ TRANSPORT_TYPE_GRPC_RERE,
226
+ TRANSPORT_TYPE_GRPC_ADAPTER,
227
+ ]:
228
+ args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE
229
+ elif args.fleet_api_type == TRANSPORT_TYPE_REST:
230
+ args.fleet_api_address = ADDRESS_FLEET_API_REST
226
231
 
227
232
  fleet_address, host, port = _format_address(args.fleet_api_address)
228
233
 
@@ -293,6 +298,13 @@ def run_superlink() -> None:
293
298
  interceptors=interceptors,
294
299
  )
295
300
  grpc_servers.append(fleet_server)
301
+ elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
302
+ fleet_server = _run_fleet_api_grpc_adapter(
303
+ address=fleet_address,
304
+ state_factory=state_factory,
305
+ certificates=certificates,
306
+ )
307
+ grpc_servers.append(fleet_server)
296
308
  else:
297
309
  raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
298
310
 
@@ -419,7 +431,7 @@ def _try_obtain_certificates(
419
431
  log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
420
432
  return None
421
433
  # Check if certificates are provided
422
- if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
434
+ if args.fleet_api_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
423
435
  if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
424
436
  if not isfile(args.ssl_ca_certfile):
425
437
  sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
@@ -491,6 +503,30 @@ def _run_fleet_api_grpc_rere(
491
503
  return fleet_grpc_server
492
504
 
493
505
 
506
+ def _run_fleet_api_grpc_adapter(
507
+ address: str,
508
+ state_factory: StateFactory,
509
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
510
+ ) -> grpc.Server:
511
+ """Run Fleet API (GrpcAdapter)."""
512
+ # Create Fleet API gRPC server
513
+ fleet_servicer = GrpcAdapterServicer(
514
+ state_factory=state_factory,
515
+ )
516
+ fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
517
+ fleet_grpc_server = generic_create_grpc_server(
518
+ servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn),
519
+ server_address=address,
520
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
521
+ certificates=certificates,
522
+ )
523
+
524
+ log(INFO, "Flower ECE: Starting Fleet API (GrpcAdapter) on %s", address)
525
+ fleet_grpc_server.start()
526
+
527
+ return fleet_grpc_server
528
+
529
+
494
530
  # pylint: disable=import-outside-toplevel,too-many-arguments
495
531
  def _run_fleet_api_rest(
496
532
  host: str,
@@ -606,7 +642,11 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
606
642
  "--fleet-api-type",
607
643
  default=TRANSPORT_TYPE_GRPC_RERE,
608
644
  type=str,
609
- choices=[TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST],
645
+ choices=[
646
+ TRANSPORT_TYPE_GRPC_RERE,
647
+ TRANSPORT_TYPE_GRPC_ADAPTER,
648
+ TRANSPORT_TYPE_REST,
649
+ ],
610
650
  help="Start a gRPC-rere or REST (experimental) Fleet API server.",
611
651
  )
612
652
  parser.add_argument(
flwr/server/compat/app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -91,7 +91,7 @@ def _update_client_manager(
91
91
  node_id=node_id,
92
92
  driver=driver,
93
93
  anonymous=False,
94
- run_id=driver.run_id, # type: ignore
94
+ run_id=driver.run.run_id,
95
95
  )
96
96
  if client_manager.register(client_proxy):
97
97
  registered_nodes[node_id] = client_proxy
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,11 +19,17 @@ from abc import ABC, abstractmethod
19
19
  from typing import Iterable, List, Optional
20
20
 
21
21
  from flwr.common import Message, RecordSet
22
+ from flwr.common.typing import Run
22
23
 
23
24
 
24
25
  class Driver(ABC):
25
26
  """Abstract base Driver class for the Driver API."""
26
27
 
28
+ @property
29
+ @abstractmethod
30
+ def run(self) -> Run:
31
+ """Run information."""
32
+
27
33
  @abstractmethod
28
34
  def create_message( # pylint: disable=too-many-arguments
29
35
  self,
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
17
17
  import time
18
18
  import warnings
19
19
  from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, List, Optional, Tuple, cast
21
21
 
22
22
  import grpc
23
23
 
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
27
  from flwr.common.serde import message_from_taskres, message_to_taskins
28
+ from flwr.common.typing import Run
28
29
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
30
  CreateRunRequest,
30
31
  CreateRunResponse,
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
37
38
  )
38
39
  from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
39
40
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
41
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
42
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
41
43
 
42
44
  from .driver import Driver
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
46
48
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
47
49
  [Driver] Error: Not connected.
48
50
 
49
- Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
50
- `GrpcDriverHelper` methods.
51
+ Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
52
+ `GrpcDriverStub` methods.
51
53
  """
52
54
 
53
55
 
54
- class GrpcDriverHelper:
55
- """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
56
+ class GrpcDriverStub:
57
+ """`GrpcDriverStub` provides access to the gRPC Driver API/service.
58
+
59
+ Parameters
60
+ ----------
61
+ driver_service_address : Optional[str]
62
+ The IPv4 or IPv6 address of the Driver API server.
63
+ Defaults to `"[::]:9091"`.
64
+ root_certificates : Optional[bytes] (default: None)
65
+ The PEM-encoded root certificates as a byte string.
66
+ If provided, a secure connection using the certificates will be
67
+ established to an SSL-enabled Flower server.
68
+ """
56
69
 
57
70
  def __init__(
58
71
  self,
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
64
77
  self.channel: Optional[grpc.Channel] = None
65
78
  self.stub: Optional[DriverStub] = None
66
79
 
80
+ def is_connected(self) -> bool:
81
+ """Return True if connected to the Driver API server, otherwise False."""
82
+ return self.channel is not None
83
+
67
84
  def connect(self) -> None:
68
85
  """Connect to the Driver API."""
69
86
  event(EventType.DRIVER_CONNECT)
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
95
112
  # Check if channel is open
96
113
  if self.stub is None:
97
114
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
98
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
115
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
99
116
 
100
117
  # Call Driver API
101
118
  res: CreateRunResponse = self.stub.CreateRun(request=req)
102
119
  return res
103
120
 
121
+ def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
+ """Get run information."""
123
+ # Check if channel is open
124
+ if self.stub is None:
125
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
127
+
128
+ # Call gRPC Driver API
129
+ res: GetRunResponse = self.stub.GetRun(request=req)
130
+ return res
131
+
104
132
  def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
105
133
  """Get client IDs."""
106
134
  # Check if channel is open
107
135
  if self.stub is None:
108
136
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
109
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
137
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
110
138
 
111
139
  # Call gRPC Driver API
112
140
  res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
117
145
  # Check if channel is open
118
146
  if self.stub is None:
119
147
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
120
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
148
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
121
149
 
122
150
  # Call gRPC Driver API
123
151
  res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
128
156
  # Check if channel is open
129
157
  if self.stub is None:
130
158
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
131
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
159
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
132
160
 
133
161
  # Call Driver API
134
162
  res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
140
168
 
141
169
  Parameters
142
170
  ----------
143
- driver_service_address : Optional[str]
144
- The IPv4 or IPv6 address of the Driver API server.
145
- Defaults to `"[::]:9091"`.
146
- certificates : bytes (default: None)
147
- Tuple containing root certificate, server certificate, and private key
148
- to start a secure SSL-enabled server. The tuple is expected to have
149
- three bytes elements in the following order:
150
-
151
- * CA certificate.
152
- * server certificate.
153
- * server private key.
154
- fab_id : str (default: None)
155
- The identifier of the FAB used in the run.
156
- fab_version : str (default: None)
157
- The version of the FAB used in the run.
171
+ run_id : int
172
+ The identifier of the run.
173
+ stub : Optional[GrpcDriverStub] (default: None)
174
+ The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
+ If None, an instance connected to "[::]:9091" will be created.
158
176
  """
159
177
 
160
- def __init__(
178
+ def __init__( # pylint: disable=too-many-arguments
161
179
  self,
162
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
163
- root_certificates: Optional[bytes] = None,
164
- fab_id: Optional[str] = None,
165
- fab_version: Optional[str] = None,
180
+ run_id: int,
181
+ stub: Optional[GrpcDriverStub] = None,
166
182
  ) -> None:
167
- self.addr = driver_service_address
168
- self.root_certificates = root_certificates
169
- self.driver_helper: Optional[GrpcDriverHelper] = None
170
- self.run_id: Optional[int] = None
171
- self.fab_id = fab_id if fab_id is not None else ""
172
- self.fab_version = fab_version if fab_version is not None else ""
183
+ self._run_id = run_id
184
+ self._run: Optional[Run] = None
185
+ self.stub = stub if stub is not None else GrpcDriverStub()
173
186
  self.node = Node(node_id=0, anonymous=True)
174
187
 
175
- def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
176
- # Check if the GrpcDriverHelper is initialized
177
- if self.driver_helper is None or self.run_id is None:
178
- # Connect and create run
179
- self.driver_helper = GrpcDriverHelper(
180
- driver_service_address=self.addr,
181
- root_certificates=self.root_certificates,
188
+ @property
189
+ def run(self) -> Run:
190
+ """Run information."""
191
+ self._get_stub_and_run_id()
192
+ return Run(**vars(cast(Run, self._run)))
193
+
194
+ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
+ # Check if is initialized
196
+ if self._run is None:
197
+ # Connect
198
+ if not self.stub.is_connected():
199
+ self.stub.connect()
200
+ # Get the run info
201
+ req = GetRunRequest(run_id=self._run_id)
202
+ res = self.stub.get_run(req)
203
+ if not res.HasField("run"):
204
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
+ self._run = Run(
206
+ run_id=res.run.run_id,
207
+ fab_id=res.run.fab_id,
208
+ fab_version=res.run.fab_version,
182
209
  )
183
- self.driver_helper.connect()
184
- req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
185
- res = self.driver_helper.create_run(req)
186
- self.run_id = res.run_id
187
- return self.driver_helper, self.run_id
210
+
211
+ return self.stub, self._run.run_id
188
212
 
189
213
  def _check_message(self, message: Message) -> None:
190
214
  # Check if the message is valid
191
215
  if not (
192
- message.metadata.run_id == self.run_id
216
+ message.metadata.run_id == cast(Run, self._run).run_id
193
217
  and message.metadata.src_node_id == self.node.node_id
194
218
  and message.metadata.message_id == ""
195
219
  and message.metadata.reply_to_message == ""
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
210
234
  This method constructs a new `Message` with given content and metadata.
211
235
  The `run_id` and `src_node_id` will be set automatically.
212
236
  """
213
- _, run_id = self._get_grpc_driver_helper_and_run_id()
237
+ _, run_id = self._get_stub_and_run_id()
214
238
  if ttl:
215
239
  warnings.warn(
216
240
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
234
258
 
235
259
  def get_node_ids(self) -> List[int]:
236
260
  """Get node IDs."""
237
- grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
238
- # Call GrpcDriverHelper method
239
- res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
261
+ stub, run_id = self._get_stub_and_run_id()
262
+ # Call GrpcDriverStub method
263
+ res = stub.get_nodes(GetNodesRequest(run_id=run_id))
240
264
  return [node.node_id for node in res.nodes]
241
265
 
242
266
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
245
269
  This method takes an iterable of messages and sends each message
246
270
  to the node specified in `dst_node_id`.
247
271
  """
248
- grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
272
+ stub, _ = self._get_stub_and_run_id()
249
273
  # Construct TaskIns
250
274
  task_ins_list: List[TaskIns] = []
251
275
  for msg in messages:
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
255
279
  taskins = message_to_taskins(msg)
256
280
  # Add to list
257
281
  task_ins_list.append(taskins)
258
- # Call GrpcDriverHelper method
259
- res = grpc_driver_helper.push_task_ins(
260
- PushTaskInsRequest(task_ins_list=task_ins_list)
261
- )
282
+ # Call GrpcDriverStub method
283
+ res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
262
284
  return list(res.task_ids)
263
285
 
264
286
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
267
289
  This method is used to collect messages from the SuperLink that correspond to a
268
290
  set of given message IDs.
269
291
  """
270
- grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
292
+ stub, _ = self._get_stub_and_run_id()
271
293
  # Pull TaskRes
272
- res = grpc_driver.pull_task_res(
294
+ res = stub.pull_task_res(
273
295
  PullTaskResRequest(node=self.node, task_ids=message_ids)
274
296
  )
275
297
  # Convert TaskRes to Message
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
308
330
 
309
331
  def close(self) -> None:
310
332
  """Disconnect from the SuperLink if connected."""
311
- # Check if GrpcDriverHelper is initialized
312
- if self.driver_helper is None:
333
+ # Check if `connect` was called before
334
+ if not self.stub.is_connected():
313
335
  return
314
336
  # Disconnect
315
- self.driver_helper.disconnect()
337
+ self.stub.disconnect()
@@ -17,11 +17,12 @@
17
17
 
18
18
  import time
19
19
  import warnings
20
- from typing import Iterable, List, Optional
20
+ from typing import Iterable, List, Optional, cast
21
21
  from uuid import UUID
22
22
 
23
23
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
24
24
  from flwr.common.serde import message_from_taskres, message_to_taskins
25
+ from flwr.common.typing import Run
25
26
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
26
27
  from flwr.server.superlink.state import StateFactory
27
28
 
@@ -33,30 +34,27 @@ class InMemoryDriver(Driver):
33
34
 
34
35
  Parameters
35
36
  ----------
37
+ run_id : int
38
+ The identifier of the run.
36
39
  state_factory : StateFactory
37
40
  A StateFactory embedding a state that this driver can interface with.
38
- fab_id : str (default: None)
39
- The identifier of the FAB used in the run.
40
- fab_version : str (default: None)
41
- The version of the FAB used in the run.
42
41
  """
43
42
 
44
43
  def __init__(
45
44
  self,
45
+ run_id: int,
46
46
  state_factory: StateFactory,
47
- fab_id: Optional[str] = None,
48
- fab_version: Optional[str] = None,
49
47
  ) -> None:
50
- self.run_id: Optional[int] = None
51
- self.fab_id = fab_id if fab_id is not None else ""
52
- self.fab_version = fab_version if fab_version is not None else ""
53
- self.node = Node(node_id=0, anonymous=True)
48
+ self._run_id = run_id
49
+ self._run: Optional[Run] = None
54
50
  self.state = state_factory.state()
51
+ self.node = Node(node_id=0, anonymous=True)
55
52
 
56
53
  def _check_message(self, message: Message) -> None:
54
+ self._init_run()
57
55
  # Check if the message is valid
58
56
  if not (
59
- message.metadata.run_id == self.run_id
57
+ message.metadata.run_id == cast(Run, self._run).run_id
60
58
  and message.metadata.src_node_id == self.node.node_id
61
59
  and message.metadata.message_id == ""
62
60
  and message.metadata.reply_to_message == ""
@@ -64,16 +62,20 @@ class InMemoryDriver(Driver):
64
62
  ):
65
63
  raise ValueError(f"Invalid message: {message}")
66
64
 
67
- def _get_run_id(self) -> int:
68
- """Return run_id.
69
-
70
- If unset, create a new run.
71
- """
72
- if self.run_id is None:
73
- self.run_id = self.state.create_run(
74
- fab_id=self.fab_id, fab_version=self.fab_version
75
- )
76
- return self.run_id
65
+ def _init_run(self) -> None:
66
+ """Initialize the run."""
67
+ if self._run is not None:
68
+ return
69
+ run = self.state.get_run(self._run_id)
70
+ if run is None:
71
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
72
+ self._run = run
73
+
74
+ @property
75
+ def run(self) -> Run:
76
+ """Run ID."""
77
+ self._init_run()
78
+ return Run(**vars(cast(Run, self._run)))
77
79
 
78
80
  def create_message( # pylint: disable=too-many-arguments
79
81
  self,
@@ -88,7 +90,7 @@ class InMemoryDriver(Driver):
88
90
  This method constructs a new `Message` with given content and metadata.
89
91
  The `run_id` and `src_node_id` will be set automatically.
90
92
  """
91
- run_id = self._get_run_id()
93
+ self._init_run()
92
94
  if ttl:
93
95
  warnings.warn(
94
96
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -99,7 +101,7 @@ class InMemoryDriver(Driver):
99
101
  ttl_ = DEFAULT_TTL if ttl is None else ttl
100
102
 
101
103
  metadata = Metadata(
102
- run_id=run_id,
104
+ run_id=cast(Run, self._run).run_id,
103
105
  message_id="", # Will be set by the server
104
106
  src_node_id=self.node.node_id,
105
107
  dst_node_id=dst_node_id,
@@ -112,8 +114,8 @@ class InMemoryDriver(Driver):
112
114
 
113
115
  def get_node_ids(self) -> List[int]:
114
116
  """Get node IDs."""
115
- run_id = self._get_run_id()
116
- return list(self.state.get_nodes(run_id))
117
+ self._init_run()
118
+ return list(self.state.get_nodes(cast(Run, self._run).run_id))
117
119
 
118
120
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
119
121
  """Push messages to specified node IDs.
@@ -24,8 +24,10 @@ from typing import Optional
24
24
  from flwr.common import Context, EventType, RecordSet, event
25
25
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
26
26
  from flwr.common.object_ref import load_app
27
+ from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
27
28
 
28
- from .driver import Driver, GrpcDriver
29
+ from .driver import Driver
30
+ from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
29
31
  from .server_app import LoadServerAppError, ServerApp
30
32
 
31
33
  ADDRESS_DRIVER_API = "0.0.0.0:9091"
@@ -50,7 +52,9 @@ def run(
50
52
  # Load ServerApp if needed
51
53
  def _load() -> ServerApp:
52
54
  if server_app_attr:
53
- server_app: ServerApp = load_app(server_app_attr, LoadServerAppError)
55
+ server_app: ServerApp = load_app(
56
+ server_app_attr, LoadServerAppError, server_app_dir
57
+ )
54
58
 
55
59
  if not isinstance(server_app, ServerApp):
56
60
  raise LoadServerAppError(
@@ -147,13 +151,16 @@ def run_server_app() -> None:
147
151
  server_app_dir = args.dir
148
152
  server_app_attr = getattr(args, "server-app")
149
153
 
150
- # Initialize GrpcDriver
151
- driver = GrpcDriver(
152
- driver_service_address=args.superlink,
153
- root_certificates=root_certificates,
154
- fab_id=args.fab_id,
155
- fab_version=args.fab_version,
154
+ # Create run
155
+ stub = GrpcDriverStub(
156
+ driver_service_address=args.superlink, root_certificates=root_certificates
156
157
  )
158
+ stub.connect()
159
+ req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
160
+ res = stub.create_run(req)
161
+
162
+ # Initialize GrpcDriver
163
+ driver = GrpcDriver(run_id=res.run_id, stub=stub)
157
164
 
158
165
  # Run the ServerApp with the Driver
159
166
  run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
@@ -53,9 +53,10 @@ __all__ = [
53
53
  "DPFedAvgAdaptive",
54
54
  "DPFedAvgFixed",
55
55
  "DifferentialPrivacyClientSideAdaptiveClipping",
56
- "DifferentialPrivacyServerSideAdaptiveClipping",
57
56
  "DifferentialPrivacyClientSideFixedClipping",
57
+ "DifferentialPrivacyServerSideAdaptiveClipping",
58
58
  "DifferentialPrivacyServerSideFixedClipping",
59
+ "FaultTolerantFedAvg",
59
60
  "FedAdagrad",
60
61
  "FedAdam",
61
62
  "FedAvg",
@@ -69,7 +70,6 @@ __all__ = [
69
70
  "FedXgbCyclic",
70
71
  "FedXgbNnAvg",
71
72
  "FedYogi",
72
- "FaultTolerantFedAvg",
73
73
  "Krum",
74
74
  "QFedAvg",
75
75
  "Strategy",
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.