flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (64) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +151 -0
  3. flwr/cli/config_utils.py +18 -46
  4. flwr/cli/new/new.py +42 -18
  5. flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
  6. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
  7. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
  8. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
  9. flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
  10. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
  12. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
  13. flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
  14. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
  20. flwr/cli/run/run.py +1 -1
  21. flwr/cli/utils.py +18 -17
  22. flwr/client/__init__.py +1 -1
  23. flwr/client/app.py +17 -93
  24. flwr/client/grpc_client/connection.py +6 -1
  25. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  26. flwr/client/grpc_rere_client/connection.py +17 -2
  27. flwr/client/mod/centraldp_mods.py +4 -2
  28. flwr/client/mod/localdp_mod.py +9 -3
  29. flwr/client/rest_client/connection.py +5 -1
  30. flwr/client/supernode/__init__.py +2 -0
  31. flwr/client/supernode/app.py +181 -7
  32. flwr/common/grpc.py +5 -1
  33. flwr/common/logger.py +37 -4
  34. flwr/common/message.py +105 -86
  35. flwr/common/record/parametersrecord.py +0 -1
  36. flwr/common/record/recordset.py +17 -5
  37. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
  38. flwr/server/app.py +111 -1
  39. flwr/server/compat/app.py +2 -2
  40. flwr/server/compat/app_utils.py +1 -1
  41. flwr/server/compat/driver_client_proxy.py +27 -72
  42. flwr/server/driver/__init__.py +3 -0
  43. flwr/server/driver/driver.py +12 -242
  44. flwr/server/driver/grpc_driver.py +315 -0
  45. flwr/server/run_serverapp.py +18 -4
  46. flwr/server/strategy/dp_adaptive_clipping.py +5 -3
  47. flwr/server/strategy/dp_fixed_clipping.py +6 -3
  48. flwr/server/superlink/driver/driver_servicer.py +1 -1
  49. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  50. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
  51. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  52. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  53. flwr/server/superlink/state/in_memory_state.py +76 -8
  54. flwr/server/superlink/state/sqlite_state.py +116 -11
  55. flwr/server/superlink/state/state.py +35 -3
  56. flwr/simulation/__init__.py +2 -2
  57. flwr/simulation/app.py +16 -1
  58. flwr/simulation/run_simulation.py +10 -7
  59. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
  60. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
  61. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
  62. flwr/server/driver/abc_driver.py +0 -140
  63. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
  64. {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
@@ -16,16 +16,14 @@
16
16
 
17
17
 
18
18
  import time
19
- from typing import List, Optional
19
+ from typing import Optional
20
20
 
21
21
  from flwr import common
22
- from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet
22
+ from flwr.common import Message, MessageType, MessageTypeLegacy, RecordSet
23
23
  from flwr.common import recordset_compat as compat
24
- from flwr.common import serde
25
- from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
26
24
  from flwr.server.client_proxy import ClientProxy
27
25
 
28
- from ..driver.driver import GrpcDriverHelper
26
+ from ..driver.driver import Driver
29
27
 
30
28
  SLEEP_TIME = 1
31
29
 
@@ -33,9 +31,7 @@ SLEEP_TIME = 1
33
31
  class DriverClientProxy(ClientProxy):
34
32
  """Flower client proxy which delegates work using the Driver API."""
35
33
 
36
- def __init__(
37
- self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
38
- ):
34
+ def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
39
35
  super().__init__(str(node_id))
40
36
  self.node_id = node_id
41
37
  self.driver = driver
@@ -116,80 +112,39 @@ class DriverClientProxy(ClientProxy):
116
112
  timeout: Optional[float],
117
113
  group_id: Optional[int],
118
114
  ) -> RecordSet:
119
- task_ins = task_pb2.TaskIns( # pylint: disable=E1101
120
- task_id="",
121
- group_id=str(group_id) if group_id is not None else "",
122
- run_id=self.run_id,
123
- task=task_pb2.Task( # pylint: disable=E1101
124
- producer=node_pb2.Node( # pylint: disable=E1101
125
- node_id=0,
126
- anonymous=True,
127
- ),
128
- consumer=node_pb2.Node( # pylint: disable=E1101
129
- node_id=self.node_id,
130
- anonymous=self.anonymous,
131
- ),
132
- task_type=task_type,
133
- recordset=serde.recordset_to_proto(recordset),
134
- ttl=DEFAULT_TTL,
135
- ),
136
- )
137
-
138
- # This would normally be recorded upon common.Message creation
139
- # but this compatibility stack doesn't create Messages,
140
- # so we need to inject `created_at` manually (needed for
141
- # taskins validation by server.utils.validator)
142
- task_ins.task.created_at = time.time()
143
115
 
144
- push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101
145
- task_ins_list=[task_ins]
116
+ # Create message
117
+ message = self.driver.create_message(
118
+ content=recordset,
119
+ message_type=task_type,
120
+ dst_node_id=self.node_id,
121
+ group_id=str(group_id) if group_id else "",
122
+ ttl=timeout,
146
123
  )
147
124
 
148
- # Send TaskIns to Driver API
149
- push_task_ins_res = self.driver.push_task_ins(req=push_task_ins_req)
150
-
151
- if len(push_task_ins_res.task_ids) != 1:
152
- raise ValueError("Unexpected number of task_ids")
125
+ # Push message
126
+ message_ids = list(self.driver.push_messages(messages=[message]))
127
+ if len(message_ids) != 1:
128
+ raise ValueError("Unexpected number of message_ids")
153
129
 
154
- task_id = push_task_ins_res.task_ids[0]
155
- if task_id == "":
156
- raise ValueError(f"Failed to schedule task for node {self.node_id}")
130
+ message_id = message_ids[0]
131
+ if message_id == "":
132
+ raise ValueError(f"Failed to send message to node {self.node_id}")
157
133
 
158
134
  if timeout:
159
135
  start_time = time.time()
160
136
 
161
137
  while True:
162
- pull_task_res_req = driver_pb2.PullTaskResRequest( # pylint: disable=E1101
163
- node=node_pb2.Node(node_id=0, anonymous=True), # pylint: disable=E1101
164
- task_ids=[task_id],
165
- )
166
-
167
- # Ask Driver API for TaskRes
168
- pull_task_res_res = self.driver.pull_task_res(req=pull_task_res_req)
169
-
170
- task_res_list: List[task_pb2.TaskRes] = list( # pylint: disable=E1101
171
- pull_task_res_res.task_res_list
172
- )
173
- if len(task_res_list) == 1:
174
- task_res = task_res_list[0]
175
-
176
- # This will raise an Exception if task_res carries an `error`
177
- validate_task_res(task_res=task_res)
178
-
179
- return serde.recordset_from_proto(task_res.task.recordset)
138
+ messages = list(self.driver.pull_messages(message_ids))
139
+ if len(messages) == 1:
140
+ msg: Message = messages[0]
141
+ if msg.has_error():
142
+ raise ValueError(
143
+ f"Message contains an Error (reason: {msg.error.reason}). "
144
+ "It originated during client-side execution of a message."
145
+ )
146
+ return msg.content
180
147
 
181
148
  if timeout is not None and time.time() > start_time + timeout:
182
149
  raise RuntimeError("Timeout reached")
183
150
  time.sleep(SLEEP_TIME)
184
-
185
-
186
- def validate_task_res(
187
- task_res: task_pb2.TaskRes, # pylint: disable=E1101
188
- ) -> None:
189
- """Validate if a TaskRes is empty or not."""
190
- if not task_res.HasField("task"):
191
- raise ValueError("Invalid TaskRes, field `task` missing")
192
- if task_res.task.HasField("error"):
193
- raise ValueError("Exception during client-side task execution")
194
- if not task_res.task.HasField("recordset"):
195
- raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
@@ -16,7 +16,10 @@
16
16
 
17
17
 
18
18
  from .driver import Driver
19
+ from .grpc_driver import GrpcDriver, GrpcDriverHelper
19
20
 
20
21
  __all__ = [
21
22
  "Driver",
23
+ "GrpcDriver",
24
+ "GrpcDriverHelper",
22
25
  ]
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,180 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service client."""
15
+ """Driver (abstract base class)."""
16
16
 
17
- import time
18
- import warnings
19
- from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple
21
17
 
22
- import grpc
18
+ from abc import ABC, abstractmethod
19
+ from typing import Iterable, List, Optional
23
20
 
24
- from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
25
- from flwr.common.grpc import create_channel
26
- from flwr.common.logger import log
27
- from flwr.common.serde import message_from_taskres, message_to_taskins
28
- from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
- CreateRunRequest,
30
- CreateRunResponse,
31
- GetNodesRequest,
32
- GetNodesResponse,
33
- PullTaskResRequest,
34
- PullTaskResResponse,
35
- PushTaskInsRequest,
36
- PushTaskInsResponse,
37
- )
38
- from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
39
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
21
+ from flwr.common import Message, RecordSet
41
22
 
42
- DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
43
23
 
44
- ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
45
- [Driver] Error: Not connected.
46
-
47
- Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
48
- `GrpcDriverHelper` methods.
49
- """
50
-
51
-
52
- class GrpcDriverHelper:
53
- """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
54
-
55
- def __init__(
56
- self,
57
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
58
- root_certificates: Optional[bytes] = None,
59
- ) -> None:
60
- self.driver_service_address = driver_service_address
61
- self.root_certificates = root_certificates
62
- self.channel: Optional[grpc.Channel] = None
63
- self.stub: Optional[DriverStub] = None
64
-
65
- def connect(self) -> None:
66
- """Connect to the Driver API."""
67
- event(EventType.DRIVER_CONNECT)
68
- if self.channel is not None or self.stub is not None:
69
- log(WARNING, "Already connected")
70
- return
71
- self.channel = create_channel(
72
- server_address=self.driver_service_address,
73
- insecure=(self.root_certificates is None),
74
- root_certificates=self.root_certificates,
75
- )
76
- self.stub = DriverStub(self.channel)
77
- log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
78
-
79
- def disconnect(self) -> None:
80
- """Disconnect from the Driver API."""
81
- event(EventType.DRIVER_DISCONNECT)
82
- if self.channel is None or self.stub is None:
83
- log(DEBUG, "Already disconnected")
84
- return
85
- channel = self.channel
86
- self.channel = None
87
- self.stub = None
88
- channel.close()
89
- log(DEBUG, "[Driver] Disconnected")
90
-
91
- def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
92
- """Request for run ID."""
93
- # Check if channel is open
94
- if self.stub is None:
95
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
96
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
97
-
98
- # Call Driver API
99
- res: CreateRunResponse = self.stub.CreateRun(request=req)
100
- return res
101
-
102
- def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
103
- """Get client IDs."""
104
- # Check if channel is open
105
- if self.stub is None:
106
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
107
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
108
-
109
- # Call gRPC Driver API
110
- res: GetNodesResponse = self.stub.GetNodes(request=req)
111
- return res
112
-
113
- def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
114
- """Schedule tasks."""
115
- # Check if channel is open
116
- if self.stub is None:
117
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
118
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
119
-
120
- # Call gRPC Driver API
121
- res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
122
- return res
123
-
124
- def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
125
- """Get task results."""
126
- # Check if channel is open
127
- if self.stub is None:
128
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
129
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
130
-
131
- # Call Driver API
132
- res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
133
- return res
134
-
135
-
136
- class Driver:
137
- """`Driver` class provides an interface to the Driver API.
138
-
139
- Parameters
140
- ----------
141
- driver_service_address : Optional[str]
142
- The IPv4 or IPv6 address of the Driver API server.
143
- Defaults to `"[::]:9091"`.
144
- certificates : bytes (default: None)
145
- Tuple containing root certificate, server certificate, and private key
146
- to start a secure SSL-enabled server. The tuple is expected to have
147
- three bytes elements in the following order:
148
-
149
- * CA certificate.
150
- * server certificate.
151
- * server private key.
152
- """
153
-
154
- def __init__(
155
- self,
156
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
157
- root_certificates: Optional[bytes] = None,
158
- ) -> None:
159
- self.addr = driver_service_address
160
- self.root_certificates = root_certificates
161
- self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
162
- self.run_id: Optional[int] = None
163
- self.node = Node(node_id=0, anonymous=True)
164
-
165
- def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
166
- # Check if the GrpcDriverHelper is initialized
167
- if self.grpc_driver_helper is None or self.run_id is None:
168
- # Connect and create run
169
- self.grpc_driver_helper = GrpcDriverHelper(
170
- driver_service_address=self.addr,
171
- root_certificates=self.root_certificates,
172
- )
173
- self.grpc_driver_helper.connect()
174
- res = self.grpc_driver_helper.create_run(CreateRunRequest())
175
- self.run_id = res.run_id
176
- return self.grpc_driver_helper, self.run_id
177
-
178
- def _check_message(self, message: Message) -> None:
179
- # Check if the message is valid
180
- if not (
181
- message.metadata.run_id == self.run_id
182
- and message.metadata.src_node_id == self.node.node_id
183
- and message.metadata.message_id == ""
184
- and message.metadata.reply_to_message == ""
185
- and message.metadata.ttl > 0
186
- ):
187
- raise ValueError(f"Invalid message: {message}")
24
+ class Driver(ABC):
25
+ """Abstract base Driver class for the Driver API."""
188
26
 
27
+ @abstractmethod
189
28
  def create_message( # pylint: disable=too-many-arguments
190
29
  self,
191
30
  content: RecordSet,
@@ -223,35 +62,12 @@ class Driver:
223
62
  message : Message
224
63
  A new `Message` instance with the specified content and metadata.
225
64
  """
226
- _, run_id = self._get_grpc_driver_helper_and_run_id()
227
- if ttl:
228
- warnings.warn(
229
- "A custom TTL was set, but note that the SuperLink does not enforce "
230
- "the TTL yet. The SuperLink will start enforcing the TTL in a future "
231
- "version of Flower.",
232
- stacklevel=2,
233
- )
234
-
235
- ttl_ = DEFAULT_TTL if ttl is None else ttl
236
- metadata = Metadata(
237
- run_id=run_id,
238
- message_id="", # Will be set by the server
239
- src_node_id=self.node.node_id,
240
- dst_node_id=dst_node_id,
241
- reply_to_message="",
242
- group_id=group_id,
243
- ttl=ttl_,
244
- message_type=message_type,
245
- )
246
- return Message(metadata=metadata, content=content)
247
65
 
66
+ @abstractmethod
248
67
  def get_node_ids(self) -> List[int]:
249
68
  """Get node IDs."""
250
- grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
251
- # Call GrpcDriverHelper method
252
- res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
253
- return [node.node_id for node in res.nodes]
254
69
 
70
+ @abstractmethod
255
71
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
256
72
  """Push messages to specified node IDs.
257
73
 
@@ -269,22 +85,8 @@ class Driver:
269
85
  An iterable of IDs for the messages that were sent, which can be used
270
86
  to pull replies.
271
87
  """
272
- grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
273
- # Construct TaskIns
274
- task_ins_list: List[TaskIns] = []
275
- for msg in messages:
276
- # Check message
277
- self._check_message(msg)
278
- # Convert Message to TaskIns
279
- taskins = message_to_taskins(msg)
280
- # Add to list
281
- task_ins_list.append(taskins)
282
- # Call GrpcDriverHelper method
283
- res = grpc_driver_helper.push_task_ins(
284
- PushTaskInsRequest(task_ins_list=task_ins_list)
285
- )
286
- return list(res.task_ids)
287
88
 
89
+ @abstractmethod
288
90
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
289
91
  """Pull messages based on message IDs.
290
92
 
@@ -301,15 +103,8 @@ class Driver:
301
103
  messages : Iterable[Message]
302
104
  An iterable of messages received.
303
105
  """
304
- grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
305
- # Pull TaskRes
306
- res = grpc_driver.pull_task_res(
307
- PullTaskResRequest(node=self.node, task_ids=message_ids)
308
- )
309
- # Convert TaskRes to Message
310
- msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
311
- return msgs
312
106
 
107
+ @abstractmethod
313
108
  def send_and_receive(
314
109
  self,
315
110
  messages: Iterable[Message],
@@ -343,28 +138,3 @@ class Driver:
343
138
  replies for all sent messages. A message remains valid until its TTL,
344
139
  which is not affected by `timeout`.
345
140
  """
346
- # Push messages
347
- msg_ids = set(self.push_messages(messages))
348
-
349
- # Pull messages
350
- end_time = time.time() + (timeout if timeout is not None else 0.0)
351
- ret: List[Message] = []
352
- while timeout is None or time.time() < end_time:
353
- res_msgs = self.pull_messages(msg_ids)
354
- ret.extend(res_msgs)
355
- msg_ids.difference_update(
356
- {msg.metadata.reply_to_message for msg in res_msgs}
357
- )
358
- if len(msg_ids) == 0:
359
- break
360
- # Sleep
361
- time.sleep(3)
362
- return ret
363
-
364
- def close(self) -> None:
365
- """Disconnect from the SuperLink if connected."""
366
- # Check if GrpcDriverHelper is initialized
367
- if self.grpc_driver_helper is None:
368
- return
369
- # Disconnect
370
- self.grpc_driver_helper.disconnect()