flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -31,13 +31,21 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
31
31
  if not tasks_ins_res.HasField("task"):
32
32
  validation_errors.append("`task` does not set field `task`")
33
33
 
34
- # Created/delivered/TTL
35
- if tasks_ins_res.task.created_at != "":
36
- validation_errors.append("`created_at` must be an empty str")
34
+ # Created/delivered/TTL/Pushed
35
+ if (
36
+ tasks_ins_res.task.created_at < 1711497600.0
37
+ ): # unix timestamp of 27 March 2024 00h:00m:00s UTC
38
+ validation_errors.append(
39
+ "`created_at` must be a float that records the unix timestamp "
40
+ "in seconds when the message was created."
41
+ )
37
42
  if tasks_ins_res.task.delivered_at != "":
38
43
  validation_errors.append("`delivered_at` must be an empty str")
39
- if tasks_ins_res.task.ttl != "":
40
- validation_errors.append("`ttl` must be an empty str")
44
+ if tasks_ins_res.task.ttl <= 0:
45
+ validation_errors.append("`ttl` must be higher than zero")
46
+ if tasks_ins_res.task.pushed_at < 1711497600.0:
47
+ # unix timestamp of 27 March 2024 00h:00m:00s UTC
48
+ validation_errors.append("`pushed_at` is not a recent timestamp")
41
49
 
42
50
  # TaskIns specific
43
51
  if isinstance(tasks_ins_res, TaskIns):
@@ -66,8 +74,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
66
74
  # Content check
67
75
  if tasks_ins_res.task.task_type == "":
68
76
  validation_errors.append("`task_type` MUST be set")
69
- if not tasks_ins_res.task.HasField("recordset"):
70
- validation_errors.append("`recordset` MUST be set")
77
+ if not (
78
+ tasks_ins_res.task.HasField("recordset")
79
+ ^ tasks_ins_res.task.HasField("error")
80
+ ):
81
+ validation_errors.append("Either `recordset` or `error` MUST be set")
71
82
 
72
83
  # Ancestors
73
84
  if len(tasks_ins_res.task.ancestry) != 0:
@@ -106,8 +117,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str
106
117
  # Content check
107
118
  if tasks_ins_res.task.task_type == "":
108
119
  validation_errors.append("`task_type` MUST be set")
109
- if not tasks_ins_res.task.HasField("recordset"):
110
- validation_errors.append("`recordset` MUST be set")
120
+ if not (
121
+ tasks_ins_res.task.HasField("recordset")
122
+ ^ tasks_ins_res.task.HasField("error")
123
+ ):
124
+ validation_errors.append("Either `recordset` or `error` MUST be set")
111
125
 
112
126
  # Ancestors
113
127
  if len(tasks_ins_res.task.ancestry) == 0:
@@ -17,13 +17,23 @@
17
17
 
18
18
  import io
19
19
  import timeit
20
- from logging import INFO
21
- from typing import Optional, cast
20
+ from logging import INFO, WARN
21
+ from typing import List, Optional, Tuple, Union, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
- from flwr.common import ConfigsRecord, Context, GetParametersIns, log
24
+ from flwr.common import (
25
+ Code,
26
+ ConfigsRecord,
27
+ Context,
28
+ EvaluateRes,
29
+ FitRes,
30
+ GetParametersIns,
31
+ ParametersRecord,
32
+ log,
33
+ )
25
34
  from flwr.common.constant import MessageType, MessageTypeLegacy
26
35
 
36
+ from ..client_proxy import ClientProxy
27
37
  from ..compat.app_utils import start_update_client_manager_thread
28
38
  from ..compat.legacy_context import LegacyContext
29
39
  from ..driver import Driver
@@ -88,7 +98,12 @@ class DefaultWorkflow:
88
98
  hist = context.history
89
99
  log(INFO, "")
90
100
  log(INFO, "[SUMMARY]")
91
- log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
101
+ log(
102
+ INFO,
103
+ "Run finished %s round(s) in %.2fs",
104
+ context.config.num_rounds,
105
+ elapsed,
106
+ )
92
107
  for idx, line in enumerate(io.StringIO(str(hist))):
93
108
  if idx == 0:
94
109
  log(INFO, "%s", line.strip("\n"))
@@ -127,13 +142,27 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
127
142
  message_type=MessageTypeLegacy.GET_PARAMETERS,
128
143
  dst_node_id=random_client.node_id,
129
144
  group_id="0",
130
- ttl="",
131
145
  )
132
146
  ]
133
147
  )
134
- log(INFO, "Received initial parameters from one random client")
135
148
  msg = list(messages)[0]
136
- paramsrecord = next(iter(msg.content.parameters_records.values()))
149
+
150
+ if (
151
+ msg.has_content()
152
+ and compat._extract_status_from_recordset( # pylint: disable=W0212
153
+ "getparametersres", msg.content
154
+ ).code
155
+ == Code.OK
156
+ ):
157
+ log(INFO, "Received initial parameters from one random client")
158
+ paramsrecord = next(iter(msg.content.parameters_records.values()))
159
+ else:
160
+ log(
161
+ WARN,
162
+ "Failed to receive initial parameters from the client."
163
+ " Empty initial parameters will be used.",
164
+ )
165
+ paramsrecord = ParametersRecord()
137
166
 
138
167
  context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
139
168
 
@@ -226,7 +255,6 @@ def default_fit_workflow( # pylint: disable=R0914
226
255
  message_type=MessageType.TRAIN,
227
256
  dst_node_id=proxy.node_id,
228
257
  group_id=str(current_round),
229
- ttl="",
230
258
  )
231
259
  for proxy, fitins in client_instructions
232
260
  ]
@@ -246,14 +274,20 @@ def default_fit_workflow( # pylint: disable=R0914
246
274
  )
247
275
 
248
276
  # Aggregate training results
249
- results = [
250
- (
251
- node_id_to_proxy[msg.metadata.src_node_id],
252
- compat.recordset_to_fitres(msg.content, False),
253
- )
254
- for msg in messages
255
- ]
256
- aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
277
+ results: List[Tuple[ClientProxy, FitRes]] = []
278
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
279
+ for msg in messages:
280
+ if msg.has_content():
281
+ proxy = node_id_to_proxy[msg.metadata.src_node_id]
282
+ fitres = compat.recordset_to_fitres(msg.content, False)
283
+ if fitres.status.code == Code.OK:
284
+ results.append((proxy, fitres))
285
+ else:
286
+ failures.append((proxy, fitres))
287
+ else:
288
+ failures.append(Exception(msg.error))
289
+
290
+ aggregated_result = context.strategy.aggregate_fit(current_round, results, failures)
257
291
  parameters_aggregated, metrics_aggregated = aggregated_result
258
292
 
259
293
  # Update the parameters and write history
@@ -267,6 +301,7 @@ def default_fit_workflow( # pylint: disable=R0914
267
301
  )
268
302
 
269
303
 
304
+ # pylint: disable-next=R0914
270
305
  def default_evaluate_workflow(driver: Driver, context: Context) -> None:
271
306
  """Execute the default workflow for a single evaluate round."""
272
307
  if not isinstance(context, LegacyContext):
@@ -306,7 +341,6 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
306
341
  message_type=MessageType.EVALUATE,
307
342
  dst_node_id=proxy.node_id,
308
343
  group_id=str(current_round),
309
- ttl="",
310
344
  )
311
345
  for proxy, evalins in client_instructions
312
346
  ]
@@ -326,14 +360,22 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
326
360
  )
327
361
 
328
362
  # Aggregate the evaluation results
329
- results = [
330
- (
331
- node_id_to_proxy[msg.metadata.src_node_id],
332
- compat.recordset_to_evaluateres(msg.content),
333
- )
334
- for msg in messages
335
- ]
336
- aggregated_result = context.strategy.aggregate_evaluate(current_round, results, [])
363
+ results: List[Tuple[ClientProxy, EvaluateRes]] = []
364
+ failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = []
365
+ for msg in messages:
366
+ if msg.has_content():
367
+ proxy = node_id_to_proxy[msg.metadata.src_node_id]
368
+ evalres = compat.recordset_to_evaluateres(msg.content)
369
+ if evalres.status.code == Code.OK:
370
+ results.append((proxy, evalres))
371
+ else:
372
+ failures.append((proxy, evalres))
373
+ else:
374
+ failures.append(Exception(msg.error))
375
+
376
+ aggregated_result = context.strategy.aggregate_evaluate(
377
+ current_round, results, failures
378
+ )
337
379
 
338
380
  loss_aggregated, metrics_aggregated = aggregated_result
339
381
 
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
81
81
  forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
82
82
  aggregate_ndarrays: NDArrays = field(default_factory=list)
83
83
  legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
+ failures: List[Exception] = field(default_factory=list)
84
85
 
85
86
 
86
87
  class SecAggPlusWorkflow:
@@ -373,7 +374,6 @@ class SecAggPlusWorkflow:
373
374
  message_type=MessageType.TRAIN,
374
375
  dst_node_id=nid,
375
376
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
376
- ttl="",
377
377
  )
378
378
 
379
379
  log(
@@ -395,6 +395,7 @@ class SecAggPlusWorkflow:
395
395
 
396
396
  for msg in msgs:
397
397
  if msg.has_error():
398
+ state.failures.append(Exception(msg.error))
398
399
  continue
399
400
  key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
400
401
  node_id = msg.metadata.src_node_id
@@ -421,7 +422,6 @@ class SecAggPlusWorkflow:
421
422
  message_type=MessageType.TRAIN,
422
423
  dst_node_id=nid,
423
424
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
424
- ttl="",
425
425
  )
426
426
 
427
427
  # Broadcast public keys to clients and receive secret key shares
@@ -453,6 +453,9 @@ class SecAggPlusWorkflow:
453
453
  nid: [] for nid in state.active_node_ids
454
454
  } # dest node ID -> list of src node IDs
455
455
  for msg in msgs:
456
+ if msg.has_error():
457
+ state.failures.append(Exception(msg.error))
458
+ continue
456
459
  node_id = msg.metadata.src_node_id
457
460
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
458
461
  dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
@@ -492,7 +495,6 @@ class SecAggPlusWorkflow:
492
495
  message_type=MessageType.TRAIN,
493
496
  dst_node_id=nid,
494
497
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
495
- ttl="",
496
498
  )
497
499
 
498
500
  log(
@@ -518,6 +520,9 @@ class SecAggPlusWorkflow:
518
520
  # Sum collected masked vectors and compute active/dead node IDs
519
521
  masked_vector = None
520
522
  for msg in msgs:
523
+ if msg.has_error():
524
+ state.failures.append(Exception(msg.error))
525
+ continue
521
526
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
522
527
  bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
523
528
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
@@ -531,6 +536,9 @@ class SecAggPlusWorkflow:
531
536
 
532
537
  # Backward compatibility with Strategy
533
538
  for msg in msgs:
539
+ if msg.has_error():
540
+ state.failures.append(Exception(msg.error))
541
+ continue
534
542
  fitres = compat.recordset_to_fitres(msg.content, True)
535
543
  proxy = state.nid_to_proxies[msg.metadata.src_node_id]
536
544
  state.legacy_results.append((proxy, fitres))
@@ -563,7 +571,6 @@ class SecAggPlusWorkflow:
563
571
  message_type=MessageType.TRAIN,
564
572
  dst_node_id=nid,
565
573
  group_id=str(current_round),
566
- ttl="",
567
574
  )
568
575
 
569
576
  log(
@@ -588,6 +595,9 @@ class SecAggPlusWorkflow:
588
595
  for nid in state.sampled_node_ids:
589
596
  collected_shares_dict[nid] = []
590
597
  for msg in msgs:
598
+ if msg.has_error():
599
+ state.failures.append(Exception(msg.error))
600
+ continue
591
601
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
592
602
  nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
593
603
  shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
@@ -656,9 +666,11 @@ class SecAggPlusWorkflow:
656
666
  INFO,
657
667
  "aggregate_fit: received %s results and %s failures",
658
668
  len(results),
659
- 0,
669
+ len(state.failures),
670
+ )
671
+ aggregated_result = context.strategy.aggregate_fit(
672
+ current_round, results, state.failures # type: ignore
660
673
  )
661
- aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
662
674
  parameters_aggregated, metrics_aggregated = aggregated_result
663
675
 
664
676
  # Update the parameters and write history
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
17
17
 
18
18
  import importlib
19
19
 
20
- from flwr.simulation.run_simulation import run_simulation, run_simulation_from_cli
20
+ from flwr.simulation.run_simulation import run_simulation
21
21
 
22
22
  is_ray_installed = importlib.util.find_spec("ray") is not None
23
23
 
@@ -28,7 +28,7 @@ else:
28
28
 
29
29
  To install the necessary dependencies, install `flwr` with the `simulation` extra:
30
30
 
31
- pip install -U flwr["simulation"]
31
+ pip install -U "flwr[simulation]"
32
32
  """
33
33
 
34
34
  def start_simulation(*args, **kwargs): # type: ignore
@@ -36,4 +36,7 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
36
  raise ImportError(RAY_IMPORT_ERROR)
37
37
 
38
38
 
39
- __all__ = ["start_simulation", "run_simulation_from_cli", "run_simulation"]
39
+ __all__ = [
40
+ "run_simulation",
41
+ "start_simulation",
42
+ ]
flwr/simulation/app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
15
15
  """Flower simulation app."""
16
16
 
17
17
 
18
+ import asyncio
19
+ import logging
18
20
  import sys
19
21
  import threading
20
22
  import traceback
@@ -25,14 +27,16 @@ from typing import Any, Dict, List, Optional, Type, Union
25
27
  import ray
26
28
  from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
27
29
 
28
- from flwr.client import ClientFn
30
+ from flwr.client import ClientFnExt
29
31
  from flwr.common import EventType, event
30
- from flwr.common.logger import log
32
+ from flwr.common.constant import NODE_ID_NUM_BYTES
33
+ from flwr.common.logger import log, set_logger_propagation, warn_unsupported_feature
31
34
  from flwr.server.client_manager import ClientManager
32
35
  from flwr.server.history import History
33
36
  from flwr.server.server import Server, init_defaults, run_fl
34
37
  from flwr.server.server_config import ServerConfig
35
38
  from flwr.server.strategy import Strategy
39
+ from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
36
40
  from flwr.simulation.ray_transport.ray_actor import (
37
41
  ClientAppActor,
38
42
  VirtualClientEngineActor,
@@ -49,7 +53,7 @@ Invalid Arguments in method:
49
53
  `start_simulation(
50
54
  *,
51
55
  client_fn: ClientFn,
52
- num_clients: Optional[int] = None,
56
+ num_clients: int,
53
57
  clients_ids: Optional[List[str]] = None,
54
58
  client_resources: Optional[Dict[str, float]] = None,
55
59
  server: Optional[Server] = None,
@@ -68,13 +72,29 @@ REASON:
68
72
 
69
73
  """
70
74
 
75
+ NodeToPartitionMapping = Dict[int, int]
76
+
77
+
78
+ def _create_node_id_to_partition_mapping(
79
+ num_clients: int,
80
+ ) -> NodeToPartitionMapping:
81
+ """Generate a node_id:partition_id mapping."""
82
+ nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
83
+ for i in range(num_clients):
84
+ while True:
85
+ node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
86
+ if node_id not in nodes_mapping:
87
+ break
88
+ nodes_mapping[node_id] = i
89
+ return nodes_mapping
90
+
71
91
 
72
92
  # pylint: disable=too-many-arguments,too-many-statements,too-many-branches
73
93
  def start_simulation(
74
94
  *,
75
- client_fn: ClientFn,
76
- num_clients: Optional[int] = None,
77
- clients_ids: Optional[List[str]] = None,
95
+ client_fn: ClientFnExt,
96
+ num_clients: int,
97
+ clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED
78
98
  client_resources: Optional[Dict[str, float]] = None,
79
99
  server: Optional[Server] = None,
80
100
  config: Optional[ServerConfig] = None,
@@ -90,23 +110,24 @@ def start_simulation(
90
110
 
91
111
  Parameters
92
112
  ----------
93
- client_fn : ClientFn
94
- A function creating client instances. The function must take a single
95
- `str` argument called `cid`. It should return a single client instance
96
- of type Client. Note that the created client instances are ephemeral
97
- and will often be destroyed after a single method invocation. Since client
98
- instances are not long-lived, they should not attempt to carry state over
99
- method invocations. Any state required by the instance (model, dataset,
100
- hyperparameters, ...) should be (re-)created in either the call to `client_fn`
101
- or the call to any of the client methods (e.g., load evaluation data in the
102
- `evaluate` method itself).
103
- num_clients : Optional[int]
104
- The total number of clients in this simulation. This must be set if
105
- `clients_ids` is not set and vice-versa.
113
+ client_fn : ClientFnExt
114
+ A function creating `Client` instances. The function must have the signature
115
+ `client_fn(context: Context). It should return
116
+ a single client instance of type `Client`. Note that the created client
117
+ instances are ephemeral and will often be destroyed after a single method
118
+ invocation. Since client instances are not long-lived, they should not attempt
119
+ to carry state over method invocations. Any state required by the instance
120
+ (model, dataset, hyperparameters, ...) should be (re-)created in either the
121
+ call to `client_fn` or the call to any of the client methods (e.g., load
122
+ evaluation data in the `evaluate` method itself).
123
+ num_clients : int
124
+ The total number of clients in this simulation.
106
125
  clients_ids : Optional[List[str]]
126
+ UNSUPPORTED, WILL BE REMOVED. USE `num_clients` INSTEAD.
107
127
  List `client_id`s for each client. This is only required if
108
128
  `num_clients` is not set. Setting both `num_clients` and `clients_ids`
109
129
  with `len(clients_ids)` not equal to `num_clients` generates an error.
130
+ Using this argument will raise an error.
110
131
  client_resources : Optional[Dict[str, float]] (default: `{"num_cpus": 1, "num_gpus": 0.0}`)
111
132
  CPU and GPU resources for a single client. Supported keys
112
133
  are `num_cpus` and `num_gpus`. To understand the GPU utilization caused by
@@ -167,6 +188,26 @@ def start_simulation(
167
188
  {"num_clients": len(clients_ids) if clients_ids is not None else num_clients},
168
189
  )
169
190
 
191
+ if clients_ids is not None:
192
+ warn_unsupported_feature(
193
+ "Passing `clients_ids` to `start_simulation` is deprecated and not longer "
194
+ "used by `start_simulation`. Use `num_clients` exclusively instead."
195
+ )
196
+ log(ERROR, "`clients_ids` argument used.")
197
+ sys.exit()
198
+
199
+ # Set logger propagation
200
+ loop: Optional[asyncio.AbstractEventLoop] = None
201
+ try:
202
+ loop = asyncio.get_running_loop()
203
+ except RuntimeError:
204
+ loop = None
205
+ finally:
206
+ if loop and loop.is_running():
207
+ # Set logger propagation to False to prevent duplicated log output in Colab.
208
+ logger = logging.getLogger("flwr")
209
+ _ = set_logger_propagation(logger, False)
210
+
170
211
  # Initialize server and server config
171
212
  initialized_server, initialized_config = init_defaults(
172
213
  server=server,
@@ -181,20 +222,8 @@ def start_simulation(
181
222
  initialized_config,
182
223
  )
183
224
 
184
- # clients_ids takes precedence
185
- cids: List[str]
186
- if clients_ids is not None:
187
- if (num_clients is not None) and (len(clients_ids) != num_clients):
188
- log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
189
- sys.exit()
190
- else:
191
- cids = clients_ids
192
- else:
193
- if num_clients is None:
194
- log(ERROR, INVALID_ARGUMENTS_START_SIMULATION)
195
- sys.exit()
196
- else:
197
- cids = [str(x) for x in range(num_clients)]
225
+ # Create node-id to partition-id mapping
226
+ nodes_mapping = _create_node_id_to_partition_mapping(num_clients)
198
227
 
199
228
  # Default arguments for Ray initialization
200
229
  if not ray_init_args:
@@ -293,10 +322,12 @@ def start_simulation(
293
322
  )
294
323
 
295
324
  # Register one RayClientProxy object for each client with the ClientManager
296
- for cid in cids:
325
+ for node_id, partition_id in nodes_mapping.items():
297
326
  client_proxy = RayActorClientProxy(
298
327
  client_fn=client_fn,
299
- cid=cid,
328
+ node_id=node_id,
329
+ partition_id=partition_id,
330
+ num_partitions=num_clients,
300
331
  actor_pool=pool,
301
332
  )
302
333
  initialized_server.client_manager().register(client=client_proxy)
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,9 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray-based Flower Actor and ActorPool implementation."""
16
16
 
17
- import asyncio
18
17
  import threading
19
- import traceback
20
18
  from abc import ABC
21
19
  from logging import DEBUG, ERROR, WARNING
22
20
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
@@ -25,22 +23,13 @@ import ray
25
23
  from ray import ObjectRef
26
24
  from ray.util.actor_pool import ActorPool
27
25
 
28
- from flwr.client.client_app import ClientApp, LoadClientAppError
26
+ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
29
27
  from flwr.common import Context, Message
30
28
  from flwr.common.logger import log
31
29
 
32
30
  ClientAppFn = Callable[[], ClientApp]
33
31
 
34
32
 
35
- class ClientException(Exception):
36
- """Raised when client side logic crashes with an exception."""
37
-
38
- def __init__(self, message: str):
39
- div = ">" * 7
40
- self.message = "\n" + div + "A ClientException occurred." + message
41
- super().__init__(self.message)
42
-
43
-
44
33
  class VirtualClientEngineActor(ABC):
45
34
  """Abstract base class for VirtualClientEngine Actors."""
46
35
 
@@ -71,17 +60,7 @@ class VirtualClientEngineActor(ABC):
71
60
  raise load_ex
72
61
 
73
62
  except Exception as ex:
74
- client_trace = traceback.format_exc()
75
- mssg = (
76
- "\n\tSomething went wrong when running your client run."
77
- "\n\tClient "
78
- + cid
79
- + " crashed when the "
80
- + self.__class__.__name__
81
- + " was running its run."
82
- "\n\tException triggered on the client side: " + client_trace,
83
- )
84
- raise ClientException(str(mssg)) from ex
63
+ raise ClientAppException(str(ex)) from ex
85
64
 
86
65
  return cid, out_message, context
87
66
 
@@ -419,12 +398,6 @@ class VirtualClientEngineActorPool(ActorPool):
419
398
  return self._fetch_future_result(cid)
420
399
 
421
400
 
422
- def init_ray(*args: Any, **kwargs: Any) -> None:
423
- """Intialises Ray if not already initialised."""
424
- if not ray.is_initialized():
425
- ray.init(*args, **kwargs)
426
-
427
-
428
401
  class BasicActorPool:
429
402
  """A basic actor pool."""
430
403
 
@@ -437,9 +410,7 @@ class BasicActorPool:
437
410
  self.client_resources = client_resources
438
411
 
439
412
  # Queue of idle actors
440
- self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue(
441
- maxsize=1024
442
- )
413
+ self.pool: List[VirtualClientEngineActor] = []
443
414
  self.num_actors = 0
444
415
 
445
416
  # Resolve arguments to pass during actor init
@@ -453,38 +424,37 @@ class BasicActorPool:
453
424
  # Figure out how many actors can be created given the cluster resources
454
425
  # and the resources the user indicates each VirtualClient will need
455
426
  self.actors_capacity = pool_size_from_resources(client_resources)
456
- self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {}
427
+ self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {}
457
428
 
458
429
  def is_actor_available(self) -> bool:
459
430
  """Return true if there is an idle actor."""
460
- return self.pool.qsize() > 0
431
+ return len(self.pool) > 0
461
432
 
462
- async def add_actors_to_pool(self, num_actors: int) -> None:
433
+ def add_actors_to_pool(self, num_actors: int) -> None:
463
434
  """Add actors to the pool.
464
435
 
465
436
  This method may be executed also if new resources are added to your Ray cluster
466
437
  (e.g. you add a new node).
467
438
  """
468
439
  for _ in range(num_actors):
469
- await self.pool.put(self.create_actor_fn()) # type: ignore
440
+ self.pool.append(self.create_actor_fn()) # type: ignore
470
441
  self.num_actors += num_actors
471
442
 
472
- async def terminate_all_actors(self) -> None:
443
+ def terminate_all_actors(self) -> None:
473
444
  """Terminate actors in pool."""
474
445
  num_terminated = 0
475
- while self.pool.qsize():
476
- actor = await self.pool.get()
446
+ for actor in self.pool:
477
447
  actor.terminate.remote() # type: ignore
478
448
  num_terminated += 1
479
449
 
480
450
  log(DEBUG, "Terminated %i actors", num_terminated)
481
451
 
482
- async def submit(
452
+ def submit(
483
453
  self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
484
454
  ) -> Any:
485
455
  """On idle actor, submit job and return future."""
486
456
  # Remove idle actor from pool
487
- actor = await self.pool.get()
457
+ actor = self.pool.pop()
488
458
  # Submit job to actor
489
459
  app_fn, mssg, cid, context = job
490
460
  future = actor_fn(actor, app_fn, mssg, cid, context)
@@ -493,14 +463,18 @@ class BasicActorPool:
493
463
  self._future_to_actor[future] = actor
494
464
  return future
495
465
 
496
- async def fetch_result_and_return_actor_to_pool(
466
+ def add_actor_back_to_pool(self, future: Any) -> None:
467
+ """Ad actor assigned to run future back into the pool."""
468
+ actor = self._future_to_actor.pop(future)
469
+ self.pool.append(actor)
470
+
471
+ def fetch_result_and_return_actor_to_pool(
497
472
  self, future: Any
498
473
  ) -> Tuple[Message, Context]:
499
474
  """Pull result given a future and add actor back to pool."""
500
- # Get actor that ran job
501
- actor = self._future_to_actor.pop(future)
502
- await self.pool.put(actor)
503
475
  # Retrieve result for object store
504
476
  # Instead of doing ray.get(future) we await it
505
- _, out_mssg, updated_context = await future
477
+ _, out_mssg, updated_context = ray.get(future)
478
+ # Get actor that ran job
479
+ self.add_actor_back_to_pool(future)
506
480
  return out_mssg, updated_context