flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/proto/task_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.error_pb2
6
7
  import flwr.proto.node_pb2
7
8
  import flwr.proto.recordset_pb2
8
9
  import google.protobuf.descriptor
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
23
24
  ANCESTRY_FIELD_NUMBER: builtins.int
24
25
  TASK_TYPE_FIELD_NUMBER: builtins.int
25
26
  RECORDSET_FIELD_NUMBER: builtins.int
27
+ ERROR_FIELD_NUMBER: builtins.int
26
28
  @property
27
29
  def producer(self) -> flwr.proto.node_pb2.Node: ...
28
30
  @property
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
35
37
  task_type: typing.Text
36
38
  @property
37
39
  def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
40
+ @property
41
+ def error(self) -> flwr.proto.error_pb2.Error: ...
38
42
  def __init__(self,
39
43
  *,
40
44
  producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
45
49
  ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
46
50
  task_type: typing.Text = ...,
47
51
  recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
52
+ error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
48
53
  ) -> None: ...
49
- def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
50
- def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
54
+ def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
55
+ def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
51
56
  global___Task = Task
52
57
 
53
58
  class TaskIns(google.protobuf.message.Message):
flwr/server/__init__.py CHANGED
@@ -16,12 +16,14 @@
16
16
 
17
17
 
18
18
  from . import strategy
19
+ from . import workflow as workflow
19
20
  from .app import run_driver_api as run_driver_api
20
21
  from .app import run_fleet_api as run_fleet_api
21
22
  from .app import run_superlink as run_superlink
22
23
  from .app import start_server as start_server
23
24
  from .client_manager import ClientManager as ClientManager
24
25
  from .client_manager import SimpleClientManager as SimpleClientManager
26
+ from .compat import LegacyContext as LegacyContext
25
27
  from .compat import start_driver as start_driver
26
28
  from .driver import Driver as Driver
27
29
  from .history import History as History
@@ -34,6 +36,7 @@ __all__ = [
34
36
  "ClientManager",
35
37
  "Driver",
36
38
  "History",
39
+ "LegacyContext",
37
40
  "run_driver_api",
38
41
  "run_fleet_api",
39
42
  "run_server_app",
@@ -45,4 +48,5 @@ __all__ = [
45
48
  "start_driver",
46
49
  "start_server",
47
50
  "strategy",
51
+ "workflow",
48
52
  ]
flwr/server/app.py CHANGED
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Flower server app."""
16
16
 
17
-
18
17
  import argparse
18
+ import asyncio
19
19
  import importlib.util
20
20
  import sys
21
21
  import threading
@@ -36,9 +36,6 @@ from flwr.common.constant import (
36
36
  )
37
37
  from flwr.common.exit_handlers import register_exit_handlers
38
38
  from flwr.common.logger import log
39
- from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
40
- add_DriverServicer_to_server,
41
- )
42
39
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
43
40
  add_FleetServicer_to_server,
44
41
  )
@@ -48,7 +45,7 @@ from .history import History
48
45
  from .server import Server, init_defaults, run_fl
49
46
  from .server_config import ServerConfig
50
47
  from .strategy import Strategy
51
- from .superlink.driver.driver_servicer import DriverServicer
48
+ from .superlink.driver.driver_grpc import run_driver_api_grpc
52
49
  from .superlink.fleet.grpc_bidi.grpc_server import (
53
50
  generic_create_grpc_server,
54
51
  start_grpc_server,
@@ -204,7 +201,7 @@ def run_driver_api() -> None:
204
201
  state_factory = StateFactory(args.database)
205
202
 
206
203
  # Start server
207
- grpc_server: grpc.Server = _run_driver_api_grpc(
204
+ grpc_server: grpc.Server = run_driver_api_grpc(
208
205
  address=address,
209
206
  state_factory=state_factory,
210
207
  certificates=certificates,
@@ -313,7 +310,7 @@ def run_superlink() -> None:
313
310
  state_factory = StateFactory(args.database)
314
311
 
315
312
  # Start Driver API
316
- driver_server: grpc.Server = _run_driver_api_grpc(
313
+ driver_server: grpc.Server = run_driver_api_grpc(
317
314
  address=address,
318
315
  state_factory=state_factory,
319
316
  certificates=certificates,
@@ -362,6 +359,7 @@ def run_superlink() -> None:
362
359
  )
363
360
  grpc_servers.append(fleet_server)
364
361
  elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
362
+ f_stop = asyncio.Event() # Does nothing
365
363
  _run_fleet_api_vce(
366
364
  num_supernodes=args.num_supernodes,
367
365
  client_app_module_name=args.client_app,
@@ -369,6 +367,7 @@ def run_superlink() -> None:
369
367
  backend_config_json_stream=args.backend_config,
370
368
  working_dir=args.dir,
371
369
  state_factory=state_factory,
370
+ f_stop=f_stop,
372
371
  )
373
372
  else:
374
373
  raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
@@ -412,30 +411,6 @@ def _try_obtain_certificates(
412
411
  return certificates
413
412
 
414
413
 
415
- def _run_driver_api_grpc(
416
- address: str,
417
- state_factory: StateFactory,
418
- certificates: Optional[Tuple[bytes, bytes, bytes]],
419
- ) -> grpc.Server:
420
- """Run Driver API (gRPC, request-response)."""
421
- # Create Driver API gRPC server
422
- driver_servicer: grpc.Server = DriverServicer(
423
- state_factory=state_factory,
424
- )
425
- driver_add_servicer_to_server_fn = add_DriverServicer_to_server
426
- driver_grpc_server = generic_create_grpc_server(
427
- servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
428
- server_address=address,
429
- max_message_length=GRPC_MAX_MESSAGE_LENGTH,
430
- certificates=certificates,
431
- )
432
-
433
- log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
434
- driver_grpc_server.start()
435
-
436
- return driver_grpc_server
437
-
438
-
439
414
  def _run_fleet_api_grpc_rere(
440
415
  address: str,
441
416
  state_factory: StateFactory,
@@ -468,6 +443,7 @@ def _run_fleet_api_vce(
468
443
  backend_config_json_stream: str,
469
444
  working_dir: str,
470
445
  state_factory: StateFactory,
446
+ f_stop: asyncio.Event,
471
447
  ) -> None:
472
448
  log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)")
473
449
 
@@ -478,6 +454,7 @@ def _run_fleet_api_vce(
478
454
  backend_config_json_stream=backend_config_json_stream,
479
455
  state_factory=state_factory,
480
456
  working_dir=working_dir,
457
+ f_stop=f_stop,
481
458
  )
482
459
 
483
460
 
@@ -47,6 +47,7 @@ class ClientProxy(ABC):
47
47
  self,
48
48
  ins: GetPropertiesIns,
49
49
  timeout: Optional[float],
50
+ group_id: Optional[int],
50
51
  ) -> GetPropertiesRes:
51
52
  """Return the client's properties."""
52
53
 
@@ -55,6 +56,7 @@ class ClientProxy(ABC):
55
56
  self,
56
57
  ins: GetParametersIns,
57
58
  timeout: Optional[float],
59
+ group_id: Optional[int],
58
60
  ) -> GetParametersRes:
59
61
  """Return the current local model parameters."""
60
62
 
@@ -63,6 +65,7 @@ class ClientProxy(ABC):
63
65
  self,
64
66
  ins: FitIns,
65
67
  timeout: Optional[float],
68
+ group_id: Optional[int],
66
69
  ) -> FitRes:
67
70
  """Refine the provided parameters using the locally held dataset."""
68
71
 
@@ -71,6 +74,7 @@ class ClientProxy(ABC):
71
74
  self,
72
75
  ins: EvaluateIns,
73
76
  timeout: Optional[float],
77
+ group_id: Optional[int],
74
78
  ) -> EvaluateRes:
75
79
  """Evaluate the provided parameters using the locally held dataset."""
76
80
 
@@ -79,5 +83,6 @@ class ClientProxy(ABC):
79
83
  self,
80
84
  ins: ReconnectIns,
81
85
  timeout: Optional[float],
86
+ group_id: Optional[int],
82
87
  ) -> DisconnectRes:
83
88
  """Disconnect and (optionally) reconnect later."""
@@ -16,7 +16,9 @@
16
16
 
17
17
 
18
18
  from .app import start_driver as start_driver
19
+ from .legacy_context import LegacyContext as LegacyContext
19
20
 
20
21
  __all__ = [
22
+ "LegacyContext",
21
23
  "start_driver",
22
24
  ]
flwr/server/compat/app.py CHANGED
@@ -16,16 +16,13 @@
16
16
 
17
17
 
18
18
  import sys
19
- import threading
20
- import time
21
19
  from logging import INFO
22
20
  from pathlib import Path
23
- from typing import Dict, Optional, Union
21
+ from typing import Optional, Union
24
22
 
25
23
  from flwr.common import EventType, event
26
24
  from flwr.common.address import parse_address
27
25
  from flwr.common.logger import log, warn_deprecated_feature
28
- from flwr.proto import driver_pb2 # pylint: disable=E0611
29
26
  from flwr.server.client_manager import ClientManager
30
27
  from flwr.server.history import History
31
28
  from flwr.server.server import Server, init_defaults, run_fl
@@ -33,8 +30,7 @@ from flwr.server.server_config import ServerConfig
33
30
  from flwr.server.strategy import Strategy
34
31
 
35
32
  from ..driver import Driver
36
- from ..driver.grpc_driver import GrpcDriver
37
- from .driver_client_proxy import DriverClientProxy
33
+ from .app_utils import start_update_client_manager_thread
38
34
 
39
35
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
40
36
 
@@ -104,11 +100,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
104
100
  """
105
101
  event(EventType.START_DRIVER_ENTER)
106
102
 
107
- if driver:
108
- # pylint: disable=protected-access
109
- grpc_driver, _ = driver._get_grpc_driver_and_run_id()
110
- # pylint: enable=protected-access
111
- else:
103
+ if driver is None:
112
104
  # Not passing a `Driver` object is deprecated
113
105
  warn_deprecated_feature("start_driver")
114
106
 
@@ -122,12 +114,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
122
114
  # Create the Driver
123
115
  if isinstance(root_certificates, str):
124
116
  root_certificates = Path(root_certificates).read_bytes()
125
- grpc_driver = GrpcDriver(
117
+ driver = Driver(
126
118
  driver_service_address=address, root_certificates=root_certificates
127
119
  )
128
- grpc_driver.connect()
129
-
130
- lock = threading.Lock()
131
120
 
132
121
  # Initialize the Driver API server and config
133
122
  initialized_server, initialized_config = init_defaults(
@@ -142,18 +131,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
142
131
  initialized_config,
143
132
  )
144
133
 
145
- f_stop = threading.Event()
146
134
  # Start the thread updating nodes
147
- thread = threading.Thread(
148
- target=update_client_manager,
149
- args=(
150
- grpc_driver,
151
- initialized_server.client_manager(),
152
- lock,
153
- f_stop,
154
- ),
135
+ thread, f_stop = start_update_client_manager_thread(
136
+ driver, initialized_server.client_manager()
155
137
  )
156
- thread.start()
157
138
 
158
139
  # Start training
159
140
  hist = run_fl(
@@ -164,72 +145,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
164
145
  f_stop.set()
165
146
 
166
147
  # Stop the Driver API server and the thread
167
- with lock:
168
- if driver:
169
- del driver
170
- else:
171
- grpc_driver.disconnect()
148
+ del driver
172
149
 
173
150
  thread.join()
174
151
 
175
152
  event(EventType.START_SERVER_LEAVE)
176
153
 
177
154
  return hist
178
-
179
-
180
- def update_client_manager(
181
- driver: GrpcDriver,
182
- client_manager: ClientManager,
183
- lock: threading.Lock,
184
- f_stop: threading.Event,
185
- ) -> None:
186
- """Update the nodes list in the client manager.
187
-
188
- This function periodically communicates with the associated driver to get all
189
- node_ids. Each node_id is then converted into a `DriverClientProxy` instance
190
- and stored in the `registered_nodes` dictionary with node_id as key.
191
-
192
- New nodes will be added to the ClientManager via `client_manager.register()`,
193
- and dead nodes will be removed from the ClientManager via
194
- `client_manager.unregister()`.
195
- """
196
- # Request for run_id
197
- run_id = driver.create_run(
198
- driver_pb2.CreateRunRequest() # pylint: disable=E1101
199
- ).run_id
200
-
201
- # Loop until the driver is disconnected
202
- registered_nodes: Dict[int, DriverClientProxy] = {}
203
- while not f_stop.is_set():
204
- with lock:
205
- # End the while loop if the driver is disconnected
206
- if driver.stub is None:
207
- break
208
- get_nodes_res = driver.get_nodes(
209
- req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
210
- )
211
- all_node_ids = {node.node_id for node in get_nodes_res.nodes}
212
- dead_nodes = set(registered_nodes).difference(all_node_ids)
213
- new_nodes = all_node_ids.difference(registered_nodes)
214
-
215
- # Unregister dead nodes
216
- for node_id in dead_nodes:
217
- client_proxy = registered_nodes[node_id]
218
- client_manager.unregister(client_proxy)
219
- del registered_nodes[node_id]
220
-
221
- # Register new nodes
222
- for node_id in new_nodes:
223
- client_proxy = DriverClientProxy(
224
- node_id=node_id,
225
- driver=driver,
226
- anonymous=False,
227
- run_id=run_id,
228
- )
229
- if client_manager.register(client_proxy):
230
- registered_nodes[node_id] = client_proxy
231
- else:
232
- raise RuntimeError("Could not register node.")
233
-
234
- # Sleep for 3 seconds
235
- time.sleep(3)
@@ -0,0 +1,102 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for the `start_driver`."""
16
+
17
+
18
+ import threading
19
+ import time
20
+ from typing import Dict, Tuple
21
+
22
+ from ..client_manager import ClientManager
23
+ from ..compat.driver_client_proxy import DriverClientProxy
24
+ from ..driver import Driver
25
+
26
+
27
+ def start_update_client_manager_thread(
28
+ driver: Driver,
29
+ client_manager: ClientManager,
30
+ ) -> Tuple[threading.Thread, threading.Event]:
31
+ """Periodically update the nodes list in the client manager in a thread.
32
+
33
+ This function starts a thread that periodically uses the associated driver to
34
+ get all node_ids. Each node_id is then converted into a `DriverClientProxy`
35
+ instance and stored in the `registered_nodes` dictionary with node_id as key.
36
+
37
+ New nodes will be added to the ClientManager via `client_manager.register()`,
38
+ and dead nodes will be removed from the ClientManager via
39
+ `client_manager.unregister()`.
40
+
41
+ Parameters
42
+ ----------
43
+ driver : Driver
44
+ The Driver object to use.
45
+ client_manager : ClientManager
46
+ The ClientManager object to be updated.
47
+
48
+ Returns
49
+ -------
50
+ threading.Thread
51
+ A thread that updates the ClientManager and handles the stop event.
52
+ threading.Event
53
+ An event that, when set, signals the thread to stop.
54
+ """
55
+ f_stop = threading.Event()
56
+ thread = threading.Thread(
57
+ target=_update_client_manager,
58
+ args=(
59
+ driver,
60
+ client_manager,
61
+ f_stop,
62
+ ),
63
+ )
64
+ thread.start()
65
+
66
+ return thread, f_stop
67
+
68
+
69
+ def _update_client_manager(
70
+ driver: Driver,
71
+ client_manager: ClientManager,
72
+ f_stop: threading.Event,
73
+ ) -> None:
74
+ """Update the nodes list in the client manager."""
75
+ # Loop until the driver is disconnected
76
+ registered_nodes: Dict[int, DriverClientProxy] = {}
77
+ while not f_stop.is_set():
78
+ all_node_ids = set(driver.get_node_ids())
79
+ dead_nodes = set(registered_nodes).difference(all_node_ids)
80
+ new_nodes = all_node_ids.difference(registered_nodes)
81
+
82
+ # Unregister dead nodes
83
+ for node_id in dead_nodes:
84
+ client_proxy = registered_nodes[node_id]
85
+ client_manager.unregister(client_proxy)
86
+ del registered_nodes[node_id]
87
+
88
+ # Register new nodes
89
+ for node_id in new_nodes:
90
+ client_proxy = DriverClientProxy(
91
+ node_id=node_id,
92
+ driver=driver.grpc_driver, # type: ignore
93
+ anonymous=False,
94
+ run_id=driver.run_id, # type: ignore
95
+ )
96
+ if client_manager.register(client_proxy):
97
+ registered_nodes[node_id] = client_proxy
98
+ else:
99
+ raise RuntimeError("Could not register node.")
100
+
101
+ # Sleep for 3 seconds
102
+ time.sleep(3)
@@ -47,57 +47,68 @@ class DriverClientProxy(ClientProxy):
47
47
  self.anonymous = anonymous
48
48
 
49
49
  def get_properties(
50
- self, ins: common.GetPropertiesIns, timeout: Optional[float]
50
+ self,
51
+ ins: common.GetPropertiesIns,
52
+ timeout: Optional[float],
53
+ group_id: Optional[int],
51
54
  ) -> common.GetPropertiesRes:
52
55
  """Return client's properties."""
53
56
  # Ins to RecordSet
54
57
  out_recordset = compat.getpropertiesins_to_recordset(ins)
55
58
  # Fetch response
56
59
  in_recordset = self._send_receive_recordset(
57
- out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout
60
+ out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout, group_id
58
61
  )
59
62
  # RecordSet to Res
60
63
  return compat.recordset_to_getpropertiesres(in_recordset)
61
64
 
62
65
  def get_parameters(
63
- self, ins: common.GetParametersIns, timeout: Optional[float]
66
+ self,
67
+ ins: common.GetParametersIns,
68
+ timeout: Optional[float],
69
+ group_id: Optional[int],
64
70
  ) -> common.GetParametersRes:
65
71
  """Return the current local model parameters."""
66
72
  # Ins to RecordSet
67
73
  out_recordset = compat.getparametersins_to_recordset(ins)
68
74
  # Fetch response
69
75
  in_recordset = self._send_receive_recordset(
70
- out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout
76
+ out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout, group_id
71
77
  )
72
78
  # RecordSet to Res
73
79
  return compat.recordset_to_getparametersres(in_recordset, False)
74
80
 
75
- def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
81
+ def fit(
82
+ self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
83
+ ) -> common.FitRes:
76
84
  """Train model parameters on the locally held dataset."""
77
85
  # Ins to RecordSet
78
86
  out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
79
87
  # Fetch response
80
88
  in_recordset = self._send_receive_recordset(
81
- out_recordset, MESSAGE_TYPE_FIT, timeout
89
+ out_recordset, MESSAGE_TYPE_FIT, timeout, group_id
82
90
  )
83
91
  # RecordSet to Res
84
92
  return compat.recordset_to_fitres(in_recordset, keep_input=False)
85
93
 
86
94
  def evaluate(
87
- self, ins: common.EvaluateIns, timeout: Optional[float]
95
+ self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
88
96
  ) -> common.EvaluateRes:
89
97
  """Evaluate model parameters on the locally held dataset."""
90
98
  # Ins to RecordSet
91
99
  out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
92
100
  # Fetch response
93
101
  in_recordset = self._send_receive_recordset(
94
- out_recordset, MESSAGE_TYPE_EVALUATE, timeout
102
+ out_recordset, MESSAGE_TYPE_EVALUATE, timeout, group_id
95
103
  )
96
104
  # RecordSet to Res
97
105
  return compat.recordset_to_evaluateres(in_recordset)
98
106
 
99
107
  def reconnect(
100
- self, ins: common.ReconnectIns, timeout: Optional[float]
108
+ self,
109
+ ins: common.ReconnectIns,
110
+ timeout: Optional[float],
111
+ group_id: Optional[int],
101
112
  ) -> common.DisconnectRes:
102
113
  """Disconnect and (optionally) reconnect later."""
103
114
  return common.DisconnectRes(reason="") # Nothing to do here (yet)
@@ -107,10 +118,11 @@ class DriverClientProxy(ClientProxy):
107
118
  recordset: RecordSet,
108
119
  task_type: str,
109
120
  timeout: Optional[float],
121
+ group_id: Optional[int],
110
122
  ) -> RecordSet:
111
123
  task_ins = task_pb2.TaskIns( # pylint: disable=E1101
112
124
  task_id="",
113
- group_id="",
125
+ group_id=str(group_id) if group_id is not None else "",
114
126
  run_id=self.run_id,
115
127
  task=task_pb2.Task( # pylint: disable=E1101
116
128
  producer=node_pb2.Node( # pylint: disable=E1101
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Legacy Context."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ from flwr.common import Context, RecordSet
22
+
23
+ from ..client_manager import ClientManager, SimpleClientManager
24
+ from ..history import History
25
+ from ..server_config import ServerConfig
26
+ from ..strategy import FedAvg, Strategy
27
+
28
+
29
+ @dataclass
30
+ class LegacyContext(Context):
31
+ """Legacy Context."""
32
+
33
+ config: ServerConfig
34
+ strategy: Strategy
35
+ client_manager: ClientManager
36
+ history: History
37
+
38
+ def __init__(
39
+ self,
40
+ state: RecordSet,
41
+ config: Optional[ServerConfig] = None,
42
+ strategy: Optional[Strategy] = None,
43
+ client_manager: Optional[ClientManager] = None,
44
+ ) -> None:
45
+ if config is None:
46
+ config = ServerConfig()
47
+ if strategy is None:
48
+ strategy = FedAvg()
49
+ if client_manager is None:
50
+ client_manager = SimpleClientManager()
51
+ self.config = config
52
+ self.strategy = strategy
53
+ self.client_manager = client_manager
54
+ self.history = History()
55
+ super().__init__(state)
@@ -107,7 +107,7 @@ def run_server_app() -> None:
107
107
  run(server_app_attr, driver, server_app_dir)
108
108
 
109
109
  # Clean up
110
- del driver
110
+ driver.__del__() # pylint: disable=unnecessary-dunder-call
111
111
 
112
112
  event(EventType.RUN_SERVER_APP_LEAVE)
113
113