flwr-nightly 1.9.0.dev20240416__py3-none-any.whl → 1.9.0.dev20240420__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (36) hide show
  1. flwr/cli/{flower_toml.py → config_utils.py} +40 -7
  2. flwr/cli/new/new.py +9 -5
  3. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  4. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +56 -0
  5. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +18 -0
  6. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +4 -0
  7. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -0
  8. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +4 -0
  9. flwr/cli/run/run.py +2 -2
  10. flwr/client/__init__.py +2 -0
  11. flwr/client/app.py +7 -53
  12. flwr/client/grpc_client/connection.py +2 -1
  13. flwr/client/grpc_rere_client/connection.py +16 -2
  14. flwr/client/rest_client/connection.py +87 -168
  15. flwr/client/supernode/__init__.py +22 -0
  16. flwr/client/supernode/app.py +107 -0
  17. flwr/common/record/recordset.py +67 -28
  18. flwr/common/telemetry.py +4 -0
  19. flwr/server/app.py +5 -5
  20. flwr/server/compat/app_utils.py +1 -1
  21. flwr/server/compat/driver_client_proxy.py +4 -2
  22. flwr/server/driver/__init__.py +0 -2
  23. flwr/server/driver/abc_driver.py +140 -0
  24. flwr/server/driver/driver.py +124 -21
  25. flwr/server/superlink/driver/driver_servicer.py +1 -1
  26. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
  27. flwr/server/superlink/state/in_memory_state.py +13 -4
  28. flwr/server/superlink/state/sqlite_state.py +17 -5
  29. flwr/server/superlink/state/state.py +21 -3
  30. {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/METADATA +1 -1
  31. {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/RECORD +34 -32
  32. {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/entry_points.txt +1 -0
  33. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  34. flwr/server/driver/grpc_driver.py +0 -129
  35. {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/LICENSE +0 -0
  36. {flwr_nightly-1.9.0.dev20240416.dist-info → flwr_nightly-1.9.0.dev20240420.dist-info}/WHEEL +0 -0
@@ -25,7 +25,7 @@ from flwr.common import serde
25
25
  from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
26
26
  from flwr.server.client_proxy import ClientProxy
27
27
 
28
- from ..driver.grpc_driver import GrpcDriver
28
+ from ..driver.driver import GrpcDriverHelper
29
29
 
30
30
  SLEEP_TIME = 1
31
31
 
@@ -33,7 +33,9 @@ SLEEP_TIME = 1
33
33
  class DriverClientProxy(ClientProxy):
34
34
  """Flower client proxy which delegates work using the Driver API."""
35
35
 
36
- def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: int):
36
+ def __init__(
37
+ self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
38
+ ):
37
39
  super().__init__(str(node_id))
38
40
  self.node_id = node_id
39
41
  self.driver = driver
@@ -16,9 +16,7 @@
16
16
 
17
17
 
18
18
  from .driver import Driver
19
- from .grpc_driver import GrpcDriver
20
19
 
21
20
  __all__ = [
22
21
  "Driver",
23
- "GrpcDriver",
24
22
  ]
@@ -0,0 +1,140 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Driver (abstract base class)."""
16
+
17
+
18
+ from abc import ABC, abstractmethod
19
+ from typing import Iterable, List, Optional
20
+
21
+ from flwr.common import Message, RecordSet
22
+
23
+
24
+ class Driver(ABC):
25
+ """Abstract base Driver class for the Driver API."""
26
+
27
+ @abstractmethod
28
+ def create_message( # pylint: disable=too-many-arguments
29
+ self,
30
+ content: RecordSet,
31
+ message_type: str,
32
+ dst_node_id: int,
33
+ group_id: str,
34
+ ttl: Optional[float] = None,
35
+ ) -> Message:
36
+ """Create a new message with specified parameters.
37
+
38
+ This method constructs a new `Message` with given content and metadata.
39
+ The `run_id` and `src_node_id` will be set automatically.
40
+
41
+ Parameters
42
+ ----------
43
+ content : RecordSet
44
+ The content for the new message. This holds records that are to be sent
45
+ to the destination node.
46
+ message_type : str
47
+ The type of the message, defining the action to be executed on
48
+ the receiving end.
49
+ dst_node_id : int
50
+ The ID of the destination node to which the message is being sent.
51
+ group_id : str
52
+ The ID of the group to which this message is associated. In some settings,
53
+ this is used as the FL round.
54
+ ttl : Optional[float] (default: None)
55
+ Time-to-live for the round trip of this message, i.e., the time from sending
56
+ this message to receiving a reply. It specifies in seconds the duration for
57
+ which the message and its potential reply are considered valid. If unset,
58
+ the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
59
+
60
+ Returns
61
+ -------
62
+ message : Message
63
+ A new `Message` instance with the specified content and metadata.
64
+ """
65
+
66
+ @abstractmethod
67
+ def get_node_ids(self) -> List[int]:
68
+ """Get node IDs."""
69
+
70
+ @abstractmethod
71
+ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
72
+ """Push messages to specified node IDs.
73
+
74
+ This method takes an iterable of messages and sends each message
75
+ to the node specified in `dst_node_id`.
76
+
77
+ Parameters
78
+ ----------
79
+ messages : Iterable[Message]
80
+ An iterable of messages to be sent.
81
+
82
+ Returns
83
+ -------
84
+ message_ids : Iterable[str]
85
+ An iterable of IDs for the messages that were sent, which can be used
86
+ to pull replies.
87
+ """
88
+
89
+ @abstractmethod
90
+ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
91
+ """Pull messages based on message IDs.
92
+
93
+ This method is used to collect messages from the SuperLink
94
+ that correspond to a set of given message IDs.
95
+
96
+ Parameters
97
+ ----------
98
+ message_ids : Iterable[str]
99
+ An iterable of message IDs for which reply messages are to be retrieved.
100
+
101
+ Returns
102
+ -------
103
+ messages : Iterable[Message]
104
+ An iterable of messages received.
105
+ """
106
+
107
+ @abstractmethod
108
+ def send_and_receive(
109
+ self,
110
+ messages: Iterable[Message],
111
+ *,
112
+ timeout: Optional[float] = None,
113
+ ) -> Iterable[Message]:
114
+ """Push messages to specified node IDs and pull the reply messages.
115
+
116
+ This method sends a list of messages to their destination node IDs and then
117
+ waits for the replies. It continues to pull replies until either all
118
+ replies are received or the specified timeout duration is exceeded.
119
+
120
+ Parameters
121
+ ----------
122
+ messages : Iterable[Message]
123
+ An iterable of messages to be sent.
124
+ timeout : Optional[float] (default: None)
125
+ The timeout duration in seconds. If specified, the method will wait for
126
+ replies for this duration. If `None`, there is no time limit and the method
127
+ will wait until replies for all messages are received.
128
+
129
+ Returns
130
+ -------
131
+ replies : Iterable[Message]
132
+ An iterable of reply messages received from the SuperLink.
133
+
134
+ Notes
135
+ -----
136
+ This method uses `push_messages` to send the messages and `pull_messages`
137
+ to collect the replies. If `timeout` is set, the method may not return
138
+ replies for all sent messages. A message remains valid until its TTL,
139
+ which is not affected by `timeout`.
140
+ """
@@ -16,20 +16,121 @@
16
16
 
17
17
  import time
18
18
  import warnings
19
+ from logging import DEBUG, ERROR, WARNING
19
20
  from typing import Iterable, List, Optional, Tuple
20
21
 
21
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
22
+ import grpc
23
+
24
+ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
25
+ from flwr.common.grpc import create_channel
26
+ from flwr.common.logger import log
22
27
  from flwr.common.serde import message_from_taskres, message_to_taskins
23
28
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
24
29
  CreateRunRequest,
30
+ CreateRunResponse,
25
31
  GetNodesRequest,
32
+ GetNodesResponse,
26
33
  PullTaskResRequest,
34
+ PullTaskResResponse,
27
35
  PushTaskInsRequest,
36
+ PushTaskInsResponse,
28
37
  )
38
+ from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
29
39
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
30
40
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
31
41
 
32
- from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
42
+ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
43
+
44
+ ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
45
+ [Driver] Error: Not connected.
46
+
47
+ Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
48
+ `GrpcDriverHelper` methods.
49
+ """
50
+
51
+
52
+ class GrpcDriverHelper:
53
+ """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
54
+
55
+ def __init__(
56
+ self,
57
+ driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
58
+ root_certificates: Optional[bytes] = None,
59
+ ) -> None:
60
+ self.driver_service_address = driver_service_address
61
+ self.root_certificates = root_certificates
62
+ self.channel: Optional[grpc.Channel] = None
63
+ self.stub: Optional[DriverStub] = None
64
+
65
+ def connect(self) -> None:
66
+ """Connect to the Driver API."""
67
+ event(EventType.DRIVER_CONNECT)
68
+ if self.channel is not None or self.stub is not None:
69
+ log(WARNING, "Already connected")
70
+ return
71
+ self.channel = create_channel(
72
+ server_address=self.driver_service_address,
73
+ insecure=(self.root_certificates is None),
74
+ root_certificates=self.root_certificates,
75
+ )
76
+ self.stub = DriverStub(self.channel)
77
+ log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
78
+
79
+ def disconnect(self) -> None:
80
+ """Disconnect from the Driver API."""
81
+ event(EventType.DRIVER_DISCONNECT)
82
+ if self.channel is None or self.stub is None:
83
+ log(DEBUG, "Already disconnected")
84
+ return
85
+ channel = self.channel
86
+ self.channel = None
87
+ self.stub = None
88
+ channel.close()
89
+ log(DEBUG, "[Driver] Disconnected")
90
+
91
+ def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
92
+ """Request for run ID."""
93
+ # Check if channel is open
94
+ if self.stub is None:
95
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
96
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
97
+
98
+ # Call Driver API
99
+ res: CreateRunResponse = self.stub.CreateRun(request=req)
100
+ return res
101
+
102
+ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
103
+ """Get client IDs."""
104
+ # Check if channel is open
105
+ if self.stub is None:
106
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
107
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
108
+
109
+ # Call gRPC Driver API
110
+ res: GetNodesResponse = self.stub.GetNodes(request=req)
111
+ return res
112
+
113
+ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
114
+ """Schedule tasks."""
115
+ # Check if channel is open
116
+ if self.stub is None:
117
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
118
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
119
+
120
+ # Call gRPC Driver API
121
+ res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
122
+ return res
123
+
124
+ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
125
+ """Get task results."""
126
+ # Check if channel is open
127
+ if self.stub is None:
128
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
129
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
130
+
131
+ # Call Driver API
132
+ res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
133
+ return res
33
134
 
34
135
 
35
136
  class Driver:
@@ -57,22 +158,22 @@ class Driver:
57
158
  ) -> None:
58
159
  self.addr = driver_service_address
59
160
  self.root_certificates = root_certificates
60
- self.grpc_driver: Optional[GrpcDriver] = None
161
+ self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
61
162
  self.run_id: Optional[int] = None
62
163
  self.node = Node(node_id=0, anonymous=True)
63
164
 
64
- def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]:
65
- # Check if the GrpcDriver is initialized
66
- if self.grpc_driver is None or self.run_id is None:
165
+ def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
166
+ # Check if the GrpcDriverHelper is initialized
167
+ if self.grpc_driver_helper is None or self.run_id is None:
67
168
  # Connect and create run
68
- self.grpc_driver = GrpcDriver(
169
+ self.grpc_driver_helper = GrpcDriverHelper(
69
170
  driver_service_address=self.addr,
70
171
  root_certificates=self.root_certificates,
71
172
  )
72
- self.grpc_driver.connect()
73
- res = self.grpc_driver.create_run(CreateRunRequest())
173
+ self.grpc_driver_helper.connect()
174
+ res = self.grpc_driver_helper.create_run(CreateRunRequest())
74
175
  self.run_id = res.run_id
75
- return self.grpc_driver, self.run_id
176
+ return self.grpc_driver_helper, self.run_id
76
177
 
77
178
  def _check_message(self, message: Message) -> None:
78
179
  # Check if the message is valid
@@ -122,7 +223,7 @@ class Driver:
122
223
  message : Message
123
224
  A new `Message` instance with the specified content and metadata.
124
225
  """
125
- _, run_id = self._get_grpc_driver_and_run_id()
226
+ _, run_id = self._get_grpc_driver_helper_and_run_id()
126
227
  if ttl:
127
228
  warnings.warn(
128
229
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -146,9 +247,9 @@ class Driver:
146
247
 
147
248
  def get_node_ids(self) -> List[int]:
148
249
  """Get node IDs."""
149
- grpc_driver, run_id = self._get_grpc_driver_and_run_id()
150
- # Call GrpcDriver method
151
- res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id))
250
+ grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
251
+ # Call GrpcDriverHelper method
252
+ res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
152
253
  return [node.node_id for node in res.nodes]
153
254
 
154
255
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -168,7 +269,7 @@ class Driver:
168
269
  An iterable of IDs for the messages that were sent, which can be used
169
270
  to pull replies.
170
271
  """
171
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
272
+ grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
172
273
  # Construct TaskIns
173
274
  task_ins_list: List[TaskIns] = []
174
275
  for msg in messages:
@@ -178,8 +279,10 @@ class Driver:
178
279
  taskins = message_to_taskins(msg)
179
280
  # Add to list
180
281
  task_ins_list.append(taskins)
181
- # Call GrpcDriver method
182
- res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
282
+ # Call GrpcDriverHelper method
283
+ res = grpc_driver_helper.push_task_ins(
284
+ PushTaskInsRequest(task_ins_list=task_ins_list)
285
+ )
183
286
  return list(res.task_ids)
184
287
 
185
288
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -198,7 +301,7 @@ class Driver:
198
301
  messages : Iterable[Message]
199
302
  An iterable of messages received.
200
303
  """
201
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
304
+ grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
202
305
  # Pull TaskRes
203
306
  res = grpc_driver.pull_task_res(
204
307
  PullTaskResRequest(node=self.node, task_ids=message_ids)
@@ -260,8 +363,8 @@ class Driver:
260
363
 
261
364
  def close(self) -> None:
262
365
  """Disconnect from the SuperLink if connected."""
263
- # Check if GrpcDriver is initialized
264
- if self.grpc_driver is None:
366
+ # Check if GrpcDriverHelper is initialized
367
+ if self.grpc_driver_helper is None:
265
368
  return
266
369
  # Disconnect
267
- self.grpc_driver.disconnect()
370
+ self.grpc_driver_helper.disconnect()
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
64
64
  """Create run ID."""
65
65
  log(INFO, "DriverServicer.CreateRun")
66
66
  state: State = self.state_factory.state()
67
- run_id = state.create_run()
67
+ run_id = state.create_run("None/None", "None")
68
68
  return CreateRunResponse(run_id=run_id)
69
69
 
70
70
  def PushTaskIns(
@@ -33,6 +33,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
33
  PushTaskResRequest,
34
34
  PushTaskResResponse,
35
35
  Reconnect,
36
+ Run,
36
37
  )
37
38
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
38
39
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
@@ -109,4 +110,6 @@ def get_run(
109
110
  request: GetRunRequest, state: State # pylint: disable=W0613
110
111
  ) -> GetRunResponse:
111
112
  """Get run information."""
112
- return GetRunResponse()
113
+ run_id, fab_id, fab_version = state.get_run(request.run_id)
114
+ run = Run(run_id=run_id, fab_id=fab_id, fab_version=fab_version)
115
+ return GetRunResponse(run=run)
@@ -36,7 +36,8 @@ class InMemoryState(State):
36
36
  def __init__(self) -> None:
37
37
  # Map node_id to (online_until, ping_interval)
38
38
  self.node_ids: Dict[int, Tuple[float, float]] = {}
39
- self.run_ids: Set[int] = set()
39
+ # Map run_id to (fab_id, fab_version)
40
+ self.run_ids: Dict[int, Tuple[str, str]] = {}
40
41
  self.task_ins_store: Dict[UUID, TaskIns] = {}
41
42
  self.task_res_store: Dict[UUID, TaskRes] = {}
42
43
  self.lock = threading.Lock()
@@ -238,18 +239,26 @@ class InMemoryState(State):
238
239
  if online_until > current_time
239
240
  }
240
241
 
241
- def create_run(self) -> int:
242
- """Create one run."""
242
+ def create_run(self, fab_id: str, fab_version: str) -> int:
243
+ """Create a new run for the specified `fab_id` and `fab_version`."""
243
244
  # Sample a random int64 as run_id
244
245
  with self.lock:
245
246
  run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
246
247
 
247
248
  if run_id not in self.run_ids:
248
- self.run_ids.add(run_id)
249
+ self.run_ids[run_id] = (fab_id, fab_version)
249
250
  return run_id
250
251
  log(ERROR, "Unexpected run creation failure.")
251
252
  return 0
252
253
 
254
+ def get_run(self, run_id: int) -> Tuple[int, str, str]:
255
+ """Retrieve information about the run with the specified `run_id`."""
256
+ with self.lock:
257
+ if run_id not in self.run_ids:
258
+ log(ERROR, "`run_id` is invalid")
259
+ return 0, "", ""
260
+ return run_id, *self.run_ids[run_id]
261
+
253
262
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
254
263
  """Acknowledge a ping received from a node, serving as a heartbeat."""
255
264
  with self.lock:
@@ -46,7 +46,9 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
46
46
 
47
47
  SQL_CREATE_TABLE_RUN = """
48
48
  CREATE TABLE IF NOT EXISTS run(
49
- run_id INTEGER UNIQUE
49
+ run_id INTEGER UNIQUE,
50
+ fab_id TEXT,
51
+ fab_version TEXT
50
52
  );
51
53
  """
52
54
 
@@ -558,8 +560,8 @@ class SqliteState(State):
558
560
  result: Set[int] = {row["node_id"] for row in rows}
559
561
  return result
560
562
 
561
- def create_run(self) -> int:
562
- """Create one run and store it in state."""
563
+ def create_run(self, fab_id: str, fab_version: str) -> int:
564
+ """Create a new run for the specified `fab_id` and `fab_version`."""
563
565
  # Sample a random int64 as run_id
564
566
  run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
565
567
 
@@ -567,12 +569,22 @@ class SqliteState(State):
567
569
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
568
570
  # If run_id does not exist
569
571
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
570
- query = "INSERT INTO run VALUES(:run_id);"
571
- self.query(query, {"run_id": run_id})
572
+ query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
573
+ self.query(query, (run_id, fab_id, fab_version))
572
574
  return run_id
573
575
  log(ERROR, "Unexpected run creation failure.")
574
576
  return 0
575
577
 
578
+ def get_run(self, run_id: int) -> Tuple[int, str, str]:
579
+ """Retrieve information about the run with the specified `run_id`."""
580
+ query = "SELECT * FROM run WHERE run_id = ?;"
581
+ try:
582
+ row = self.query(query, (run_id,))[0]
583
+ return run_id, row["fab_id"], row["fab_version"]
584
+ except sqlite3.IntegrityError:
585
+ log(ERROR, "`run_id` does not exist.")
586
+ return 0, "", ""
587
+
576
588
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
577
589
  """Acknowledge a ping received from a node, serving as a heartbeat."""
578
590
  # Update `online_until` and `ping_interval` for the given `node_id`
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import List, Optional, Set
19
+ from typing import List, Optional, Set, Tuple
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
@@ -150,8 +150,26 @@ class State(abc.ABC):
150
150
  """
151
151
 
152
152
  @abc.abstractmethod
153
- def create_run(self) -> int:
154
- """Create one run."""
153
+ def create_run(self, fab_id: str, fab_version: str) -> int:
154
+ """Create a new run for the specified `fab_id` and `fab_version`."""
155
+
156
+ @abc.abstractmethod
157
+ def get_run(self, run_id: int) -> Tuple[int, str, str]:
158
+ """Retrieve information about the run with the specified `run_id`.
159
+
160
+ Parameters
161
+ ----------
162
+ run_id : int
163
+ The identifier of the run.
164
+
165
+ Returns
166
+ -------
167
+ Tuple[int, str, str]
168
+ A tuple containing three elements:
169
+ - `run_id`: The identifier of the run, same as the specified `run_id`.
170
+ - `fab_id`: The identifier of the FAB used in the specified run.
171
+ - `fab_version`: The version of the FAB used in the specified run.
172
+ """
155
173
 
156
174
  @abc.abstractmethod
157
175
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.9.0.dev20240416
3
+ Version: 1.9.0.dev20240420
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0