flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,54 @@
1
+ # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Driver gRPC API."""
16
+
17
+ from logging import INFO
18
+ from typing import Optional, Tuple
19
+
20
+ import grpc
21
+
22
+ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
23
+ from flwr.common.logger import log
24
+ from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
25
+ add_DriverServicer_to_server,
26
+ )
27
+ from flwr.server.superlink.state import StateFactory
28
+
29
+ from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
30
+ from .driver_servicer import DriverServicer
31
+
32
+
33
+ def run_driver_api_grpc(
34
+ address: str,
35
+ state_factory: StateFactory,
36
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
37
+ ) -> grpc.Server:
38
+ """Run Driver API (gRPC, request-response)."""
39
+ # Create Driver API gRPC server
40
+ driver_servicer: grpc.Server = DriverServicer(
41
+ state_factory=state_factory,
42
+ )
43
+ driver_add_servicer_to_server_fn = add_DriverServicer_to_server
44
+ driver_grpc_server = generic_create_grpc_server(
45
+ servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
46
+ server_address=address,
47
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
48
+ certificates=certificates,
49
+ )
50
+
51
+ log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
52
+ driver_grpc_server.start()
53
+
54
+ return driver_grpc_server
@@ -15,7 +15,7 @@
15
15
  """Driver API servicer."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG, INFO
19
19
  from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
70
70
  self, request: PushTaskInsRequest, context: grpc.ServicerContext
71
71
  ) -> PushTaskInsResponse:
72
72
  """Push a set of TaskIns."""
73
- log(INFO, "DriverServicer.PushTaskIns")
73
+ log(DEBUG, "DriverServicer.PushTaskIns")
74
74
 
75
75
  # Validate request
76
76
  _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
95
95
  self, request: PullTaskResRequest, context: grpc.ServicerContext
96
96
  ) -> PullTaskResResponse:
97
97
  """Pull a set of TaskRes."""
98
- log(INFO, "DriverServicer.PullTaskRes")
98
+ log(DEBUG, "DriverServicer.PullTaskRes")
99
99
 
100
100
  # Convert each task_id str to UUID
101
101
  task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
105
105
 
106
106
  # Register callback
107
107
  def on_rpc_done() -> None:
108
- log(INFO, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
108
+ log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
109
109
 
110
110
  if context.is_active():
111
111
  return
@@ -46,6 +46,7 @@ class GrpcClientProxy(ClientProxy):
46
46
  self,
47
47
  ins: common.GetPropertiesIns,
48
48
  timeout: Optional[float],
49
+ group_id: Optional[int],
49
50
  ) -> common.GetPropertiesRes:
50
51
  """Request client's set of internal properties."""
51
52
  get_properties_msg = serde.get_properties_ins_to_proto(ins)
@@ -65,6 +66,7 @@ class GrpcClientProxy(ClientProxy):
65
66
  self,
66
67
  ins: common.GetParametersIns,
67
68
  timeout: Optional[float],
69
+ group_id: Optional[int],
68
70
  ) -> common.GetParametersRes:
69
71
  """Return the current local model parameters."""
70
72
  get_parameters_msg = serde.get_parameters_ins_to_proto(ins)
@@ -84,6 +86,7 @@ class GrpcClientProxy(ClientProxy):
84
86
  self,
85
87
  ins: common.FitIns,
86
88
  timeout: Optional[float],
89
+ group_id: Optional[int],
87
90
  ) -> common.FitRes:
88
91
  """Refine the provided parameters using the locally held dataset."""
89
92
  fit_ins_msg = serde.fit_ins_to_proto(ins)
@@ -102,6 +105,7 @@ class GrpcClientProxy(ClientProxy):
102
105
  self,
103
106
  ins: common.EvaluateIns,
104
107
  timeout: Optional[float],
108
+ group_id: Optional[int],
105
109
  ) -> common.EvaluateRes:
106
110
  """Evaluate the provided parameters using the locally held dataset."""
107
111
  evaluate_msg = serde.evaluate_ins_to_proto(ins)
@@ -119,6 +123,7 @@ class GrpcClientProxy(ClientProxy):
119
123
  self,
120
124
  ins: common.ReconnectIns,
121
125
  timeout: Optional[float],
126
+ group_id: Optional[int],
122
127
  ) -> common.DisconnectRes:
123
128
  """Disconnect and (optionally) reconnect later."""
124
129
  reconnect_ins_msg = serde.reconnect_ins_to_proto(ins)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Fleet VirtualClientEngine side."""
15
+ """Fleet Simulation Engine side."""
16
16
 
17
17
  from .vce_api import start_vce
18
18
 
@@ -141,13 +141,13 @@ class RayBackend(Backend):
141
141
 
142
142
  Return output message and updated context.
143
143
  """
144
- node_id = message.metadata.dst_node_id
144
+ partition_id = message.metadata.partition_id
145
145
 
146
146
  try:
147
147
  # Submite a task to the pool
148
148
  future = await self.pool.submit(
149
149
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
150
- (app, message, str(node_id), context),
150
+ (app, message, str(partition_id), context),
151
151
  )
152
152
 
153
153
  await future
@@ -163,10 +163,9 @@ class RayBackend(Backend):
163
163
  except LoadClientAppError as load_ex:
164
164
  log(
165
165
  ERROR,
166
- "An exception was raised when processing a message. Terminating %s",
166
+ "An exception was raised when processing a message by %s",
167
167
  self.__class__.__name__,
168
168
  )
169
- await self.terminate()
170
169
  raise load_ex
171
170
 
172
171
  async def terminate(self) -> None:
@@ -12,19 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Fleet VirtualClientEngine API."""
15
+ """Fleet Simulation Engine API."""
16
+
16
17
 
17
18
  import asyncio
18
19
  import json
19
- from logging import ERROR, INFO
20
- from typing import Dict, Optional
20
+ import traceback
21
+ from logging import DEBUG, ERROR, INFO, WARN
22
+ from typing import Callable, Dict, List, Optional
21
23
 
22
- from flwr.client.client_app import ClientApp, load_client_app
24
+ from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
23
25
  from flwr.client.node_state import NodeState
24
26
  from flwr.common.logger import log
27
+ from flwr.common.serde import message_from_taskins, message_to_taskres
28
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
25
29
  from flwr.server.superlink.state import StateFactory
26
30
 
27
- from .backend import error_messages_backends, supported_backends
31
+ from .backend import Backend, error_messages_backends, supported_backends
28
32
 
29
33
  NodeToPartitionMapping = Dict[int, int]
30
34
 
@@ -42,21 +46,223 @@ def _register_nodes(
42
46
  return nodes_mapping
43
47
 
44
48
 
45
- # pylint: disable=too-many-arguments,unused-argument
49
+ # pylint: disable=too-many-arguments,too-many-locals
50
+ async def worker(
51
+ app_fn: Callable[[], ClientApp],
52
+ queue: "asyncio.Queue[TaskIns]",
53
+ node_states: Dict[int, NodeState],
54
+ state_factory: StateFactory,
55
+ nodes_mapping: NodeToPartitionMapping,
56
+ backend: Backend,
57
+ ) -> None:
58
+ """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
59
+ state = state_factory.state()
60
+ while True:
61
+ try:
62
+ task_ins: TaskIns = await queue.get()
63
+ node_id = task_ins.task.consumer.node_id
64
+
65
+ # Register and retrieve runstate
66
+ node_states[node_id].register_context(run_id=task_ins.run_id)
67
+ context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
68
+
69
+ # Convert TaskIns to Message
70
+ message = message_from_taskins(task_ins)
71
+ # Set partition_id
72
+ message.metadata.partition_id = nodes_mapping[node_id]
73
+
74
+ # Let backend process message
75
+ out_mssg, updated_context = await backend.process_message(
76
+ app_fn, message, context
77
+ )
78
+
79
+ # Update Context
80
+ node_states[node_id].update_context(
81
+ task_ins.run_id, context=updated_context
82
+ )
83
+
84
+ # Convert to TaskRes
85
+ task_res = message_to_taskres(out_mssg)
86
+ # Store TaskRes in state
87
+ state.store_task_res(task_res)
88
+
89
+ except asyncio.CancelledError as e:
90
+ log(DEBUG, "Async worker: %s", e)
91
+ break
92
+
93
+ except LoadClientAppError as app_ex:
94
+ log(ERROR, "Async worker: %s", app_ex)
95
+ log(ERROR, traceback.format_exc())
96
+ raise
97
+
98
+ except Exception as ex: # pylint: disable=broad-exception-caught
99
+ log(ERROR, ex)
100
+ log(ERROR, traceback.format_exc())
101
+ break
102
+
103
+
104
+ async def add_taskins_to_queue(
105
+ queue: "asyncio.Queue[TaskIns]",
106
+ state_factory: StateFactory,
107
+ nodes_mapping: NodeToPartitionMapping,
108
+ backend: Backend,
109
+ consumers: List["asyncio.Task[None]"],
110
+ f_stop: asyncio.Event,
111
+ ) -> None:
112
+ """Retrieve TaskIns and add it to the queue."""
113
+ state = state_factory.state()
114
+ num_initial_consumers = len(consumers)
115
+ while not f_stop.is_set():
116
+ for node_id in nodes_mapping.keys():
117
+ task_ins = state.get_task_ins(node_id=node_id, limit=1)
118
+ if task_ins:
119
+ await queue.put(task_ins[0])
120
+
121
+ # Count consumers that are running
122
+ num_active = sum(not (cc.done()) for cc in consumers)
123
+
124
+ # Alert if number of consumers decreased by half
125
+ if num_active < num_initial_consumers // 2:
126
+ log(
127
+ WARN,
128
+ "Number of active workers has more than halved: (%i/%i active)",
129
+ num_active,
130
+ num_initial_consumers,
131
+ )
132
+
133
+ # Break if consumers died
134
+ if num_active == 0:
135
+ raise RuntimeError("All workers have died. Ending Simulation.")
136
+
137
+ # Log some stats
138
+ log(
139
+ DEBUG,
140
+ "Simulation Engine stats: "
141
+ "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
142
+ num_active,
143
+ num_initial_consumers,
144
+ backend.__class__.__name__,
145
+ backend.num_workers,
146
+ queue.qsize(),
147
+ )
148
+ await asyncio.sleep(1.0)
149
+ log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
150
+
151
+
152
+ async def run(
153
+ app_fn: Callable[[], ClientApp],
154
+ backend_fn: Callable[[], Backend],
155
+ nodes_mapping: NodeToPartitionMapping,
156
+ state_factory: StateFactory,
157
+ node_states: Dict[int, NodeState],
158
+ f_stop: asyncio.Event,
159
+ ) -> None:
160
+ """Run the VCE async."""
161
+ queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
162
+
163
+ try:
164
+
165
+ # Instantiate backend
166
+ backend = backend_fn()
167
+
168
+ # Build backend
169
+ await backend.build()
170
+
171
+ # Add workers (they submit Messages to Backend)
172
+ worker_tasks = [
173
+ asyncio.create_task(
174
+ worker(
175
+ app_fn, queue, node_states, state_factory, nodes_mapping, backend
176
+ )
177
+ )
178
+ for _ in range(backend.num_workers)
179
+ ]
180
+ # Create producer (adds TaskIns into Queue)
181
+ producer = asyncio.create_task(
182
+ add_taskins_to_queue(
183
+ queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
184
+ )
185
+ )
186
+
187
+ # Wait for producer to finish
188
+ # The producer runs forever until f_stop is set or until
189
+ # all worker (consumer) coroutines are completed. Workers
190
+ # also run forever and only end if an exception is raised.
191
+ await asyncio.gather(producer)
192
+
193
+ except Exception as ex:
194
+
195
+ log(ERROR, "An exception occured!! %s", ex)
196
+ log(ERROR, traceback.format_exc())
197
+ log(WARN, "Stopping Simulation Engine.")
198
+
199
+ # Manually trigger stopping event
200
+ f_stop.set()
201
+
202
+ # Raise exception
203
+ raise RuntimeError("Simulation Engine crashed.") from ex
204
+
205
+ finally:
206
+ # Produced task terminated, now cancel worker tasks
207
+ for w_t in worker_tasks:
208
+ _ = w_t.cancel()
209
+
210
+ while not all(w_t.done() for w_t in worker_tasks):
211
+ log(DEBUG, "Terminating async workers...")
212
+ await asyncio.sleep(0.5)
213
+
214
+ await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
215
+
216
+ # Terminate backend
217
+ await backend.terminate()
218
+
219
+
220
+ # pylint: disable=too-many-arguments,unused-argument,too-many-locals
46
221
  def start_vce(
47
- num_supernodes: int,
48
222
  client_app_module_name: str,
49
223
  backend_name: str,
50
224
  backend_config_json_stream: str,
51
- state_factory: StateFactory,
52
225
  working_dir: str,
53
- f_stop: Optional[asyncio.Event] = None,
226
+ f_stop: asyncio.Event,
227
+ num_supernodes: Optional[int] = None,
228
+ state_factory: Optional[StateFactory] = None,
229
+ existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
54
230
  ) -> None:
55
- """Start Fleet API with the VirtualClientEngine (VCE)."""
56
- # Register SuperNodes
57
- nodes_mapping = _register_nodes(
58
- num_nodes=num_supernodes, state_factory=state_factory
59
- )
231
+ """Start Fleet API with the Simulation Engine."""
232
+ if num_supernodes is not None and existing_nodes_mapping is not None:
233
+ raise ValueError(
234
+ "Both `num_supernodes` and `existing_nodes_mapping` are provided, "
235
+ "but only one is allowed."
236
+ )
237
+ if num_supernodes is None:
238
+ if state_factory is None or existing_nodes_mapping is None:
239
+ raise ValueError(
240
+ "If not passing an existing `state_factory` and associated "
241
+ "`existing_nodes_mapping` you must supply `num_supernodes` to indicate "
242
+ "how many nodes to insert into a new StateFactory that will be created."
243
+ )
244
+ if existing_nodes_mapping:
245
+ if state_factory is None:
246
+ raise ValueError(
247
+ "`existing_nodes_mapping` was passed, but no `state_factory` was "
248
+ "passed."
249
+ )
250
+ log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.")
251
+ # Use mapping constructed externally. This also means nodes
252
+ # have previously being registered.
253
+ nodes_mapping = existing_nodes_mapping
254
+
255
+ if not state_factory:
256
+ log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
257
+ # Create an empty in-memory state factory
258
+ state_factory = StateFactory(":flwr-in-memory-state:")
259
+ log(INFO, "Created new %s.", state_factory.__class__.__name__)
260
+
261
+ if num_supernodes:
262
+ # Register SuperNodes
263
+ nodes_mapping = _register_nodes(
264
+ num_nodes=num_supernodes, state_factory=state_factory
265
+ )
60
266
 
61
267
  # Construct mapping of NodeStates
62
268
  node_states: Dict[int, NodeState] = {}
@@ -69,7 +275,6 @@ def start_vce(
69
275
 
70
276
  try:
71
277
  backend_type = supported_backends[backend_name]
72
- _ = backend_type(backend_config, work_dir=working_dir)
73
278
  except KeyError as ex:
74
279
  log(
75
280
  ERROR,
@@ -83,10 +288,25 @@ def start_vce(
83
288
 
84
289
  raise ex
85
290
 
291
+ def backend_fn() -> Backend:
292
+ """Instantiate a Backend."""
293
+ return backend_type(backend_config, work_dir=working_dir)
294
+
86
295
  log(INFO, "client_app_module_name = %s", client_app_module_name)
87
296
 
88
297
  def _load() -> ClientApp:
89
298
  app: ClientApp = load_client_app(client_app_module_name)
90
299
  return app
91
300
 
92
- # start backend
301
+ app_fn = _load
302
+
303
+ asyncio.run(
304
+ run(
305
+ app_fn,
306
+ backend_fn,
307
+ nodes_mapping,
308
+ state_factory,
309
+ node_states,
310
+ f_stop,
311
+ )
312
+ )
flwr/server/typing.py CHANGED
@@ -22,3 +22,4 @@ from flwr.common import Context
22
22
  from .driver import Driver
23
23
 
24
24
  ServerAppCallable = Callable[[Driver, Context], None]
25
+ Workflow = Callable[[Driver, Context], None]
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflows."""
16
+
17
+
18
+ from .default_workflows import DefaultWorkflow
19
+
20
+ __all__ = [
21
+ "DefaultWorkflow",
22
+ ]