flwr-nightly 1.10.0.dev20240705__py3-none-any.whl → 1.10.0.dev20240708__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (37) hide show
  1. flwr/client/app.py +9 -6
  2. flwr/client/grpc_adapter_client/connection.py +2 -1
  3. flwr/client/grpc_client/connection.py +2 -1
  4. flwr/client/grpc_rere_client/connection.py +9 -3
  5. flwr/client/rest_client/connection.py +10 -4
  6. flwr/common/config.py +75 -2
  7. flwr/common/context.py +8 -2
  8. flwr/common/logger.py +13 -0
  9. flwr/common/typing.py +1 -0
  10. flwr/proto/common_pb2.py +24 -0
  11. flwr/proto/common_pb2.pyi +7 -0
  12. flwr/proto/common_pb2_grpc.py +4 -0
  13. flwr/proto/common_pb2_grpc.pyi +4 -0
  14. flwr/proto/driver_pb2.py +23 -19
  15. flwr/proto/driver_pb2.pyi +18 -1
  16. flwr/proto/exec_pb2.py +15 -11
  17. flwr/proto/exec_pb2.pyi +19 -1
  18. flwr/proto/run_pb2.py +11 -7
  19. flwr/proto/run_pb2.pyi +19 -1
  20. flwr/server/driver/grpc_driver.py +77 -139
  21. flwr/server/run_serverapp.py +20 -12
  22. flwr/server/superlink/driver/driver_servicer.py +5 -1
  23. flwr/server/superlink/state/in_memory_state.py +10 -2
  24. flwr/server/superlink/state/sqlite_state.py +22 -7
  25. flwr/server/superlink/state/state.py +7 -2
  26. flwr/simulation/app.py +39 -24
  27. flwr/simulation/ray_transport/ray_client_proxy.py +15 -7
  28. flwr/simulation/run_simulation.py +1 -1
  29. flwr/superexec/app.py +1 -0
  30. flwr/superexec/deployment.py +16 -5
  31. flwr/superexec/exec_servicer.py +4 -1
  32. flwr/superexec/executor.py +2 -3
  33. {flwr_nightly-1.10.0.dev20240705.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/METADATA +1 -1
  34. {flwr_nightly-1.10.0.dev20240705.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/RECORD +37 -33
  35. {flwr_nightly-1.10.0.dev20240705.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/LICENSE +0 -0
  36. {flwr_nightly-1.10.0.dev20240705.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/WHEEL +0 -0
  37. {flwr_nightly-1.10.0.dev20240705.dist-info → flwr_nightly-1.10.0.dev20240708.dist-info}/entry_points.txt +0 -0
@@ -16,8 +16,8 @@
16
16
 
17
17
  import time
18
18
  import warnings
19
- from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple, cast
19
+ from logging import DEBUG, WARNING
20
+ from typing import Iterable, List, Optional, cast
21
21
 
22
22
  import grpc
23
23
 
@@ -27,8 +27,6 @@ from flwr.common.logger import log
27
27
  from flwr.common.serde import message_from_taskres, message_to_taskins
28
28
  from flwr.common.typing import Run
29
29
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
30
- CreateRunRequest,
31
- CreateRunResponse,
32
30
  GetNodesRequest,
33
31
  GetNodesResponse,
34
32
  PullTaskResRequest,
@@ -53,167 +51,103 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
53
51
  """
54
52
 
55
53
 
56
- class GrpcDriverStub:
57
- """`GrpcDriverStub` provides access to the gRPC Driver API/service.
54
+ class GrpcDriver(Driver):
55
+ """`GrpcDriver` provides an interface to the Driver API.
58
56
 
59
57
  Parameters
60
58
  ----------
61
- driver_service_address : Optional[str]
62
- The IPv4 or IPv6 address of the Driver API server.
63
- Defaults to `"[::]:9091"`.
59
+ run_id : int
60
+ The identifier of the run.
61
+ driver_service_address : str (default: "[::]:9091")
62
+ The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
64
63
  root_certificates : Optional[bytes] (default: None)
65
64
  The PEM-encoded root certificates as a byte string.
66
65
  If provided, a secure connection using the certificates will be
67
66
  established to an SSL-enabled Flower server.
68
67
  """
69
68
 
70
- def __init__(
69
+ def __init__( # pylint: disable=too-many-arguments
71
70
  self,
71
+ run_id: int,
72
72
  driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
73
73
  root_certificates: Optional[bytes] = None,
74
74
  ) -> None:
75
- self.driver_service_address = driver_service_address
76
- self.root_certificates = root_certificates
77
- self.channel: Optional[grpc.Channel] = None
78
- self.stub: Optional[DriverStub] = None
75
+ self._run_id = run_id
76
+ self._addr = driver_service_address
77
+ self._cert = root_certificates
78
+ self._run: Optional[Run] = None
79
+ self._grpc_stub: Optional[DriverStub] = None
80
+ self._channel: Optional[grpc.Channel] = None
81
+ self.node = Node(node_id=0, anonymous=True)
82
+
83
+ @property
84
+ def _is_connected(self) -> bool:
85
+ """Check if connected to the Driver API server."""
86
+ return self._channel is not None
79
87
 
80
- def is_connected(self) -> bool:
81
- """Return True if connected to the Driver API server, otherwise False."""
82
- return self.channel is not None
88
+ def _connect(self) -> None:
89
+ """Connect to the Driver API.
83
90
 
84
- def connect(self) -> None:
85
- """Connect to the Driver API."""
91
+ This will not call GetRun.
92
+ """
86
93
  event(EventType.DRIVER_CONNECT)
87
- if self.channel is not None or self.stub is not None:
94
+ if self._is_connected:
88
95
  log(WARNING, "Already connected")
89
96
  return
90
- self.channel = create_channel(
91
- server_address=self.driver_service_address,
92
- insecure=(self.root_certificates is None),
93
- root_certificates=self.root_certificates,
97
+ self._channel = create_channel(
98
+ server_address=self._addr,
99
+ insecure=(self._cert is None),
100
+ root_certificates=self._cert,
94
101
  )
95
- self.stub = DriverStub(self.channel)
96
- log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
102
+ self._grpc_stub = DriverStub(self._channel)
103
+ log(DEBUG, "[Driver] Connected to %s", self._addr)
97
104
 
98
- def disconnect(self) -> None:
105
+ def _disconnect(self) -> None:
99
106
  """Disconnect from the Driver API."""
100
107
  event(EventType.DRIVER_DISCONNECT)
101
- if self.channel is None or self.stub is None:
108
+ if not self._is_connected:
102
109
  log(DEBUG, "Already disconnected")
103
110
  return
104
- channel = self.channel
105
- self.channel = None
106
- self.stub = None
111
+ channel: grpc.Channel = self._channel
112
+ self._channel = None
113
+ self._grpc_stub = None
107
114
  channel.close()
108
115
  log(DEBUG, "[Driver] Disconnected")
109
116
 
110
- def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
111
- """Request for run ID."""
112
- # Check if channel is open
113
- if self.stub is None:
114
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
115
- raise ConnectionError("`GrpcDriverStub` instance not connected")
116
-
117
- # Call Driver API
118
- res: CreateRunResponse = self.stub.CreateRun(request=req)
119
- return res
120
-
121
- def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
- """Get run information."""
123
- # Check if channel is open
124
- if self.stub is None:
125
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
- raise ConnectionError("`GrpcDriverStub` instance not connected")
127
-
128
- # Call gRPC Driver API
129
- res: GetRunResponse = self.stub.GetRun(request=req)
130
- return res
131
-
132
- def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
133
- """Get client IDs."""
134
- # Check if channel is open
135
- if self.stub is None:
136
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
137
- raise ConnectionError("`GrpcDriverStub` instance not connected")
138
-
139
- # Call gRPC Driver API
140
- res: GetNodesResponse = self.stub.GetNodes(request=req)
141
- return res
142
-
143
- def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
144
- """Schedule tasks."""
145
- # Check if channel is open
146
- if self.stub is None:
147
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
148
- raise ConnectionError("`GrpcDriverStub` instance not connected")
149
-
150
- # Call gRPC Driver API
151
- res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
152
- return res
153
-
154
- def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
155
- """Get task results."""
156
- # Check if channel is open
157
- if self.stub is None:
158
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
159
- raise ConnectionError("`GrpcDriverStub` instance not connected")
160
-
161
- # Call Driver API
162
- res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
163
- return res
164
-
165
-
166
- class GrpcDriver(Driver):
167
- """`Driver` class provides an interface to the Driver API.
168
-
169
- Parameters
170
- ----------
171
- run_id : int
172
- The identifier of the run.
173
- stub : Optional[GrpcDriverStub] (default: None)
174
- The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
- If None, an instance connected to "[::]:9091" will be created.
176
- """
177
-
178
- def __init__( # pylint: disable=too-many-arguments
179
- self,
180
- run_id: int,
181
- stub: Optional[GrpcDriverStub] = None,
182
- ) -> None:
183
- self._run_id = run_id
184
- self._run: Optional[Run] = None
185
- self.stub = stub if stub is not None else GrpcDriverStub()
186
- self.node = Node(node_id=0, anonymous=True)
117
+ def _init_run(self) -> None:
118
+ # Check if is initialized
119
+ if self._run is not None:
120
+ return
121
+ # Get the run info
122
+ req = GetRunRequest(run_id=self._run_id)
123
+ res: GetRunResponse = self._stub.GetRun(req)
124
+ if not res.HasField("run"):
125
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
126
+ self._run = Run(
127
+ run_id=res.run.run_id,
128
+ fab_id=res.run.fab_id,
129
+ fab_version=res.run.fab_version,
130
+ override_config=dict(res.run.override_config.items()),
131
+ )
187
132
 
188
133
  @property
189
134
  def run(self) -> Run:
190
135
  """Run information."""
191
- self._get_stub_and_run_id()
192
- return Run(**vars(cast(Run, self._run)))
136
+ self._init_run()
137
+ return Run(**vars(self._run))
193
138
 
194
- def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
- # Check if is initialized
196
- if self._run is None:
197
- # Connect
198
- if not self.stub.is_connected():
199
- self.stub.connect()
200
- # Get the run info
201
- req = GetRunRequest(run_id=self._run_id)
202
- res = self.stub.get_run(req)
203
- if not res.HasField("run"):
204
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
- self._run = Run(
206
- run_id=res.run.run_id,
207
- fab_id=res.run.fab_id,
208
- fab_version=res.run.fab_version,
209
- )
210
-
211
- return self.stub, self._run.run_id
139
+ @property
140
+ def _stub(self) -> DriverStub:
141
+ """Driver stub."""
142
+ if not self._is_connected:
143
+ self._connect()
144
+ return cast(DriverStub, self._grpc_stub)
212
145
 
213
146
  def _check_message(self, message: Message) -> None:
214
147
  # Check if the message is valid
215
148
  if not (
216
- message.metadata.run_id == cast(Run, self._run).run_id
149
+ # Assume self._run being initialized
150
+ message.metadata.run_id == self._run_id
217
151
  and message.metadata.src_node_id == self.node.node_id
218
152
  and message.metadata.message_id == ""
219
153
  and message.metadata.reply_to_message == ""
@@ -234,7 +168,7 @@ class GrpcDriver(Driver):
234
168
  This method constructs a new `Message` with given content and metadata.
235
169
  The `run_id` and `src_node_id` will be set automatically.
236
170
  """
237
- _, run_id = self._get_stub_and_run_id()
171
+ self._init_run()
238
172
  if ttl:
239
173
  warnings.warn(
240
174
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -245,7 +179,7 @@ class GrpcDriver(Driver):
245
179
 
246
180
  ttl_ = DEFAULT_TTL if ttl is None else ttl
247
181
  metadata = Metadata(
248
- run_id=run_id,
182
+ run_id=self._run_id,
249
183
  message_id="", # Will be set by the server
250
184
  src_node_id=self.node.node_id,
251
185
  dst_node_id=dst_node_id,
@@ -258,9 +192,11 @@ class GrpcDriver(Driver):
258
192
 
259
193
  def get_node_ids(self) -> List[int]:
260
194
  """Get node IDs."""
261
- stub, run_id = self._get_stub_and_run_id()
195
+ self._init_run()
262
196
  # Call GrpcDriverStub method
263
- res = stub.get_nodes(GetNodesRequest(run_id=run_id))
197
+ res: GetNodesResponse = self._stub.GetNodes(
198
+ GetNodesRequest(run_id=self._run_id)
199
+ )
264
200
  return [node.node_id for node in res.nodes]
265
201
 
266
202
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -269,7 +205,7 @@ class GrpcDriver(Driver):
269
205
  This method takes an iterable of messages and sends each message
270
206
  to the node specified in `dst_node_id`.
271
207
  """
272
- stub, _ = self._get_stub_and_run_id()
208
+ self._init_run()
273
209
  # Construct TaskIns
274
210
  task_ins_list: List[TaskIns] = []
275
211
  for msg in messages:
@@ -280,7 +216,9 @@ class GrpcDriver(Driver):
280
216
  # Add to list
281
217
  task_ins_list.append(taskins)
282
218
  # Call GrpcDriverStub method
283
- res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
219
+ res: PushTaskInsResponse = self._stub.PushTaskIns(
220
+ PushTaskInsRequest(task_ins_list=task_ins_list)
221
+ )
284
222
  return list(res.task_ids)
285
223
 
286
224
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -289,9 +227,9 @@ class GrpcDriver(Driver):
289
227
  This method is used to collect messages from the SuperLink that correspond to a
290
228
  set of given message IDs.
291
229
  """
292
- stub, _ = self._get_stub_and_run_id()
230
+ self._init_run()
293
231
  # Pull TaskRes
294
- res = stub.pull_task_res(
232
+ res: PullTaskResResponse = self._stub.PullTaskRes(
295
233
  PullTaskResRequest(node=self.node, task_ids=message_ids)
296
234
  )
297
235
  # Convert TaskRes to Message
@@ -331,7 +269,7 @@ class GrpcDriver(Driver):
331
269
  def close(self) -> None:
332
270
  """Disconnect from the SuperLink if connected."""
333
271
  # Check if `connect` was called before
334
- if not self.stub.is_connected():
272
+ if not self._is_connected:
335
273
  return
336
274
  # Disconnect
337
- self.stub.disconnect()
275
+ self._disconnect()
@@ -25,10 +25,13 @@ from flwr.common import Context, EventType, RecordSet, event
25
25
  from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
26
26
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
27
27
  from flwr.common.object_ref import load_app
28
- from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
28
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
+ CreateRunRequest,
30
+ CreateRunResponse,
31
+ )
29
32
 
30
33
  from .driver import Driver
31
- from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
34
+ from .driver.grpc_driver import GrpcDriver
32
35
  from .server_app import LoadServerAppError, ServerApp
33
36
 
34
37
  ADDRESS_DRIVER_API = "0.0.0.0:9091"
@@ -144,22 +147,27 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
144
147
  "For more details, use: ``flower-server-app -h``"
145
148
  )
146
149
 
147
- stub = GrpcDriverStub(
148
- driver_service_address=args.superlink, root_certificates=root_certificates
149
- )
150
+ # Initialize GrpcDriver
150
151
  if args.run_id is not None:
151
152
  # User provided `--run-id`, but not `server-app`
152
- run_id = args.run_id
153
+ driver = GrpcDriver(
154
+ run_id=args.run_id,
155
+ driver_service_address=args.superlink,
156
+ root_certificates=root_certificates,
157
+ )
153
158
  else:
154
159
  # User provided `server-app`, but not `--run-id`
155
160
  # Create run if run_id is not provided
156
- stub.connect()
161
+ driver = GrpcDriver(
162
+ run_id=0, # Will be overwritten
163
+ driver_service_address=args.superlink,
164
+ root_certificates=root_certificates,
165
+ )
166
+ # Create run
157
167
  req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
158
- res = stub.create_run(req)
159
- run_id = res.run_id
160
-
161
- # Initialize GrpcDriver
162
- driver = GrpcDriver(run_id=run_id, stub=stub)
168
+ res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
169
+ # Overwrite driver._run_id
170
+ driver._run_id = res.run_id # pylint: disable=W0212
163
171
 
164
172
  # Dynamically obtain ServerApp path based on run_id
165
173
  if args.run_id is not None:
@@ -69,7 +69,11 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
69
69
  """Create run ID."""
70
70
  log(DEBUG, "DriverServicer.CreateRun")
71
71
  state: State = self.state_factory.state()
72
- run_id = state.create_run(request.fab_id, request.fab_version)
72
+ run_id = state.create_run(
73
+ request.fab_id,
74
+ request.fab_version,
75
+ dict(request.override_config.items()),
76
+ )
73
77
  return CreateRunResponse(run_id=run_id)
74
78
 
75
79
  def PushTaskIns(
@@ -275,7 +275,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
275
275
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
276
276
  return self.public_key_to_node_id.get(client_public_key)
277
277
 
278
- def create_run(self, fab_id: str, fab_version: str) -> int:
278
+ def create_run(
279
+ self,
280
+ fab_id: str,
281
+ fab_version: str,
282
+ override_config: Dict[str, str],
283
+ ) -> int:
279
284
  """Create a new run for the specified `fab_id` and `fab_version`."""
280
285
  # Sample a random int64 as run_id
281
286
  with self.lock:
@@ -283,7 +288,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
283
288
 
284
289
  if run_id not in self.run_ids:
285
290
  self.run_ids[run_id] = Run(
286
- run_id=run_id, fab_id=fab_id, fab_version=fab_version
291
+ run_id=run_id,
292
+ fab_id=fab_id,
293
+ fab_version=fab_version,
294
+ override_config=override_config,
287
295
  )
288
296
  return run_id
289
297
  log(ERROR, "Unexpected run creation failure.")
@@ -15,6 +15,7 @@
15
15
  """SQLite based implemenation of server state."""
16
16
 
17
17
 
18
+ import json
18
19
  import re
19
20
  import sqlite3
20
21
  import time
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
61
62
 
62
63
  SQL_CREATE_TABLE_RUN = """
63
64
  CREATE TABLE IF NOT EXISTS run(
64
- run_id INTEGER UNIQUE,
65
- fab_id TEXT,
66
- fab_version TEXT
65
+ run_id INTEGER UNIQUE,
66
+ fab_id TEXT,
67
+ fab_version TEXT,
68
+ override_config TEXT
67
69
  );
68
70
  """
69
71
 
@@ -613,7 +615,12 @@ class SqliteState(State): # pylint: disable=R0904
613
615
  return node_id
614
616
  return None
615
617
 
616
- def create_run(self, fab_id: str, fab_version: str) -> int:
618
+ def create_run(
619
+ self,
620
+ fab_id: str,
621
+ fab_version: str,
622
+ override_config: Dict[str, str],
623
+ ) -> int:
617
624
  """Create a new run for the specified `fab_id` and `fab_version`."""
618
625
  # Sample a random int64 as run_id
619
626
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
@@ -622,8 +629,13 @@ class SqliteState(State): # pylint: disable=R0904
622
629
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
623
630
  # If run_id does not exist
624
631
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
625
- query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
626
- self.query(query, (run_id, fab_id, fab_version))
632
+ query = (
633
+ "INSERT INTO run (run_id, fab_id, fab_version, override_config)"
634
+ "VALUES (?, ?, ?, ?);"
635
+ )
636
+ self.query(
637
+ query, (run_id, fab_id, fab_version, json.dumps(override_config))
638
+ )
627
639
  return run_id
628
640
  log(ERROR, "Unexpected run creation failure.")
629
641
  return 0
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
687
699
  try:
688
700
  row = self.query(query, (run_id,))[0]
689
701
  return Run(
690
- run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
702
+ run_id=run_id,
703
+ fab_id=row["fab_id"],
704
+ fab_version=row["fab_version"],
705
+ override_config=json.loads(row["override_config"]),
691
706
  )
692
707
  except sqlite3.IntegrityError:
693
708
  log(ERROR, "`run_id` does not exist.")
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import List, Optional, Set
19
+ from typing import Dict, List, Optional, Set
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common.typing import Run
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
157
157
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
158
158
 
159
159
  @abc.abstractmethod
160
- def create_run(self, fab_id: str, fab_version: str) -> int:
160
+ def create_run(
161
+ self,
162
+ fab_id: str,
163
+ fab_version: str,
164
+ override_config: Dict[str, str],
165
+ ) -> int:
161
166
  """Create a new run for the specified `fab_id` and `fab_version`."""
162
167
 
163
168
  @abc.abstractmethod
flwr/simulation/app.py CHANGED
@@ -29,12 +29,14 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
29
29
 
30
30
  from flwr.client import ClientFnExt
31
31
  from flwr.common import EventType, event
32
- from flwr.common.logger import log, set_logger_propagation
32
+ from flwr.common.constant import NODE_ID_NUM_BYTES
33
+ from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature
33
34
  from flwr.server.client_manager import ClientManager
34
35
  from flwr.server.history import History
35
36
  from flwr.server.server import Server, init_defaults, run_fl
36
37
  from flwr.server.server_config import ServerConfig
37
38
  from flwr.server.strategy import Strategy
39
+ from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
38
40
  from flwr.simulation.ray_transport.ray_actor import (
39
41
  ClientAppActor,
40
42
  VirtualClientEngineActor,
@@ -51,7 +53,7 @@ Invalid Arguments in method:
51
53
  `start_simulation(
52
54
  *,
53
55
  client_fn: ClientFn,
54
- num_clients: Optional[int] = None,
56
+ num_clients: int,
55
57
  clients_ids: Optional[List[str]] = None,
56
58
  client_resources: Optional[Dict[str, float]] = None,
57
59
  server: Optional[Server] = None,
@@ -70,13 +72,29 @@ REASON:
70
72
 
71
73
  """
72
74
 
75
+ NodeToPartitionMapping = Dict[int, int]
76
+
77
+
78
+ def _create_node_id_to_partition_mapping(
79
+ num_clients: int,
80
+ ) -> NodeToPartitionMapping:
81
+ """Generate a node_id:partition_id mapping."""
82
+ nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
83
+ for i in range(num_clients):
84
+ while True:
85
+ node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
86
+ if node_id not in nodes_mapping:
87
+ break
88
+ nodes_mapping[node_id] = i
89
+ return nodes_mapping
90
+
73
91
 
74
92
  # pylint: disable=too-many-arguments,too-many-statements,too-many-branches
75
93
  def start_simulation(
76
94
  *,
77
95
  client_fn: ClientFnExt,
78
- num_clients: Optional[int] = None,
79
- clients_ids: Optional[List[str]] = None,
96
+ num_clients: int,
97
+ clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED
80
98
  client_resources: Optional[Dict[str, float]] = None,
81
99
  server: Optional[Server] = None,
82
100
  config: Optional[ServerConfig] = None,
@@ -102,13 +120,14 @@ def start_simulation(
102
120
  (model, dataset, hyperparameters, ...) should be (re-)created in either the
103
121
  call to `client_fn` or the call to any of the client methods (e.g., load
104
122
  evaluation data in the `evaluate` method itself).
105
- num_clients : Optional[int]
106
- The total number of clients in this simulation. This must be set if
107
- `clients_ids` is not set and vice-versa.
123
+ num_clients : int
124
+ The total number of clients in this simulation.
108
125
  clients_ids : Optional[List[str]]
126
+ UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD.
109
127
  List `client_id`s for each client. This is only required if
110
128
  `num_clients` is not set. Setting both `num_clients` and `clients_ids`
111
129
  with `len(clients_ids)` not equal to `num_clients` generates an error.
130
+ Using this argument will raise an error.
112
131
  client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`)
113
132
  CPU and GPU resources for a single client. Supported keys
114
133
  are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by
@@ -158,7 +177,6 @@ def start_simulation(
158
177
  is an advanced feature. For all details, please refer to the Ray documentation:
159
178
  https://docs.ray.io/en/latest/ray-core/scheduling/index.html
160
179
 
161
-
162
180
  Returns
163
181
  -------
164
182
  hist : flwr.server.history.History
@@ -170,6 +188,14 @@ def start_simulation(
170
188
  {"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
171
189
  )
172
190
 
191
+ if clients_ids is not None:
192
+ warn_unsupported_feature(
193
+ "Passing `clients_ids` to `start_simulation` is deprecated and not longer "
194
+ "used by `start_simulation`. Use `num_clients` exclusively instead."
195
+ )
196
+ log(ERROR, "`clients_ids` argument used.")
197
+ sys.exit()
198
+
173
199
  # Set logger propagation
174
200
  loop: Optional[asyncio.AbstractEventLoop] = None
175
201
  try:
@@ -196,20 +222,8 @@ def start_simulation(
196
222
  initialized_config,
197
223
  )
198
224
 
199
- # clients_ids takes precedence
200
- cids: List[str]
201
- if clients_ids is not None:
202
- if (num_clients is not None) and (len(clients_ids) != num_clients):
203
- log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
204
- sys.exit()
205
- else:
206
- cids = clients_ids
207
- else:
208
- if num_clients is None:
209
- log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
210
- sys.exit()
211
- else:
212
- cids = [str(x) for x in range(num_clients)]
225
+ # Create node-id to partition-id mapping
226
+ nodes_mapping = _create_node_id_to_partition_mapping(num_clients)
213
227
 
214
228
  # Default arguments for Ray initialization
215
229
  if not ray_init_args:
@@ -308,10 +322,11 @@ def start_simulation(
308
322
  )
309
323
 
310
324
  # Register one RayClientProxy object for each client with the ClientManager
311
- for cid in cids:
325
+ for node_id, partition_id in nodes_mapping.items():
312
326
  client_proxy = RayActorClientProxy(
313
327
  client_fn=client_fn,
314
- cid=cid,
328
+ node_id=node_id,
329
+ partition_id=partition_id,
315
330
  actor_pool=pool,
316
331
  )
317
332
  initialized_server.client_manager().register(client=client_proxy)