flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240708__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (34) hide show
  1. flwr/client/app.py +9 -6
  2. flwr/client/grpc_adapter_client/connection.py +2 -1
  3. flwr/client/grpc_client/connection.py +2 -1
  4. flwr/client/grpc_rere_client/connection.py +9 -3
  5. flwr/client/rest_client/connection.py +10 -4
  6. flwr/common/config.py +75 -2
  7. flwr/common/context.py +8 -2
  8. flwr/common/typing.py +1 -0
  9. flwr/proto/common_pb2.py +24 -0
  10. flwr/proto/common_pb2.pyi +7 -0
  11. flwr/proto/common_pb2_grpc.py +4 -0
  12. flwr/proto/common_pb2_grpc.pyi +4 -0
  13. flwr/proto/driver_pb2.py +23 -19
  14. flwr/proto/driver_pb2.pyi +18 -1
  15. flwr/proto/exec_pb2.py +15 -11
  16. flwr/proto/exec_pb2.pyi +19 -1
  17. flwr/proto/run_pb2.py +11 -7
  18. flwr/proto/run_pb2.pyi +19 -1
  19. flwr/server/driver/grpc_driver.py +77 -139
  20. flwr/server/run_serverapp.py +20 -12
  21. flwr/server/superlink/driver/driver_servicer.py +5 -1
  22. flwr/server/superlink/state/in_memory_state.py +10 -2
  23. flwr/server/superlink/state/sqlite_state.py +22 -7
  24. flwr/server/superlink/state/state.py +7 -2
  25. flwr/simulation/run_simulation.py +1 -1
  26. flwr/superexec/app.py +1 -0
  27. flwr/superexec/deployment.py +16 -5
  28. flwr/superexec/exec_servicer.py +4 -1
  29. flwr/superexec/executor.py +2 -3
  30. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/METADATA +1 -1
  31. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/RECORD +34 -30
  32. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/LICENSE +0 -0
  33. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/WHEEL +0 -0
  34. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/entry_points.txt +0 -0
@@ -16,8 +16,8 @@
16
16
 
17
17
  import time
18
18
  import warnings
19
- from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple, cast
19
+ from logging import DEBUG, WARNING
20
+ from typing import Iterable, List, Optional, cast
21
21
 
22
22
  import grpc
23
23
 
@@ -27,8 +27,6 @@ from flwr.common.logger import log
27
27
  from flwr.common.serde import message_from_taskres, message_to_taskins
28
28
  from flwr.common.typing import Run
29
29
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
30
- CreateRunRequest,
31
- CreateRunResponse,
32
30
  GetNodesRequest,
33
31
  GetNodesResponse,
34
32
  PullTaskResRequest,
@@ -53,167 +51,103 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
53
51
  """
54
52
 
55
53
 
56
- class GrpcDriverStub:
57
- """`GrpcDriverStub` provides access to the gRPC Driver API/service.
54
+ class GrpcDriver(Driver):
55
+ """`GrpcDriver` provides an interface to the Driver API.
58
56
 
59
57
  Parameters
60
58
  ----------
61
- driver_service_address : Optional[str]
62
- The IPv4 or IPv6 address of the Driver API server.
63
- Defaults to `"[::]:9091"`.
59
+ run_id : int
60
+ The identifier of the run.
61
+ driver_service_address : str (default: "[::]:9091")
62
+ The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
64
63
  root_certificates : Optional[bytes] (default: None)
65
64
  The PEM-encoded root certificates as a byte string.
66
65
  If provided, a secure connection using the certificates will be
67
66
  established to an SSL-enabled Flower server.
68
67
  """
69
68
 
70
- def __init__(
69
+ def __init__( # pylint: disable=too-many-arguments
71
70
  self,
71
+ run_id: int,
72
72
  driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
73
73
  root_certificates: Optional[bytes] = None,
74
74
  ) -> None:
75
- self.driver_service_address = driver_service_address
76
- self.root_certificates = root_certificates
77
- self.channel: Optional[grpc.Channel] = None
78
- self.stub: Optional[DriverStub] = None
75
+ self._run_id = run_id
76
+ self._addr = driver_service_address
77
+ self._cert = root_certificates
78
+ self._run: Optional[Run] = None
79
+ self._grpc_stub: Optional[DriverStub] = None
80
+ self._channel: Optional[grpc.Channel] = None
81
+ self.node = Node(node_id=0, anonymous=True)
82
+
83
+ @property
84
+ def _is_connected(self) -> bool:
85
+ """Check if connected to the Driver API server."""
86
+ return self._channel is not None
79
87
 
80
- def is_connected(self) -> bool:
81
- """Return True if connected to the Driver API server, otherwise False."""
82
- return self.channel is not None
88
+ def _connect(self) -> None:
89
+ """Connect to the Driver API.
83
90
 
84
- def connect(self) -> None:
85
- """Connect to the Driver API."""
91
+ This will not call GetRun.
92
+ """
86
93
  event(EventType.DRIVER_CONNECT)
87
- if self.channel is not None or self.stub is not None:
94
+ if self._is_connected:
88
95
  log(WARNING, "Already connected")
89
96
  return
90
- self.channel = create_channel(
91
- server_address=self.driver_service_address,
92
- insecure=(self.root_certificates is None),
93
- root_certificates=self.root_certificates,
97
+ self._channel = create_channel(
98
+ server_address=self._addr,
99
+ insecure=(self._cert is None),
100
+ root_certificates=self._cert,
94
101
  )
95
- self.stub = DriverStub(self.channel)
96
- log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
102
+ self._grpc_stub = DriverStub(self._channel)
103
+ log(DEBUG, "[Driver] Connected to %s", self._addr)
97
104
 
98
- def disconnect(self) -> None:
105
+ def _disconnect(self) -> None:
99
106
  """Disconnect from the Driver API."""
100
107
  event(EventType.DRIVER_DISCONNECT)
101
- if self.channel is None or self.stub is None:
108
+ if not self._is_connected:
102
109
  log(DEBUG, "Already disconnected")
103
110
  return
104
- channel = self.channel
105
- self.channel = None
106
- self.stub = None
111
+ channel: grpc.Channel = self._channel
112
+ self._channel = None
113
+ self._grpc_stub = None
107
114
  channel.close()
108
115
  log(DEBUG, "[Driver] Disconnected")
109
116
 
110
- def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
111
- """Request for run ID."""
112
- # Check if channel is open
113
- if self.stub is None:
114
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
115
- raise ConnectionError("`GrpcDriverStub` instance not connected")
116
-
117
- # Call Driver API
118
- res: CreateRunResponse = self.stub.CreateRun(request=req)
119
- return res
120
-
121
- def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
- """Get run information."""
123
- # Check if channel is open
124
- if self.stub is None:
125
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
- raise ConnectionError("`GrpcDriverStub` instance not connected")
127
-
128
- # Call gRPC Driver API
129
- res: GetRunResponse = self.stub.GetRun(request=req)
130
- return res
131
-
132
- def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
133
- """Get client IDs."""
134
- # Check if channel is open
135
- if self.stub is None:
136
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
137
- raise ConnectionError("`GrpcDriverStub` instance not connected")
138
-
139
- # Call gRPC Driver API
140
- res: GetNodesResponse = self.stub.GetNodes(request=req)
141
- return res
142
-
143
- def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
144
- """Schedule tasks."""
145
- # Check if channel is open
146
- if self.stub is None:
147
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
148
- raise ConnectionError("`GrpcDriverStub` instance not connected")
149
-
150
- # Call gRPC Driver API
151
- res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
152
- return res
153
-
154
- def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
155
- """Get task results."""
156
- # Check if channel is open
157
- if self.stub is None:
158
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
159
- raise ConnectionError("`GrpcDriverStub` instance not connected")
160
-
161
- # Call Driver API
162
- res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
163
- return res
164
-
165
-
166
- class GrpcDriver(Driver):
167
- """`Driver` class provides an interface to the Driver API.
168
-
169
- Parameters
170
- ----------
171
- run_id : int
172
- The identifier of the run.
173
- stub : Optional[GrpcDriverStub] (default: None)
174
- The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
- If None, an instance connected to "[::]:9091" will be created.
176
- """
177
-
178
- def __init__( # pylint: disable=too-many-arguments
179
- self,
180
- run_id: int,
181
- stub: Optional[GrpcDriverStub] = None,
182
- ) -> None:
183
- self._run_id = run_id
184
- self._run: Optional[Run] = None
185
- self.stub = stub if stub is not None else GrpcDriverStub()
186
- self.node = Node(node_id=0, anonymous=True)
117
+ def _init_run(self) -> None:
118
+ # Check if is initialized
119
+ if self._run is not None:
120
+ return
121
+ # Get the run info
122
+ req = GetRunRequest(run_id=self._run_id)
123
+ res: GetRunResponse = self._stub.GetRun(req)
124
+ if not res.HasField("run"):
125
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
126
+ self._run = Run(
127
+ run_id=res.run.run_id,
128
+ fab_id=res.run.fab_id,
129
+ fab_version=res.run.fab_version,
130
+ override_config=dict(res.run.override_config.items()),
131
+ )
187
132
 
188
133
  @property
189
134
  def run(self) -> Run:
190
135
  """Run information."""
191
- self._get_stub_and_run_id()
192
- return Run(**vars(cast(Run, self._run)))
136
+ self._init_run()
137
+ return Run(**vars(self._run))
193
138
 
194
- def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
- # Check if is initialized
196
- if self._run is None:
197
- # Connect
198
- if not self.stub.is_connected():
199
- self.stub.connect()
200
- # Get the run info
201
- req = GetRunRequest(run_id=self._run_id)
202
- res = self.stub.get_run(req)
203
- if not res.HasField("run"):
204
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
- self._run = Run(
206
- run_id=res.run.run_id,
207
- fab_id=res.run.fab_id,
208
- fab_version=res.run.fab_version,
209
- )
210
-
211
- return self.stub, self._run.run_id
139
+ @property
140
+ def _stub(self) -> DriverStub:
141
+ """Driver stub."""
142
+ if not self._is_connected:
143
+ self._connect()
144
+ return cast(DriverStub, self._grpc_stub)
212
145
 
213
146
  def _check_message(self, message: Message) -> None:
214
147
  # Check if the message is valid
215
148
  if not (
216
- message.metadata.run_id == cast(Run, self._run).run_id
149
+ # Assume self._run being initialized
150
+ message.metadata.run_id == self._run_id
217
151
  and message.metadata.src_node_id == self.node.node_id
218
152
  and message.metadata.message_id == ""
219
153
  and message.metadata.reply_to_message == ""
@@ -234,7 +168,7 @@ class GrpcDriver(Driver):
234
168
  This method constructs a new `Message` with given content and metadata.
235
169
  The `run_id` and `src_node_id` will be set automatically.
236
170
  """
237
- _, run_id = self._get_stub_and_run_id()
171
+ self._init_run()
238
172
  if ttl:
239
173
  warnings.warn(
240
174
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -245,7 +179,7 @@ class GrpcDriver(Driver):
245
179
 
246
180
  ttl_ = DEFAULT_TTL if ttl is None else ttl
247
181
  metadata = Metadata(
248
- run_id=run_id,
182
+ run_id=self._run_id,
249
183
  message_id="", # Will be set by the server
250
184
  src_node_id=self.node.node_id,
251
185
  dst_node_id=dst_node_id,
@@ -258,9 +192,11 @@ class GrpcDriver(Driver):
258
192
 
259
193
  def get_node_ids(self) -> List[int]:
260
194
  """Get node IDs."""
261
- stub, run_id = self._get_stub_and_run_id()
195
+ self._init_run()
262
196
  # Call GrpcDriverStub method
263
- res = stub.get_nodes(GetNodesRequest(run_id=run_id))
197
+ res: GetNodesResponse = self._stub.GetNodes(
198
+ GetNodesRequest(run_id=self._run_id)
199
+ )
264
200
  return [node.node_id for node in res.nodes]
265
201
 
266
202
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -269,7 +205,7 @@ class GrpcDriver(Driver):
269
205
  This method takes an iterable of messages and sends each message
270
206
  to the node specified in `dst_node_id`.
271
207
  """
272
- stub, _ = self._get_stub_and_run_id()
208
+ self._init_run()
273
209
  # Construct TaskIns
274
210
  task_ins_list: List[TaskIns] = []
275
211
  for msg in messages:
@@ -280,7 +216,9 @@ class GrpcDriver(Driver):
280
216
  # Add to list
281
217
  task_ins_list.append(taskins)
282
218
  # Call GrpcDriverStub method
283
- res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
219
+ res: PushTaskInsResponse = self._stub.PushTaskIns(
220
+ PushTaskInsRequest(task_ins_list=task_ins_list)
221
+ )
284
222
  return list(res.task_ids)
285
223
 
286
224
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -289,9 +227,9 @@ class GrpcDriver(Driver):
289
227
  This method is used to collect messages from the SuperLink that correspond to a
290
228
  set of given message IDs.
291
229
  """
292
- stub, _ = self._get_stub_and_run_id()
230
+ self._init_run()
293
231
  # Pull TaskRes
294
- res = stub.pull_task_res(
232
+ res: PullTaskResResponse = self._stub.PullTaskRes(
295
233
  PullTaskResRequest(node=self.node, task_ids=message_ids)
296
234
  )
297
235
  # Convert TaskRes to Message
@@ -331,7 +269,7 @@ class GrpcDriver(Driver):
331
269
  def close(self) -> None:
332
270
  """Disconnect from the SuperLink if connected."""
333
271
  # Check if `connect` was called before
334
- if not self.stub.is_connected():
272
+ if not self._is_connected:
335
273
  return
336
274
  # Disconnect
337
- self.stub.disconnect()
275
+ self._disconnect()
@@ -25,10 +25,13 @@ from flwr.common import Context, EventType, RecordSet, event
25
25
  from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
26
26
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
27
27
  from flwr.common.object_ref import load_app
28
- from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
28
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
+ CreateRunRequest,
30
+ CreateRunResponse,
31
+ )
29
32
 
30
33
  from .driver import Driver
31
- from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
34
+ from .driver.grpc_driver import GrpcDriver
32
35
  from .server_app import LoadServerAppError, ServerApp
33
36
 
34
37
  ADDRESS_DRIVER_API = "0.0.0.0:9091"
@@ -144,22 +147,27 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
144
147
  "For more details, use: ``flower-server-app -h``"
145
148
  )
146
149
 
147
- stub = GrpcDriverStub(
148
- driver_service_address=args.superlink, root_certificates=root_certificates
149
- )
150
+ # Initialize GrpcDriver
150
151
  if args.run_id is not None:
151
152
  # User provided `--run-id`, but not `server-app`
152
- run_id = args.run_id
153
+ driver = GrpcDriver(
154
+ run_id=args.run_id,
155
+ driver_service_address=args.superlink,
156
+ root_certificates=root_certificates,
157
+ )
153
158
  else:
154
159
  # User provided `server-app`, but not `--run-id`
155
160
  # Create run if run_id is not provided
156
- stub.connect()
161
+ driver = GrpcDriver(
162
+ run_id=0, # Will be overwritten
163
+ driver_service_address=args.superlink,
164
+ root_certificates=root_certificates,
165
+ )
166
+ # Create run
157
167
  req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
158
- res = stub.create_run(req)
159
- run_id = res.run_id
160
-
161
- # Initialize GrpcDriver
162
- driver = GrpcDriver(run_id=run_id, stub=stub)
168
+ res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
169
+ # Overwrite driver._run_id
170
+ driver._run_id = res.run_id # pylint: disable=W0212
163
171
 
164
172
  # Dynamically obtain ServerApp path based on run_id
165
173
  if args.run_id is not None:
@@ -69,7 +69,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
69
69
  """Create run ID."""
70
70
  log(DEBUG, "DriverServicer.CreateRun")
71
71
  state: State = self.state_factory.state()
72
- run_id = state.create_run(request.fab_id, request.fab_version)
72
+ run_id = state.create_run(
73
+ request.fab_id,
74
+ request.fab_version,
75
+ dict(request.override_config.items()),
76
+ )
73
77
  return CreateRunResponse(run_id=run_id)
74
78
 
75
79
  def PushTaskIns(
@@ -275,7 +275,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
275
275
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
276
276
  return self.public_key_to_node_id.get(client_public_key)
277
277
 
278
- def create_run(self, fab_id: str, fab_version: str) -> int:
278
+ def create_run(
279
+ self,
280
+ fab_id: str,
281
+ fab_version: str,
282
+ override_config: Dict[str, str],
283
+ ) -> int:
279
284
  """Create a new run for the specified `fab_id` and `fab_version`."""
280
285
  # Sample a random int64 as run_id
281
286
  with self.lock:
@@ -283,7 +288,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
283
288
 
284
289
  if run_id not in self.run_ids:
285
290
  self.run_ids[run_id] = Run(
286
- run_id=run_id, fab_id=fab_id, fab_version=fab_version
291
+ run_id=run_id,
292
+ fab_id=fab_id,
293
+ fab_version=fab_version,
294
+ override_config=override_config,
287
295
  )
288
296
  return run_id
289
297
  log(ERROR, "Unexpected run creation failure.")
@@ -15,6 +15,7 @@
15
15
  """SQLite based implemenation of server state."""
16
16
 
17
17
 
18
+ import json
18
19
  import re
19
20
  import sqlite3
20
21
  import time
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
61
62
 
62
63
  SQL_CREATE_TABLE_RUN = """
63
64
  CREATE TABLE IF NOT EXISTS run(
64
- run_id INTEGER UNIQUE,
65
- fab_id TEXT,
66
- fab_version TEXT
65
+ run_id INTEGER UNIQUE,
66
+ fab_id TEXT,
67
+ fab_version TEXT,
68
+ override_config TEXT
67
69
  );
68
70
  """
69
71
 
@@ -613,7 +615,12 @@ class SqliteState(State): # pylint: disable=R0904
613
615
  return node_id
614
616
  return None
615
617
 
616
- def create_run(self, fab_id: str, fab_version: str) -> int:
618
+ def create_run(
619
+ self,
620
+ fab_id: str,
621
+ fab_version: str,
622
+ override_config: Dict[str, str],
623
+ ) -> int:
617
624
  """Create a new run for the specified `fab_id` and `fab_version`."""
618
625
  # Sample a random int64 as run_id
619
626
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
@@ -622,8 +629,13 @@ class SqliteState(State): # pylint: disable=R0904
622
629
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
623
630
  # If run_id does not exist
624
631
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
625
- query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
626
- self.query(query, (run_id, fab_id, fab_version))
632
+ query = (
633
+ "INSERT INTO run (run_id, fab_id, fab_version, override_config)"
634
+ "VALUES (?, ?, ?, ?);"
635
+ )
636
+ self.query(
637
+ query, (run_id, fab_id, fab_version, json.dumps(override_config))
638
+ )
627
639
  return run_id
628
640
  log(ERROR, "Unexpected run creation failure.")
629
641
  return 0
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
687
699
  try:
688
700
  row = self.query(query, (run_id,))[0]
689
701
  return Run(
690
- run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
702
+ run_id=run_id,
703
+ fab_id=row["fab_id"],
704
+ fab_version=row["fab_version"],
705
+ override_config=json.loads(row["override_config"]),
691
706
  )
692
707
  except sqlite3.IntegrityError:
693
708
  log(ERROR, "`run_id` does not exist.")
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import List, Optional, Set
19
+ from typing import Dict, List, Optional, Set
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common.typing import Run
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
157
157
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
158
158
 
159
159
  @abc.abstractmethod
160
- def create_run(self, fab_id: str, fab_version: str) -> int:
160
+ def create_run(
161
+ self,
162
+ fab_id: str,
163
+ fab_version: str,
164
+ override_config: Dict[str, str],
165
+ ) -> int:
161
166
  """Create a new run for the specified `fab_id` and `fab_version`."""
162
167
 
163
168
  @abc.abstractmethod
@@ -209,7 +209,7 @@ def _main_loop(
209
209
  serverapp_th = None
210
210
  try:
211
211
  # Create run (with empty fab_id and fab_version)
212
- run_id_ = state_factory.state().create_run("", "")
212
+ run_id_ = state_factory.state().create_run("", "", {})
213
213
 
214
214
  if run_id:
215
215
  _override_run_id(state_factory, run_id_to_replace=run_id_, run_id=run_id)
flwr/superexec/app.py CHANGED
@@ -77,6 +77,7 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser:
77
77
  parser.add_argument(
78
78
  "executor",
79
79
  help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.",
80
+ default="flwr.superexec.deployment:executor",
80
81
  )
81
82
  parser.add_argument(
82
83
  "--address",
@@ -17,7 +17,7 @@
17
17
  import subprocess
18
18
  import sys
19
19
  from logging import ERROR, INFO
20
- from typing import Optional
20
+ from typing import Dict, Optional
21
21
 
22
22
  from typing_extensions import override
23
23
 
@@ -53,18 +53,29 @@ class DeploymentEngine(Executor):
53
53
  )
54
54
  self.stub = DriverStub(channel)
55
55
 
56
- def _create_run(self, fab_id: str, fab_version: str) -> int:
56
+ def _create_run(
57
+ self,
58
+ fab_id: str,
59
+ fab_version: str,
60
+ override_config: Dict[str, str],
61
+ ) -> int:
57
62
  if self.stub is None:
58
63
  self._connect()
59
64
 
60
65
  assert self.stub is not None
61
66
 
62
- req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
67
+ req = CreateRunRequest(
68
+ fab_id=fab_id,
69
+ fab_version=fab_version,
70
+ override_config=override_config,
71
+ )
63
72
  res = self.stub.CreateRun(request=req)
64
73
  return int(res.run_id)
65
74
 
66
75
  @override
67
- def start_run(self, fab_file: bytes) -> Optional[RunTracker]:
76
+ def start_run(
77
+ self, fab_file: bytes, override_config: Dict[str, str]
78
+ ) -> Optional[RunTracker]:
68
79
  """Start run using the Flower Deployment Engine."""
69
80
  try:
70
81
  # Install FAB to flwr dir
@@ -79,7 +90,7 @@ class DeploymentEngine(Executor):
79
90
  )
80
91
 
81
92
  # Call SuperLink to create run
82
- run_id: int = self._create_run(fab_id, fab_version)
93
+ run_id: int = self._create_run(fab_id, fab_version, override_config)
83
94
  log(INFO, "Created run %s", str(run_id))
84
95
 
85
96
  # Start ServerApp
@@ -45,7 +45,10 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
45
45
  """Create run ID."""
46
46
  log(INFO, "ExecServicer.StartRun")
47
47
 
48
- run = self.executor.start_run(request.fab_file)
48
+ run = self.executor.start_run(
49
+ request.fab_file,
50
+ dict(request.override_config.items()),
51
+ )
49
52
 
50
53
  if run is None:
51
54
  log(ERROR, "Executor failed to start run")
@@ -17,7 +17,7 @@
17
17
  from abc import ABC, abstractmethod
18
18
  from dataclasses import dataclass
19
19
  from subprocess import Popen
20
- from typing import Optional
20
+ from typing import Dict, Optional
21
21
 
22
22
 
23
23
  @dataclass
@@ -33,8 +33,7 @@ class Executor(ABC):
33
33
 
34
34
  @abstractmethod
35
35
  def start_run(
36
- self,
37
- fab_file: bytes,
36
+ self, fab_file: bytes, override_config: Dict[str, str]
38
37
  ) -> Optional[RunTracker]:
39
38
  """Start a run using the given Flower FAB ID and version.
40
39
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.10.0.dev20240707
3
+ Version: 1.10.0.dev20240708
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0