flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/proto/task_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.error_pb2
6
7
  import flwr.proto.node_pb2
7
8
  import flwr.proto.recordset_pb2
8
9
  import google.protobuf.descriptor
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
23
24
  ANCESTRY_FIELD_NUMBER: builtins.int
24
25
  TASK_TYPE_FIELD_NUMBER: builtins.int
25
26
  RECORDSET_FIELD_NUMBER: builtins.int
27
+ ERROR_FIELD_NUMBER: builtins.int
26
28
  @property
27
29
  def producer(self) -> flwr.proto.node_pb2.Node: ...
28
30
  @property
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
35
37
  task_type: typing.Text
36
38
  @property
37
39
  def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
40
+ @property
41
+ def error(self) -> flwr.proto.error_pb2.Error: ...
38
42
  def __init__(self,
39
43
  *,
40
44
  producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
45
49
  ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
46
50
  task_type: typing.Text = ...,
47
51
  recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
52
+ error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
48
53
  ) -> None: ...
49
- def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
50
- def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
54
+ def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
55
+ def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
51
56
  global___Task = Task
52
57
 
53
58
  class TaskIns(google.protobuf.message.Message):
flwr/server/__init__.py CHANGED
@@ -16,12 +16,14 @@
16
16
 
17
17
 
18
18
  from . import strategy
19
+ from . import workflow as workflow
19
20
  from .app import run_driver_api as run_driver_api
20
21
  from .app import run_fleet_api as run_fleet_api
21
22
  from .app import run_superlink as run_superlink
22
23
  from .app import start_server as start_server
23
24
  from .client_manager import ClientManager as ClientManager
24
25
  from .client_manager import SimpleClientManager as SimpleClientManager
26
+ from .compat import LegacyContext as LegacyContext
25
27
  from .compat import start_driver as start_driver
26
28
  from .driver import Driver as Driver
27
29
  from .history import History as History
@@ -34,6 +36,7 @@ __all__ = [
34
36
  "ClientManager",
35
37
  "Driver",
36
38
  "History",
39
+ "LegacyContext",
37
40
  "run_driver_api",
38
41
  "run_fleet_api",
39
42
  "run_server_app",
@@ -45,4 +48,5 @@ __all__ = [
45
48
  "start_driver",
46
49
  "start_server",
47
50
  "strategy",
51
+ "workflow",
48
52
  ]
flwr/server/app.py CHANGED
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Flower server app."""
16
16
 
17
-
18
17
  import argparse
18
+ import asyncio
19
19
  import importlib.util
20
20
  import sys
21
21
  import threading
@@ -36,9 +36,6 @@ from flwr.common.constant import (
36
36
  )
37
37
  from flwr.common.exit_handlers import register_exit_handlers
38
38
  from flwr.common.logger import log
39
- from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
40
- add_DriverServicer_to_server,
41
- )
42
39
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
43
40
  add_FleetServicer_to_server,
44
41
  )
@@ -48,7 +45,7 @@ from .history import History
48
45
  from .server import Server, init_defaults, run_fl
49
46
  from .server_config import ServerConfig
50
47
  from .strategy import Strategy
51
- from .superlink.driver.driver_servicer import DriverServicer
48
+ from .superlink.driver.driver_grpc import run_driver_api_grpc
52
49
  from .superlink.fleet.grpc_bidi.grpc_server import (
53
50
  generic_create_grpc_server,
54
51
  start_grpc_server,
@@ -204,7 +201,7 @@ def run_driver_api() -> None:
204
201
  state_factory = StateFactory(args.database)
205
202
 
206
203
  # Start server
207
- grpc_server: grpc.Server = _run_driver_api_grpc(
204
+ grpc_server: grpc.Server = run_driver_api_grpc(
208
205
  address=address,
209
206
  state_factory=state_factory,
210
207
  certificates=certificates,
@@ -313,7 +310,7 @@ def run_superlink() -> None:
313
310
  state_factory = StateFactory(args.database)
314
311
 
315
312
  # Start Driver API
316
- driver_server: grpc.Server = _run_driver_api_grpc(
313
+ driver_server: grpc.Server = run_driver_api_grpc(
317
314
  address=address,
318
315
  state_factory=state_factory,
319
316
  certificates=certificates,
@@ -362,6 +359,7 @@ def run_superlink() -> None:
362
359
  )
363
360
  grpc_servers.append(fleet_server)
364
361
  elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
362
+ f_stop = asyncio.Event() # Does nothing
365
363
  _run_fleet_api_vce(
366
364
  num_supernodes=args.num_supernodes,
367
365
  client_app_module_name=args.client_app,
@@ -369,6 +367,7 @@ def run_superlink() -> None:
369
367
  backend_config_json_stream=args.backend_config,
370
368
  working_dir=args.dir,
371
369
  state_factory=state_factory,
370
+ f_stop=f_stop,
372
371
  )
373
372
  else:
374
373
  raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
@@ -412,30 +411,6 @@ def _try_obtain_certificates(
412
411
  return certificates
413
412
 
414
413
 
415
- def _run_driver_api_grpc(
416
- address: str,
417
- state_factory: StateFactory,
418
- certificates: Optional[Tuple[bytes, bytes, bytes]],
419
- ) -> grpc.Server:
420
- """Run Driver API (gRPC, request-response)."""
421
- # Create Driver API gRPC server
422
- driver_servicer: grpc.Server = DriverServicer(
423
- state_factory=state_factory,
424
- )
425
- driver_add_servicer_to_server_fn = add_DriverServicer_to_server
426
- driver_grpc_server = generic_create_grpc_server(
427
- servicer_and_add_fn=(driver_servicer, driver_add_servicer_to_server_fn),
428
- server_address=address,
429
- max_message_length=GRPC_MAX_MESSAGE_LENGTH,
430
- certificates=certificates,
431
- )
432
-
433
- log(INFO, "Flower ECE: Starting Driver API (gRPC-rere) on %s", address)
434
- driver_grpc_server.start()
435
-
436
- return driver_grpc_server
437
-
438
-
439
414
  def _run_fleet_api_grpc_rere(
440
415
  address: str,
441
416
  state_factory: StateFactory,
@@ -468,6 +443,7 @@ def _run_fleet_api_vce(
468
443
  backend_config_json_stream: str,
469
444
  working_dir: str,
470
445
  state_factory: StateFactory,
446
+ f_stop: asyncio.Event,
471
447
  ) -> None:
472
448
  log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)")
473
449
 
@@ -478,6 +454,7 @@ def _run_fleet_api_vce(
478
454
  backend_config_json_stream=backend_config_json_stream,
479
455
  state_factory=state_factory,
480
456
  working_dir=working_dir,
457
+ f_stop=f_stop,
481
458
  )
482
459
 
483
460
 
@@ -47,6 +47,7 @@ class ClientProxy(ABC):
47
47
  self,
48
48
  ins: GetPropertiesIns,
49
49
  timeout: Optional[float],
50
+ group_id: Optional[int],
50
51
  ) -> GetPropertiesRes:
51
52
  """Return the client's properties."""
52
53
 
@@ -55,6 +56,7 @@ class ClientProxy(ABC):
55
56
  self,
56
57
  ins: GetParametersIns,
57
58
  timeout: Optional[float],
59
+ group_id: Optional[int],
58
60
  ) -> GetParametersRes:
59
61
  """Return the current local model parameters."""
60
62
 
@@ -63,6 +65,7 @@ class ClientProxy(ABC):
63
65
  self,
64
66
  ins: FitIns,
65
67
  timeout: Optional[float],
68
+ group_id: Optional[int],
66
69
  ) -> FitRes:
67
70
  """Refine the provided parameters using the locally held dataset."""
68
71
 
@@ -71,6 +74,7 @@ class ClientProxy(ABC):
71
74
  self,
72
75
  ins: EvaluateIns,
73
76
  timeout: Optional[float],
77
+ group_id: Optional[int],
74
78
  ) -> EvaluateRes:
75
79
  """Evaluate the provided parameters using the locally held dataset."""
76
80
 
@@ -79,5 +83,6 @@ class ClientProxy(ABC):
79
83
  self,
80
84
  ins: ReconnectIns,
81
85
  timeout: Optional[float],
86
+ group_id: Optional[int],
82
87
  ) -> DisconnectRes:
83
88
  """Disconnect and (optionally) reconnect later."""
@@ -16,7 +16,9 @@
16
16
 
17
17
 
18
18
  from .app import start_driver as start_driver
19
+ from .legacy_context import LegacyContext as LegacyContext
19
20
 
20
21
  __all__ = [
22
+ "LegacyContext",
21
23
  "start_driver",
22
24
  ]
flwr/server/compat/app.py CHANGED
@@ -16,16 +16,13 @@
16
16
 
17
17
 
18
18
  import sys
19
- import threading
20
- import time
21
19
  from logging import INFO
22
20
  from pathlib import Path
23
- from typing import Dict, Optional, Union
21
+ from typing import Optional, Union
24
22
 
25
23
  from flwr.common import EventType, event
26
24
  from flwr.common.address import parse_address
27
25
  from flwr.common.logger import log, warn_deprecated_feature
28
- from flwr.proto import driver_pb2 # pylint: disable=E0611
29
26
  from flwr.server.client_manager import ClientManager
30
27
  from flwr.server.history import History
31
28
  from flwr.server.server import Server, init_defaults, run_fl
@@ -33,8 +30,7 @@ from flwr.server.server_config import ServerConfig
33
30
  from flwr.server.strategy import Strategy
34
31
 
35
32
  from ..driver import Driver
36
- from ..driver.grpc_driver import GrpcDriver
37
- from .driver_client_proxy import DriverClientProxy
33
+ from .app_utils import start_update_client_manager_thread
38
34
 
39
35
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
40
36
 
@@ -104,11 +100,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
104
100
  """
105
101
  event(EventType.START_DRIVER_ENTER)
106
102
 
107
- if driver:
108
- # pylint: disable=protected-access
109
- grpc_driver, _ = driver._get_grpc_driver_and_run_id()
110
- # pylint: enable=protected-access
111
- else:
103
+ if driver is None:
112
104
  # Not passing a `Driver` object is deprecated
113
105
  warn_deprecated_feature("start_driver")
114
106
 
@@ -122,12 +114,9 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
122
114
  # Create the Driver
123
115
  if isinstance(root_certificates, str):
124
116
  root_certificates = Path(root_certificates).read_bytes()
125
- grpc_driver = GrpcDriver(
117
+ driver = Driver(
126
118
  driver_service_address=address, root_certificates=root_certificates
127
119
  )
128
- grpc_driver.connect()
129
-
130
- lock = threading.Lock()
131
120
 
132
121
  # Initialize the Driver API server and config
133
122
  initialized_server, initialized_config = init_defaults(
@@ -142,18 +131,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
142
131
  initialized_config,
143
132
  )
144
133
 
145
- f_stop = threading.Event()
146
134
  # Start the thread updating nodes
147
- thread = threading.Thread(
148
- target=update_client_manager,
149
- args=(
150
- grpc_driver,
151
- initialized_server.client_manager(),
152
- lock,
153
- f_stop,
154
- ),
135
+ thread, f_stop = start_update_client_manager_thread(
136
+ driver, initialized_server.client_manager()
155
137
  )
156
- thread.start()
157
138
 
158
139
  # Start training
159
140
  hist = run_fl(
@@ -164,72 +145,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
164
145
  f_stop.set()
165
146
 
166
147
  # Stop the Driver API server and the thread
167
- with lock:
168
- if driver:
169
- del driver
170
- else:
171
- grpc_driver.disconnect()
148
+ del driver
172
149
 
173
150
  thread.join()
174
151
 
175
152
  event(EventType.START_SERVER_LEAVE)
176
153
 
177
154
  return hist
178
-
179
-
180
- def update_client_manager(
181
- driver: GrpcDriver,
182
- client_manager: ClientManager,
183
- lock: threading.Lock,
184
- f_stop: threading.Event,
185
- ) -> None:
186
- """Update the nodes list in the client manager.
187
-
188
- This function periodically communicates with the associated driver to get all
189
- node_ids. Each node_id is then converted into a `DriverClientProxy` instance
190
- and stored in the `registered_nodes` dictionary with node_id as key.
191
-
192
- New nodes will be added to the ClientManager via `client_manager.register()`,
193
- and dead nodes will be removed from the ClientManager via
194
- `client_manager.unregister()`.
195
- """
196
- # Request for run_id
197
- run_id = driver.create_run(
198
- driver_pb2.CreateRunRequest() # pylint: disable=E1101
199
- ).run_id
200
-
201
- # Loop until the driver is disconnected
202
- registered_nodes: Dict[int, DriverClientProxy] = {}
203
- while not f_stop.is_set():
204
- with lock:
205
- # End the while loop if the driver is disconnected
206
- if driver.stub is None:
207
- break
208
- get_nodes_res = driver.get_nodes(
209
- req=driver_pb2.GetNodesRequest(run_id=run_id) # pylint: disable=E1101
210
- )
211
- all_node_ids = {node.node_id for node in get_nodes_res.nodes}
212
- dead_nodes = set(registered_nodes).difference(all_node_ids)
213
- new_nodes = all_node_ids.difference(registered_nodes)
214
-
215
- # Unregister dead nodes
216
- for node_id in dead_nodes:
217
- client_proxy = registered_nodes[node_id]
218
- client_manager.unregister(client_proxy)
219
- del registered_nodes[node_id]
220
-
221
- # Register new nodes
222
- for node_id in new_nodes:
223
- client_proxy = DriverClientProxy(
224
- node_id=node_id,
225
- driver=driver,
226
- anonymous=False,
227
- run_id=run_id,
228
- )
229
- if client_manager.register(client_proxy):
230
- registered_nodes[node_id] = client_proxy
231
- else:
232
- raise RuntimeError("Could not register node.")
233
-
234
- # Sleep for 3 seconds
235
- time.sleep(3)
@@ -0,0 +1,102 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for the `start_driver`."""
16
+
17
+
18
+ import threading
19
+ import time
20
+ from typing import Dict, Tuple
21
+
22
+ from ..client_manager import ClientManager
23
+ from ..compat.driver_client_proxy import DriverClientProxy
24
+ from ..driver import Driver
25
+
26
+
27
+ def start_update_client_manager_thread(
28
+ driver: Driver,
29
+ client_manager: ClientManager,
30
+ ) -> Tuple[threading.Thread, threading.Event]:
31
+ """Periodically update the nodes list in the client manager in a thread.
32
+
33
+ This function starts a thread that periodically uses the associated driver to
34
+ get all node_ids. Each node_id is then converted into a `DriverClientProxy`
35
+ instance and stored in the `registered_nodes` dictionary with node_id as key.
36
+
37
+ New nodes will be added to the ClientManager via `client_manager.register()`,
38
+ and dead nodes will be removed from the ClientManager via
39
+ `client_manager.unregister()`.
40
+
41
+ Parameters
42
+ ----------
43
+ driver : Driver
44
+ The Driver object to use.
45
+ client_manager : ClientManager
46
+ The ClientManager object to be updated.
47
+
48
+ Returns
49
+ -------
50
+ threading.Thread
51
+ A thread that updates the ClientManager and handles the stop event.
52
+ threading.Event
53
+ An event that, when set, signals the thread to stop.
54
+ """
55
+ f_stop = threading.Event()
56
+ thread = threading.Thread(
57
+ target=_update_client_manager,
58
+ args=(
59
+ driver,
60
+ client_manager,
61
+ f_stop,
62
+ ),
63
+ )
64
+ thread.start()
65
+
66
+ return thread, f_stop
67
+
68
+
69
+ def _update_client_manager(
70
+ driver: Driver,
71
+ client_manager: ClientManager,
72
+ f_stop: threading.Event,
73
+ ) -> None:
74
+ """Update the nodes list in the client manager."""
75
+ # Loop until the driver is disconnected
76
+ registered_nodes: Dict[int, DriverClientProxy] = {}
77
+ while not f_stop.is_set():
78
+ all_node_ids = set(driver.get_node_ids())
79
+ dead_nodes = set(registered_nodes).difference(all_node_ids)
80
+ new_nodes = all_node_ids.difference(registered_nodes)
81
+
82
+ # Unregister dead nodes
83
+ for node_id in dead_nodes:
84
+ client_proxy = registered_nodes[node_id]
85
+ client_manager.unregister(client_proxy)
86
+ del registered_nodes[node_id]
87
+
88
+ # Register new nodes
89
+ for node_id in new_nodes:
90
+ client_proxy = DriverClientProxy(
91
+ node_id=node_id,
92
+ driver=driver.grpc_driver, # type: ignore
93
+ anonymous=False,
94
+ run_id=driver.run_id, # type: ignore
95
+ )
96
+ if client_manager.register(client_proxy):
97
+ registered_nodes[node_id] = client_proxy
98
+ else:
99
+ raise RuntimeError("Could not register node.")
100
+
101
+ # Sleep for 3 seconds
102
+ time.sleep(3)
@@ -47,57 +47,68 @@ class DriverClientProxy(ClientProxy):
47
47
  self.anonymous = anonymous
48
48
 
49
49
  def get_properties(
50
- self, ins: common.GetPropertiesIns, timeout: Optional[float]
50
+ self,
51
+ ins: common.GetPropertiesIns,
52
+ timeout: Optional[float],
53
+ group_id: Optional[int],
51
54
  ) -> common.GetPropertiesRes:
52
55
  """Return client's properties."""
53
56
  # Ins to RecordSet
54
57
  out_recordset = compat.getpropertiesins_to_recordset(ins)
55
58
  # Fetch response
56
59
  in_recordset = self._send_receive_recordset(
57
- out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout
60
+ out_recordset, MESSAGE_TYPE_GET_PROPERTIES, timeout, group_id
58
61
  )
59
62
  # RecordSet to Res
60
63
  return compat.recordset_to_getpropertiesres(in_recordset)
61
64
 
62
65
  def get_parameters(
63
- self, ins: common.GetParametersIns, timeout: Optional[float]
66
+ self,
67
+ ins: common.GetParametersIns,
68
+ timeout: Optional[float],
69
+ group_id: Optional[int],
64
70
  ) -> common.GetParametersRes:
65
71
  """Return the current local model parameters."""
66
72
  # Ins to RecordSet
67
73
  out_recordset = compat.getparametersins_to_recordset(ins)
68
74
  # Fetch response
69
75
  in_recordset = self._send_receive_recordset(
70
- out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout
76
+ out_recordset, MESSAGE_TYPE_GET_PARAMETERS, timeout, group_id
71
77
  )
72
78
  # RecordSet to Res
73
79
  return compat.recordset_to_getparametersres(in_recordset, False)
74
80
 
75
- def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
81
+ def fit(
82
+ self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
83
+ ) -> common.FitRes:
76
84
  """Train model parameters on the locally held dataset."""
77
85
  # Ins to RecordSet
78
86
  out_recordset = compat.fitins_to_recordset(ins, keep_input=True)
79
87
  # Fetch response
80
88
  in_recordset = self._send_receive_recordset(
81
- out_recordset, MESSAGE_TYPE_FIT, timeout
89
+ out_recordset, MESSAGE_TYPE_FIT, timeout, group_id
82
90
  )
83
91
  # RecordSet to Res
84
92
  return compat.recordset_to_fitres(in_recordset, keep_input=False)
85
93
 
86
94
  def evaluate(
87
- self, ins: common.EvaluateIns, timeout: Optional[float]
95
+ self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
88
96
  ) -> common.EvaluateRes:
89
97
  """Evaluate model parameters on the locally held dataset."""
90
98
  # Ins to RecordSet
91
99
  out_recordset = compat.evaluateins_to_recordset(ins, keep_input=True)
92
100
  # Fetch response
93
101
  in_recordset = self._send_receive_recordset(
94
- out_recordset, MESSAGE_TYPE_EVALUATE, timeout
102
+ out_recordset, MESSAGE_TYPE_EVALUATE, timeout, group_id
95
103
  )
96
104
  # RecordSet to Res
97
105
  return compat.recordset_to_evaluateres(in_recordset)
98
106
 
99
107
  def reconnect(
100
- self, ins: common.ReconnectIns, timeout: Optional[float]
108
+ self,
109
+ ins: common.ReconnectIns,
110
+ timeout: Optional[float],
111
+ group_id: Optional[int],
101
112
  ) -> common.DisconnectRes:
102
113
  """Disconnect and (optionally) reconnect later."""
103
114
  return common.DisconnectRes(reason="") # Nothing to do here (yet)
@@ -107,10 +118,11 @@ class DriverClientProxy(ClientProxy):
107
118
  recordset: RecordSet,
108
119
  task_type: str,
109
120
  timeout: Optional[float],
121
+ group_id: Optional[int],
110
122
  ) -> RecordSet:
111
123
  task_ins = task_pb2.TaskIns( # pylint: disable=E1101
112
124
  task_id="",
113
- group_id="",
125
+ group_id=str(group_id) if group_id is not None else "",
114
126
  run_id=self.run_id,
115
127
  task=task_pb2.Task( # pylint: disable=E1101
116
128
  producer=node_pb2.Node( # pylint: disable=E1101
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Legacy Context."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ from flwr.common import Context, RecordSet
22
+
23
+ from ..client_manager import ClientManager, SimpleClientManager
24
+ from ..history import History
25
+ from ..server_config import ServerConfig
26
+ from ..strategy import FedAvg, Strategy
27
+
28
+
29
+ @dataclass
30
+ class LegacyContext(Context):
31
+ """Legacy Context."""
32
+
33
+ config: ServerConfig
34
+ strategy: Strategy
35
+ client_manager: ClientManager
36
+ history: History
37
+
38
+ def __init__(
39
+ self,
40
+ state: RecordSet,
41
+ config: Optional[ServerConfig] = None,
42
+ strategy: Optional[Strategy] = None,
43
+ client_manager: Optional[ClientManager] = None,
44
+ ) -> None:
45
+ if config is None:
46
+ config = ServerConfig()
47
+ if strategy is None:
48
+ strategy = FedAvg()
49
+ if client_manager is None:
50
+ client_manager = SimpleClientManager()
51
+ self.config = config
52
+ self.strategy = strategy
53
+ self.client_manager = client_manager
54
+ self.history = History()
55
+ super().__init__(state)
@@ -107,7 +107,7 @@ def run_server_app() -> None:
107
107
  run(server_app_attr, driver, server_app_dir)
108
108
 
109
109
  # Clean up
110
- del driver
110
+ driver.__del__() # pylint: disable=unnecessary-dunder-call
111
111
 
112
112
  event(EventType.RUN_SERVER_APP_LEAVE)
113
113