flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,20 +12,25 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service client."""
15
+ """Flower gRPC Driver."""
16
16
 
17
-
18
- from logging import DEBUG, ERROR, WARNING
19
- from typing import Optional
17
+ import time
18
+ import warnings
19
+ from logging import DEBUG, WARNING
20
+ from typing import Iterable, List, Optional, cast
20
21
 
21
22
  import grpc
22
23
 
23
- from flwr.common import EventType, event
24
+ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
24
25
  from flwr.common.grpc import create_channel
25
26
  from flwr.common.logger import log
27
+ from flwr.common.serde import (
28
+ message_from_taskres,
29
+ message_to_taskins,
30
+ user_config_from_proto,
31
+ )
32
+ from flwr.common.typing import Run
26
33
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
27
- CreateRunRequest,
28
- CreateRunResponse,
29
34
  GetNodesRequest,
30
35
  GetNodesResponse,
31
36
  PullTaskResRequest,
@@ -34,96 +39,241 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
34
39
  PushTaskInsResponse,
35
40
  )
36
41
  from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
42
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
43
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
44
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
45
+
46
+ from .driver import Driver
37
47
 
38
48
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
39
49
 
40
50
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
41
51
  [Driver] Error: Not connected.
42
52
 
43
- Call `connect()` on the `GrpcDriver` instance before calling any of the other
44
- `GrpcDriver` methods.
53
+ Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
54
+ `GrpcDriverStub` methods.
45
55
  """
46
56
 
47
57
 
48
- class GrpcDriver:
49
- """`GrpcDriver` provides access to the gRPC Driver API/service."""
58
+ class GrpcDriver(Driver):
59
+ """`GrpcDriver` provides an interface to the Driver API.
60
+
61
+ Parameters
62
+ ----------
63
+ run_id : int
64
+ The identifier of the run.
65
+ driver_service_address : str (default: "[::]:9091")
66
+ The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
67
+ root_certificates : Optional[bytes] (default: None)
68
+ The PEM-encoded root certificates as a byte string.
69
+ If provided, a secure connection using the certificates will be
70
+ established to an SSL-enabled Flower server.
71
+ """
50
72
 
51
- def __init__(
73
+ def __init__( # pylint: disable=too-many-arguments
52
74
  self,
75
+ run_id: int,
53
76
  driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
54
77
  root_certificates: Optional[bytes] = None,
55
78
  ) -> None:
56
- self.driver_service_address = driver_service_address
57
- self.root_certificates = root_certificates
58
- self.channel: Optional[grpc.Channel] = None
59
- self.stub: Optional[DriverStub] = None
79
+ self._run_id = run_id
80
+ self._addr = driver_service_address
81
+ self._cert = root_certificates
82
+ self._run: Optional[Run] = None
83
+ self._grpc_stub: Optional[DriverStub] = None
84
+ self._channel: Optional[grpc.Channel] = None
85
+ self.node = Node(node_id=0, anonymous=True)
60
86
 
61
- def connect(self) -> None:
62
- """Connect to the Driver API."""
87
+ @property
88
+ def _is_connected(self) -> bool:
89
+ """Check if connected to the Driver API server."""
90
+ return self._channel is not None
91
+
92
+ def _connect(self) -> None:
93
+ """Connect to the Driver API.
94
+
95
+ This will not call GetRun.
96
+ """
63
97
  event(EventType.DRIVER_CONNECT)
64
- if self.channel is not None or self.stub is not None:
98
+ if self._is_connected:
65
99
  log(WARNING, "Already connected")
66
100
  return
67
- self.channel = create_channel(
68
- server_address=self.driver_service_address,
69
- insecure=(self.root_certificates is None),
70
- root_certificates=self.root_certificates,
101
+ self._channel = create_channel(
102
+ server_address=self._addr,
103
+ insecure=(self._cert is None),
104
+ root_certificates=self._cert,
71
105
  )
72
- self.stub = DriverStub(self.channel)
73
- log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
106
+ self._grpc_stub = DriverStub(self._channel)
107
+ log(DEBUG, "[Driver] Connected to %s", self._addr)
74
108
 
75
- def disconnect(self) -> None:
109
+ def _disconnect(self) -> None:
76
110
  """Disconnect from the Driver API."""
77
111
  event(EventType.DRIVER_DISCONNECT)
78
- if self.channel is None or self.stub is None:
112
+ if not self._is_connected:
79
113
  log(DEBUG, "Already disconnected")
80
114
  return
81
- channel = self.channel
82
- self.channel = None
83
- self.stub = None
115
+ channel: grpc.Channel = self._channel
116
+ self._channel = None
117
+ self._grpc_stub = None
84
118
  channel.close()
85
119
  log(DEBUG, "[Driver] Disconnected")
86
120
 
87
- def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
88
- """Request for run ID."""
89
- # Check if channel is open
90
- if self.stub is None:
91
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
92
- raise ConnectionError("`GrpcDriver` instance not connected")
93
-
94
- # Call Driver API
95
- res: CreateRunResponse = self.stub.CreateRun(request=req)
96
- return res
97
-
98
- def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
99
- """Get client IDs."""
100
- # Check if channel is open
101
- if self.stub is None:
102
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
103
- raise ConnectionError("`GrpcDriver` instance not connected")
104
-
105
- # Call gRPC Driver API
106
- res: GetNodesResponse = self.stub.GetNodes(request=req)
107
- return res
108
-
109
- def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
110
- """Schedule tasks."""
111
- # Check if channel is open
112
- if self.stub is None:
113
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
114
- raise ConnectionError("`GrpcDriver` instance not connected")
115
-
116
- # Call gRPC Driver API
117
- res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
118
- return res
119
-
120
- def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
121
- """Get task results."""
122
- # Check if channel is open
123
- if self.stub is None:
124
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
125
- raise ConnectionError("`GrpcDriver` instance not connected")
126
-
127
- # Call Driver API
128
- res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
129
- return res
121
+ def _init_run(self) -> None:
122
+ # Check if is initialized
123
+ if self._run is not None:
124
+ return
125
+ # Get the run info
126
+ req = GetRunRequest(run_id=self._run_id)
127
+ res: GetRunResponse = self._stub.GetRun(req)
128
+ if not res.HasField("run"):
129
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
130
+ self._run = Run(
131
+ run_id=res.run.run_id,
132
+ fab_id=res.run.fab_id,
133
+ fab_version=res.run.fab_version,
134
+ override_config=user_config_from_proto(res.run.override_config),
135
+ )
136
+
137
+ @property
138
+ def run(self) -> Run:
139
+ """Run information."""
140
+ self._init_run()
141
+ return Run(**vars(self._run))
142
+
143
+ @property
144
+ def _stub(self) -> DriverStub:
145
+ """Driver stub."""
146
+ if not self._is_connected:
147
+ self._connect()
148
+ return cast(DriverStub, self._grpc_stub)
149
+
150
+ def _check_message(self, message: Message) -> None:
151
+ # Check if the message is valid
152
+ if not (
153
+ # Assume self._run being initialized
154
+ message.metadata.run_id == self._run_id
155
+ and message.metadata.src_node_id == self.node.node_id
156
+ and message.metadata.message_id == ""
157
+ and message.metadata.reply_to_message == ""
158
+ and message.metadata.ttl > 0
159
+ ):
160
+ raise ValueError(f"Invalid message: {message}")
161
+
162
+ def create_message( # pylint: disable=too-many-arguments
163
+ self,
164
+ content: RecordSet,
165
+ message_type: str,
166
+ dst_node_id: int,
167
+ group_id: str,
168
+ ttl: Optional[float] = None,
169
+ ) -> Message:
170
+ """Create a new message with specified parameters.
171
+
172
+ This method constructs a new `Message` with given content and metadata.
173
+ The `run_id` and `src_node_id` will be set automatically.
174
+ """
175
+ self._init_run()
176
+ if ttl:
177
+ warnings.warn(
178
+ "A custom TTL was set, but note that the SuperLink does not enforce "
179
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
180
+ "version of Flower.",
181
+ stacklevel=2,
182
+ )
183
+
184
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
185
+ metadata = Metadata(
186
+ run_id=self._run_id,
187
+ message_id="", # Will be set by the server
188
+ src_node_id=self.node.node_id,
189
+ dst_node_id=dst_node_id,
190
+ reply_to_message="",
191
+ group_id=group_id,
192
+ ttl=ttl_,
193
+ message_type=message_type,
194
+ )
195
+ return Message(metadata=metadata, content=content)
196
+
197
+ def get_node_ids(self) -> List[int]:
198
+ """Get node IDs."""
199
+ self._init_run()
200
+ # Call GrpcDriverStub method
201
+ res: GetNodesResponse = self._stub.GetNodes(
202
+ GetNodesRequest(run_id=self._run_id)
203
+ )
204
+ return [node.node_id for node in res.nodes]
205
+
206
+ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
207
+ """Push messages to specified node IDs.
208
+
209
+ This method takes an iterable of messages and sends each message
210
+ to the node specified in `dst_node_id`.
211
+ """
212
+ self._init_run()
213
+ # Construct TaskIns
214
+ task_ins_list: List[TaskIns] = []
215
+ for msg in messages:
216
+ # Check message
217
+ self._check_message(msg)
218
+ # Convert Message to TaskIns
219
+ taskins = message_to_taskins(msg)
220
+ # Add to list
221
+ task_ins_list.append(taskins)
222
+ # Call GrpcDriverStub method
223
+ res: PushTaskInsResponse = self._stub.PushTaskIns(
224
+ PushTaskInsRequest(task_ins_list=task_ins_list)
225
+ )
226
+ return list(res.task_ids)
227
+
228
+ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
229
+ """Pull messages based on message IDs.
230
+
231
+ This method is used to collect messages from the SuperLink that correspond to a
232
+ set of given message IDs.
233
+ """
234
+ self._init_run()
235
+ # Pull TaskRes
236
+ res: PullTaskResResponse = self._stub.PullTaskRes(
237
+ PullTaskResRequest(node=self.node, task_ids=message_ids)
238
+ )
239
+ # Convert TaskRes to Message
240
+ msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
241
+ return msgs
242
+
243
+ def send_and_receive(
244
+ self,
245
+ messages: Iterable[Message],
246
+ *,
247
+ timeout: Optional[float] = None,
248
+ ) -> Iterable[Message]:
249
+ """Push messages to specified node IDs and pull the reply messages.
250
+
251
+ This method sends a list of messages to their destination node IDs and then
252
+ waits for the replies. It continues to pull replies until either all replies are
253
+ received or the specified timeout duration is exceeded.
254
+ """
255
+ # Push messages
256
+ msg_ids = set(self.push_messages(messages))
257
+
258
+ # Pull messages
259
+ end_time = time.time() + (timeout if timeout is not None else 0.0)
260
+ ret: List[Message] = []
261
+ while timeout is None or time.time() < end_time:
262
+ res_msgs = self.pull_messages(msg_ids)
263
+ ret.extend(res_msgs)
264
+ msg_ids.difference_update(
265
+ {msg.metadata.reply_to_message for msg in res_msgs}
266
+ )
267
+ if len(msg_ids) == 0:
268
+ break
269
+ # Sleep
270
+ time.sleep(3)
271
+ return ret
272
+
273
+ def close(self) -> None:
274
+ """Disconnect from the SuperLink if connected."""
275
+ # Check if `connect` was called before
276
+ if not self._is_connected:
277
+ return
278
+ # Disconnect
279
+ self._disconnect()
@@ -0,0 +1,183 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower in-memory Driver."""
16
+
17
+
18
+ import time
19
+ import warnings
20
+ from typing import Iterable, List, Optional, cast
21
+ from uuid import UUID
22
+
23
+ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
24
+ from flwr.common.serde import message_from_taskres, message_to_taskins
25
+ from flwr.common.typing import Run
26
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
27
+ from flwr.server.superlink.state import StateFactory
28
+
29
+ from .driver import Driver
30
+
31
+
32
+ class InMemoryDriver(Driver):
33
+ """`InMemoryDriver` class provides an interface to the Driver API.
34
+
35
+ Parameters
36
+ ----------
37
+ run_id : int
38
+ The identifier of the run.
39
+ state_factory : StateFactory
40
+ A StateFactory embedding a state that this driver can interface with.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ run_id: int,
46
+ state_factory: StateFactory,
47
+ ) -> None:
48
+ self._run_id = run_id
49
+ self._run: Optional[Run] = None
50
+ self.state = state_factory.state()
51
+ self.node = Node(node_id=0, anonymous=True)
52
+
53
+ def _check_message(self, message: Message) -> None:
54
+ self._init_run()
55
+ # Check if the message is valid
56
+ if not (
57
+ message.metadata.run_id == cast(Run, self._run).run_id
58
+ and message.metadata.src_node_id == self.node.node_id
59
+ and message.metadata.message_id == ""
60
+ and message.metadata.reply_to_message == ""
61
+ and message.metadata.ttl > 0
62
+ ):
63
+ raise ValueError(f"Invalid message: {message}")
64
+
65
+ def _init_run(self) -> None:
66
+ """Initialize the run."""
67
+ if self._run is not None:
68
+ return
69
+ run = self.state.get_run(self._run_id)
70
+ if run is None:
71
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
72
+ self._run = run
73
+
74
+ @property
75
+ def run(self) -> Run:
76
+ """Run ID."""
77
+ self._init_run()
78
+ return Run(**vars(cast(Run, self._run)))
79
+
80
+ def create_message( # pylint: disable=too-many-arguments
81
+ self,
82
+ content: RecordSet,
83
+ message_type: str,
84
+ dst_node_id: int,
85
+ group_id: str,
86
+ ttl: Optional[float] = None,
87
+ ) -> Message:
88
+ """Create a new message with specified parameters.
89
+
90
+ This method constructs a new `Message` with given content and metadata.
91
+ The `run_id` and `src_node_id` will be set automatically.
92
+ """
93
+ self._init_run()
94
+ if ttl:
95
+ warnings.warn(
96
+ "A custom TTL was set, but note that the SuperLink does not enforce "
97
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
98
+ "version of Flower.",
99
+ stacklevel=2,
100
+ )
101
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
102
+
103
+ metadata = Metadata(
104
+ run_id=cast(Run, self._run).run_id,
105
+ message_id="", # Will be set by the server
106
+ src_node_id=self.node.node_id,
107
+ dst_node_id=dst_node_id,
108
+ reply_to_message="",
109
+ group_id=group_id,
110
+ ttl=ttl_,
111
+ message_type=message_type,
112
+ )
113
+ return Message(metadata=metadata, content=content)
114
+
115
+ def get_node_ids(self) -> List[int]:
116
+ """Get node IDs."""
117
+ self._init_run()
118
+ return list(self.state.get_nodes(cast(Run, self._run).run_id))
119
+
120
+ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
121
+ """Push messages to specified node IDs.
122
+
123
+ This method takes an iterable of messages and sends each message
124
+ to the node specified in `dst_node_id`.
125
+ """
126
+ task_ids: List[str] = []
127
+ for msg in messages:
128
+ # Check message
129
+ self._check_message(msg)
130
+ # Convert Message to TaskIns
131
+ taskins = message_to_taskins(msg)
132
+ # Store in state
133
+ taskins.task.pushed_at = time.time()
134
+ task_id = self.state.store_task_ins(taskins)
135
+ if task_id:
136
+ task_ids.append(str(task_id))
137
+
138
+ return task_ids
139
+
140
+ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
141
+ """Pull messages based on message IDs.
142
+
143
+ This method is used to collect messages from the SuperLink that correspond to a
144
+ set of given message IDs.
145
+ """
146
+ msg_ids = {UUID(msg_id) for msg_id in message_ids}
147
+ # Pull TaskRes
148
+ task_res_list = self.state.get_task_res(task_ids=msg_ids, limit=len(msg_ids))
149
+ # Delete tasks in state
150
+ self.state.delete_tasks(msg_ids)
151
+ # Convert TaskRes to Message
152
+ msgs = [message_from_taskres(taskres) for taskres in task_res_list]
153
+ return msgs
154
+
155
+ def send_and_receive(
156
+ self,
157
+ messages: Iterable[Message],
158
+ *,
159
+ timeout: Optional[float] = None,
160
+ ) -> Iterable[Message]:
161
+ """Push messages to specified node IDs and pull the reply messages.
162
+
163
+ This method sends a list of messages to their destination node IDs and then
164
+ waits for the replies. It continues to pull replies until either all replies are
165
+ received or the specified timeout duration is exceeded.
166
+ """
167
+ # Push messages
168
+ msg_ids = set(self.push_messages(messages))
169
+
170
+ # Pull messages
171
+ end_time = time.time() + (timeout if timeout is not None else 0.0)
172
+ ret: List[Message] = []
173
+ while timeout is None or time.time() < end_time:
174
+ res_msgs = self.pull_messages(msg_ids)
175
+ ret.extend(res_msgs)
176
+ msg_ids.difference_update(
177
+ {msg.metadata.reply_to_message for msg in res_msgs}
178
+ )
179
+ if len(msg_ids) == 0:
180
+ break
181
+ # Sleep
182
+ time.sleep(3)
183
+ return ret
flwr/server/history.py CHANGED
@@ -91,32 +91,32 @@ class History:
91
91
  """
92
92
  rep = ""
93
93
  if self.losses_distributed:
94
- rep += "History (loss, distributed):\n" + pprint.pformat(
95
- reduce(
96
- lambda a, b: a + b,
97
- [
98
- f"\tround {server_round}: {loss}\n"
99
- for server_round, loss in self.losses_distributed
100
- ],
101
- )
94
+ rep += "History (loss, distributed):\n" + reduce(
95
+ lambda a, b: a + b,
96
+ [
97
+ f"\tround {server_round}: {loss}\n"
98
+ for server_round, loss in self.losses_distributed
99
+ ],
102
100
  )
103
101
  if self.losses_centralized:
104
- rep += "History (loss, centralized):\n" + pprint.pformat(
105
- reduce(
106
- lambda a, b: a + b,
107
- [
108
- f"\tround {server_round}: {loss}\n"
109
- for server_round, loss in self.losses_centralized
110
- ],
111
- )
102
+ rep += "History (loss, centralized):\n" + reduce(
103
+ lambda a, b: a + b,
104
+ [
105
+ f"\tround {server_round}: {loss}\n"
106
+ for server_round, loss in self.losses_centralized
107
+ ],
112
108
  )
113
109
  if self.metrics_distributed_fit:
114
- rep += "History (metrics, distributed, fit):\n" + pprint.pformat(
115
- self.metrics_distributed_fit
110
+ rep += (
111
+ "History (metrics, distributed, fit):\n"
112
+ + pprint.pformat(self.metrics_distributed_fit)
113
+ + "\n"
116
114
  )
117
115
  if self.metrics_distributed:
118
- rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat(
119
- self.metrics_distributed
116
+ rep += (
117
+ "History (metrics, distributed, evaluate):\n"
118
+ + pprint.pformat(self.metrics_distributed)
119
+ + "\n"
120
120
  )
121
121
  if self.metrics_centralized:
122
122
  rep += "History (metrics, centralized):\n" + pprint.pformat(