flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -0,0 +1,54 @@
1
+ # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Driver gRPC API."""
16
+
17
+ from logging import INFO
18
+ from typing import Optional, Tuple
19
+
20
+ import grpc
21
+
22
+ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
23
+ from flwr.common.logger import log
24
+ from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
25
+ add_DriverServicer_to_server,
26
+ )
27
+ from flwr.server.superlink.state import StateFactory
28
+
29
+ from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
30
+ from .driver_servicer import DriverServicer
31
+
32
+
33
+ def run_driver_api_grpc(
34
+ address: str,
35
+ state_factory: StateFactory,
36
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
37
+ ) -> grpc.Server:
38
+ """Run Driver API (gRPC, request-response)."""
39
+ # Create Driver API gRPC server
40
+ driver_servicer: grpc.Server = DriverServicer(
41
+ state_factory=state_factory,
42
+ )
43
+ driver_add_servicer_to_server_fn = add_DriverServicer_to_server
44
+ driver_grpc_server = generic_create_grpc_server(
45
+ servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
46
+ server_address=address,
47
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
48
+ certificates=certificates,
49
+ )
50
+
51
+ log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
52
+ driver_grpc_server.start()
53
+
54
+ return driver_grpc_server
@@ -15,7 +15,7 @@
15
15
  """Driver API servicer."""
16
16
 
17
17
 
18
- from logging import INFO
18
+ from logging import DEBUG, INFO
19
19
  from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
70
70
  self, request: PushTaskInsRequest, context: grpc.ServicerContext
71
71
  ) -> PushTaskInsResponse:
72
72
  """Push a set of TaskIns."""
73
- log(INFO, "DriverServicer.PushTaskIns")
73
+ log(DEBUG, "DriverServicer.PushTaskIns")
74
74
 
75
75
  # Validate request
76
76
  _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
95
95
  self, request: PullTaskResRequest, context: grpc.ServicerContext
96
96
  ) -> PullTaskResResponse:
97
97
  """Pull a set of TaskRes."""
98
- log(INFO, "DriverServicer.PullTaskRes")
98
+ log(DEBUG, "DriverServicer.PullTaskRes")
99
99
 
100
100
  # Convert each task_id str to UUID
101
101
  task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
105
105
 
106
106
  # Register callback
107
107
  def on_rpc_done() -> None:
108
- log(INFO, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
108
+ log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
109
109
 
110
110
  if context.is_active():
111
111
  return
@@ -46,6 +46,7 @@ class GrpcClientProxy(ClientProxy):
46
46
  self,
47
47
  ins: common.GetPropertiesIns,
48
48
  timeout: Optional[float],
49
+ group_id: Optional[int],
49
50
  ) -> common.GetPropertiesRes:
50
51
  """Request client's set of internal properties."""
51
52
  get_properties_msg = serde.get_properties_ins_to_proto(ins)
@@ -65,6 +66,7 @@ class GrpcClientProxy(ClientProxy):
65
66
  self,
66
67
  ins: common.GetParametersIns,
67
68
  timeout: Optional[float],
69
+ group_id: Optional[int],
68
70
  ) -> common.GetParametersRes:
69
71
  """Return the current local model parameters."""
70
72
  get_parameters_msg = serde.get_parameters_ins_to_proto(ins)
@@ -84,6 +86,7 @@ class GrpcClientProxy(ClientProxy):
84
86
  self,
85
87
  ins: common.FitIns,
86
88
  timeout: Optional[float],
89
+ group_id: Optional[int],
87
90
  ) -> common.FitRes:
88
91
  """Refine the provided parameters using the locally held dataset."""
89
92
  fit_ins_msg = serde.fit_ins_to_proto(ins)
@@ -102,6 +105,7 @@ class GrpcClientProxy(ClientProxy):
102
105
  self,
103
106
  ins: common.EvaluateIns,
104
107
  timeout: Optional[float],
108
+ group_id: Optional[int],
105
109
  ) -> common.EvaluateRes:
106
110
  """Evaluate the provided parameters using the locally held dataset."""
107
111
  evaluate_msg = serde.evaluate_ins_to_proto(ins)
@@ -119,6 +123,7 @@ class GrpcClientProxy(ClientProxy):
119
123
  self,
120
124
  ins: common.ReconnectIns,
121
125
  timeout: Optional[float],
126
+ group_id: Optional[int],
122
127
  ) -> common.DisconnectRes:
123
128
  """Disconnect and (optionally) reconnect later."""
124
129
  reconnect_ins_msg = serde.reconnect_ins_to_proto(ins)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Fleet VirtualClientEngine side."""
15
+ """Fleet Simulation Engine side."""
16
16
 
17
17
  from .vce_api import start_vce
18
18
 
@@ -141,13 +141,13 @@ class RayBackend(Backend):
141
141
 
142
142
  Return output message and updated context.
143
143
  """
144
- node_id = message.metadata.dst_node_id
144
+ partition_id = message.metadata.partition_id
145
145
 
146
146
  try:
147
147
  # Submite a task to the pool
148
148
  future = await self.pool.submit(
149
149
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
150
- (app, message, str(node_id), context),
150
+ (app, message, str(partition_id), context),
151
151
  )
152
152
 
153
153
  await future
@@ -163,10 +163,9 @@ class RayBackend(Backend):
163
163
  except LoadClientAppError as load_ex:
164
164
  log(
165
165
  ERROR,
166
- "An exception was raised when processing a message. Terminating %s",
166
+ "An exception was raised when processing a message by %s",
167
167
  self.__class__.__name__,
168
168
  )
169
- await self.terminate()
170
169
  raise load_ex
171
170
 
172
171
  async def terminate(self) -> None:
@@ -12,19 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Fleet VirtualClientEngine API."""
15
+ """Fleet Simulation Engine API."""
16
+
16
17
 
17
18
  import asyncio
18
19
  import json
19
- from logging import ERROR, INFO
20
- from typing import Dict, Optional
20
+ import traceback
21
+ from logging import DEBUG, ERROR, INFO, WARN
22
+ from typing import Callable, Dict, List, Optional
21
23
 
22
- from flwr.client.client_app import ClientApp, load_client_app
24
+ from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
23
25
  from flwr.client.node_state import NodeState
24
26
  from flwr.common.logger import log
27
+ from flwr.common.serde import message_from_taskins, message_to_taskres
28
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
25
29
  from flwr.server.superlink.state import StateFactory
26
30
 
27
- from .backend import error_messages_backends, supported_backends
31
+ from .backend import Backend, error_messages_backends, supported_backends
28
32
 
29
33
  NodeToPartitionMapping = Dict[int, int]
30
34
 
@@ -42,21 +46,223 @@ def _register_nodes(
42
46
  return nodes_mapping
43
47
 
44
48
 
45
- # pylint: disable=too-many-arguments,unused-argument
49
+ # pylint: disable=too-many-arguments,too-many-locals
50
+ async def worker(
51
+ app_fn: Callable[[], ClientApp],
52
+ queue: "asyncio.Queue[TaskIns]",
53
+ node_states: Dict[int, NodeState],
54
+ state_factory: StateFactory,
55
+ nodes_mapping: NodeToPartitionMapping,
56
+ backend: Backend,
57
+ ) -> None:
58
+ """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
59
+ state = state_factory.state()
60
+ while True:
61
+ try:
62
+ task_ins: TaskIns = await queue.get()
63
+ node_id = task_ins.task.consumer.node_id
64
+
65
+ # Register and retrieve runstate
66
+ node_states[node_id].register_context(run_id=task_ins.run_id)
67
+ context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
68
+
69
+ # Convert TaskIns to Message
70
+ message = message_from_taskins(task_ins)
71
+ # Set partition_id
72
+ message.metadata.partition_id = nodes_mapping[node_id]
73
+
74
+ # Let backend process message
75
+ out_mssg, updated_context = await backend.process_message(
76
+ app_fn, message, context
77
+ )
78
+
79
+ # Update Context
80
+ node_states[node_id].update_context(
81
+ task_ins.run_id, context=updated_context
82
+ )
83
+
84
+ # Convert to TaskRes
85
+ task_res = message_to_taskres(out_mssg)
86
+ # Store TaskRes in state
87
+ state.store_task_res(task_res)
88
+
89
+ except asyncio.CancelledError as e:
90
+ log(DEBUG, "Async worker: %s", e)
91
+ break
92
+
93
+ except LoadClientAppError as app_ex:
94
+ log(ERROR, "Async worker: %s", app_ex)
95
+ log(ERROR, traceback.format_exc())
96
+ raise
97
+
98
+ except Exception as ex: # pylint: disable=broad-exception-caught
99
+ log(ERROR, ex)
100
+ log(ERROR, traceback.format_exc())
101
+ break
102
+
103
+
104
+ async def add_taskins_to_queue(
105
+ queue: "asyncio.Queue[TaskIns]",
106
+ state_factory: StateFactory,
107
+ nodes_mapping: NodeToPartitionMapping,
108
+ backend: Backend,
109
+ consumers: List["asyncio.Task[None]"],
110
+ f_stop: asyncio.Event,
111
+ ) -> None:
112
+ """Retrieve TaskIns and add it to the queue."""
113
+ state = state_factory.state()
114
+ num_initial_consumers = len(consumers)
115
+ while not f_stop.is_set():
116
+ for node_id in nodes_mapping.keys():
117
+ task_ins = state.get_task_ins(node_id=node_id, limit=1)
118
+ if task_ins:
119
+ await queue.put(task_ins[0])
120
+
121
+ # Count consumers that are running
122
+ num_active = sum(not (cc.done()) for cc in consumers)
123
+
124
+ # Alert if number of consumers decreased by half
125
+ if num_active < num_initial_consumers // 2:
126
+ log(
127
+ WARN,
128
+ "Number of active workers has more than halved: (%i/%i active)",
129
+ num_active,
130
+ num_initial_consumers,
131
+ )
132
+
133
+ # Break if consumers died
134
+ if num_active == 0:
135
+ raise RuntimeError("All workers have died. Ending Simulation.")
136
+
137
+ # Log some stats
138
+ log(
139
+ DEBUG,
140
+ "Simulation Engine stats: "
141
+ "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
142
+ num_active,
143
+ num_initial_consumers,
144
+ backend.__class__.__name__,
145
+ backend.num_workers,
146
+ queue.qsize(),
147
+ )
148
+ await asyncio.sleep(1.0)
149
+ log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
150
+
151
+
152
+ async def run(
153
+ app_fn: Callable[[], ClientApp],
154
+ backend_fn: Callable[[], Backend],
155
+ nodes_mapping: NodeToPartitionMapping,
156
+ state_factory: StateFactory,
157
+ node_states: Dict[int, NodeState],
158
+ f_stop: asyncio.Event,
159
+ ) -> None:
160
+ """Run the VCE async."""
161
+ queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
162
+
163
+ try:
164
+
165
+ # Instantiate backend
166
+ backend = backend_fn()
167
+
168
+ # Build backend
169
+ await backend.build()
170
+
171
+ # Add workers (they submit Messages to Backend)
172
+ worker_tasks = [
173
+ asyncio.create_task(
174
+ worker(
175
+ app_fn, queue, node_states, state_factory, nodes_mapping, backend
176
+ )
177
+ )
178
+ for _ in range(backend.num_workers)
179
+ ]
180
+ # Create producer (adds TaskIns into Queue)
181
+ producer = asyncio.create_task(
182
+ add_taskins_to_queue(
183
+ queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
184
+ )
185
+ )
186
+
187
+ # Wait for producer to finish
188
+ # The producer runs forever until f_stop is set or until
189
+ # all worker (consumer) coroutines are completed. Workers
190
+ # also run forever and only end if an exception is raised.
191
+ await asyncio.gather(producer)
192
+
193
+ except Exception as ex:
194
+
195
+ log(ERROR, "An exception occured!! %s", ex)
196
+ log(ERROR, traceback.format_exc())
197
+ log(WARN, "Stopping Simulation Engine.")
198
+
199
+ # Manually trigger stopping event
200
+ f_stop.set()
201
+
202
+ # Raise exception
203
+ raise RuntimeError("Simulation Engine crashed.") from ex
204
+
205
+ finally:
206
+ # Produced task terminated, now cancel worker tasks
207
+ for w_t in worker_tasks:
208
+ _ = w_t.cancel()
209
+
210
+ while not all(w_t.done() for w_t in worker_tasks):
211
+ log(DEBUG, "Terminating async workers...")
212
+ await asyncio.sleep(0.5)
213
+
214
+ await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
215
+
216
+ # Terminate backend
217
+ await backend.terminate()
218
+
219
+
220
+ # pylint: disable=too-many-arguments,unused-argument,too-many-locals
46
221
  def start_vce(
47
- num_supernodes: int,
48
222
  client_app_module_name: str,
49
223
  backend_name: str,
50
224
  backend_config_json_stream: str,
51
- state_factory: StateFactory,
52
225
  working_dir: str,
53
- f_stop: Optional[asyncio.Event] = None,
226
+ f_stop: asyncio.Event,
227
+ num_supernodes: Optional[int] = None,
228
+ state_factory: Optional[StateFactory] = None,
229
+ existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
54
230
  ) -> None:
55
- """Start Fleet API with the VirtualClientEngine (VCE)."""
56
- # Register SuperNodes
57
- nodes_mapping = _register_nodes(
58
- num_nodes=num_supernodes, state_factory=state_factory
59
- )
231
+ """Start Fleet API with the Simulation Engine."""
232
+ if num_supernodes is not None and existing_nodes_mapping is not None:
233
+ raise ValueError(
234
+ "Both `num_supernodes` and `existing_nodes_mapping` are provided, "
235
+ "but only one is allowed."
236
+ )
237
+ if num_supernodes is None:
238
+ if state_factory is None or existing_nodes_mapping is None:
239
+ raise ValueError(
240
+ "If not passing an existing `state_factory` and associated "
241
+ "`existing_nodes_mapping` you must supply `num_supernodes` to indicate "
242
+ "how many nodes to insert into a new StateFactory that will be created."
243
+ )
244
+ if existing_nodes_mapping:
245
+ if state_factory is None:
246
+ raise ValueError(
247
+ "`existing_nodes_mapping` was passed, but no `state_factory` was "
248
+ "passed."
249
+ )
250
+ log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.")
251
+ # Use mapping constructed externally. This also means nodes
252
+ # have previously being registered.
253
+ nodes_mapping = existing_nodes_mapping
254
+
255
+ if not state_factory:
256
+ log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
257
+ # Create an empty in-memory state factory
258
+ state_factory = StateFactory(":flwr-in-memory-state:")
259
+ log(INFO, "Created new %s.", state_factory.__class__.__name__)
260
+
261
+ if num_supernodes:
262
+ # Register SuperNodes
263
+ nodes_mapping = _register_nodes(
264
+ num_nodes=num_supernodes, state_factory=state_factory
265
+ )
60
266
 
61
267
  # Construct mapping of NodeStates
62
268
  node_states: Dict[int, NodeState] = {}
@@ -69,7 +275,6 @@ def start_vce(
69
275
 
70
276
  try:
71
277
  backend_type = supported_backends[backend_name]
72
- _ = backend_type(backend_config, work_dir=working_dir)
73
278
  except KeyError as ex:
74
279
  log(
75
280
  ERROR,
@@ -83,10 +288,25 @@ def start_vce(
83
288
 
84
289
  raise ex
85
290
 
291
+ def backend_fn() -> Backend:
292
+ """Instantiate a Backend."""
293
+ return backend_type(backend_config, work_dir=working_dir)
294
+
86
295
  log(INFO, "client_app_module_name = %s", client_app_module_name)
87
296
 
88
297
  def _load() -> ClientApp:
89
298
  app: ClientApp = load_client_app(client_app_module_name)
90
299
  return app
91
300
 
92
- # start backend
301
+ app_fn = _load
302
+
303
+ asyncio.run(
304
+ run(
305
+ app_fn,
306
+ backend_fn,
307
+ nodes_mapping,
308
+ state_factory,
309
+ node_states,
310
+ f_stop,
311
+ )
312
+ )
flwr/server/typing.py CHANGED
@@ -22,3 +22,4 @@ from flwr.common import Context
22
22
  from .driver import Driver
23
23
 
24
24
  ServerAppCallable = Callable[[Driver, Context], None]
25
+ Workflow = Callable[[Driver, Context], None]
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflows."""
16
+
17
+
18
+ from .default_workflows import DefaultWorkflow
19
+
20
+ __all__ = [
21
+ "DefaultWorkflow",
22
+ ]