flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -187,7 +187,7 @@ def secaggplus_mod(
187
187
 
188
188
  # Return message
189
189
  out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
190
- return msg.create_reply(out_content, ttl="")
190
+ return msg.create_reply(out_content)
191
191
 
192
192
 
193
193
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
flwr/client/mod/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/client/node_state.py CHANGED
@@ -15,27 +15,72 @@
15
15
  """Node state."""
16
16
 
17
17
 
18
- from typing import Any, Dict
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Dict, Optional
19
21
 
20
22
  from flwr.common import Context, RecordSet
23
+ from flwr.common.config import get_fused_config, get_fused_config_from_dir
24
+ from flwr.common.typing import Run, UserConfig
25
+
26
+
27
+ @dataclass()
28
+ class RunInfo:
29
+ """Contains the Context and initial run_config of a Run."""
30
+
31
+ context: Context
32
+ initial_run_config: UserConfig
21
33
 
22
34
 
23
35
  class NodeState:
24
36
  """State of a node where client nodes execute runs."""
25
37
 
26
- def __init__(self) -> None:
27
- self._meta: Dict[str, Any] = {} # holds metadata about the node
28
- self.run_contexts: Dict[int, Context] = {}
38
+ def __init__(
39
+ self,
40
+ node_id: int,
41
+ node_config: UserConfig,
42
+ ) -> None:
43
+ self.node_id = node_id
44
+ self.node_config = node_config
45
+ self.run_infos: Dict[int, RunInfo] = {}
29
46
 
30
- def register_context(self, run_id: int) -> None:
47
+ def register_context(
48
+ self,
49
+ run_id: int,
50
+ run: Optional[Run] = None,
51
+ flwr_path: Optional[Path] = None,
52
+ app_dir: Optional[str] = None,
53
+ ) -> None:
31
54
  """Register new run context for this node."""
32
- if run_id not in self.run_contexts:
33
- self.run_contexts[run_id] = Context(state=RecordSet())
55
+ if run_id not in self.run_infos:
56
+ initial_run_config = {}
57
+ if app_dir:
58
+ # Load from app directory
59
+ app_path = Path(app_dir)
60
+ if app_path.is_dir():
61
+ override_config = run.override_config if run else {}
62
+ initial_run_config = get_fused_config_from_dir(
63
+ app_path, override_config
64
+ )
65
+ else:
66
+ raise ValueError("The specified `app_dir` must be a directory.")
67
+ else:
68
+ # Load from .fab
69
+ initial_run_config = get_fused_config(run, flwr_path) if run else {}
70
+ self.run_infos[run_id] = RunInfo(
71
+ initial_run_config=initial_run_config,
72
+ context=Context(
73
+ node_id=self.node_id,
74
+ node_config=self.node_config,
75
+ state=RecordSet(),
76
+ run_config=initial_run_config.copy(),
77
+ ),
78
+ )
34
79
 
35
80
  def retrieve_context(self, run_id: int) -> Context:
36
81
  """Get run context given a run_id."""
37
- if run_id in self.run_contexts:
38
- return self.run_contexts[run_id]
82
+ if run_id in self.run_infos:
83
+ return self.run_infos[run_id].context
39
84
 
40
85
  raise RuntimeError(
41
86
  f"Context for run_id={run_id} doesn't exist."
@@ -45,4 +90,9 @@ class NodeState:
45
90
 
46
91
  def update_context(self, run_id: int, context: Context) -> None:
47
92
  """Update run context."""
48
- self.run_contexts[run_id] = context
93
+ if context.run_config != self.run_infos[run_id].initial_run_config:
94
+ raise ValueError(
95
+ "The `run_config` field of the `Context` object cannot be "
96
+ f"modified (run_id: {run_id})."
97
+ )
98
+ self.run_infos[run_id].context = context
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
41
41
  expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
42
42
 
43
43
  # NodeState
44
- node_state = NodeState()
44
+ node_state = NodeState(node_id=0, node_config={})
45
45
 
46
46
  for task in tasks:
47
47
  run_id = task.run_id
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
59
59
  node_state.update_context(run_id=run_id, context=updated_state)
60
60
 
61
61
  # Verify values
62
- for run_id, context in node_state.run_contexts.items():
62
+ for run_id, run_info in node_state.run_infos.items():
63
63
  assert (
64
- context.state.configs_records["counter"]["count"] == expected_values[run_id]
64
+ run_info.context.state.configs_records["counter"]["count"]
65
+ == expected_values[run_id]
65
66
  )
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,30 +15,51 @@
15
15
  """Contextmanager for a REST request-response channel to the Flower server."""
16
16
 
17
17
 
18
+ import random
18
19
  import sys
20
+ import threading
19
21
  from contextlib import contextmanager
20
22
  from copy import copy
21
23
  from logging import ERROR, INFO, WARN
22
- from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast
24
+ from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union
23
25
 
26
+ from cryptography.hazmat.primitives.asymmetric import ec
27
+ from google.protobuf.message import Message as GrpcMessage
28
+
29
+ from flwr.client.heartbeat import start_ping_loop
24
30
  from flwr.client.message_handler.message_handler import validate_out_message
25
31
  from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
26
32
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
- from flwr.common.constant import MISSING_EXTRA_REST
33
+ from flwr.common.constant import (
34
+ MISSING_EXTRA_REST,
35
+ PING_BASE_MULTIPLIER,
36
+ PING_CALL_TIMEOUT,
37
+ PING_DEFAULT_INTERVAL,
38
+ PING_RANDOM_RANGE,
39
+ )
28
40
  from flwr.common.logger import log
29
41
  from flwr.common.message import Message, Metadata
30
42
  from flwr.common.retry_invoker import RetryInvoker
31
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Fab, Run
32
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
50
  CreateNodeRequest,
34
51
  CreateNodeResponse,
35
52
  DeleteNodeRequest,
53
+ DeleteNodeResponse,
54
+ PingRequest,
55
+ PingResponse,
36
56
  PullTaskInsRequest,
37
57
  PullTaskInsResponse,
38
58
  PushTaskResRequest,
39
59
  PushTaskResResponse,
40
60
  )
41
61
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
62
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
42
63
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
43
64
 
44
65
  try:
@@ -47,19 +68,18 @@ except ModuleNotFoundError:
47
68
  sys.exit(MISSING_EXTRA_REST)
48
69
 
49
70
 
50
- KEY_NODE = "node"
51
- KEY_METADATA = "in_message_metadata"
52
-
53
-
54
71
  PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
55
72
  PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
56
73
  PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
57
74
  PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
75
+ PATH_PING: str = "api/v0/fleet/ping"
76
+ PATH_GET_RUN: str = "/api/v0/fleet/get-run"
77
+
78
+ T = TypeVar("T", bound=GrpcMessage)
58
79
 
59
80
 
60
81
  @contextmanager
61
- # pylint: disable-next=too-many-statements
62
- def http_request_response(
82
+ def http_request_response( # pylint: disable=,R0913, R0914, R0915
63
83
  server_address: str,
64
84
  insecure: bool, # pylint: disable=unused-argument
65
85
  retry_invoker: RetryInvoker,
@@ -67,12 +87,17 @@ def http_request_response(
67
87
  root_certificates: Optional[
68
88
  Union[bytes, str]
69
89
  ] = None, # pylint: disable=unused-argument
90
+ authentication_keys: Optional[ # pylint: disable=unused-argument
91
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
92
+ ] = None,
70
93
  ) -> Iterator[
71
94
  Tuple[
72
95
  Callable[[], Optional[Message]],
73
96
  Callable[[Message], None],
97
+ Optional[Callable[[], Optional[int]]],
74
98
  Optional[Callable[[], None]],
75
- Optional[Callable[[], None]],
99
+ Optional[Callable[[int], Run]],
100
+ Optional[Callable[[str], Fab]],
76
101
  ]
77
102
  ]:
78
103
  """Primitives for request/response-based interaction with a server.
@@ -98,10 +123,16 @@ def http_request_response(
98
123
  Path of the root certificate. If provided, a secure
99
124
  connection using the certificates will be established to an SSL-enabled
100
125
  Flower server. Bytes won't work for the REST API.
126
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
127
+ Client authentication is not supported for this transport type.
101
128
 
102
129
  Returns
103
130
  -------
104
- receive, send : Callable, Callable
131
+ receive : Callable
132
+ send : Callable
133
+ create_node : Optional[Callable]
134
+ delete_node : Optional[Callable]
135
+ get_run : Optional[Callable]
105
136
  """
106
137
  log(
107
138
  WARN,
@@ -126,144 +157,146 @@ def http_request_response(
126
157
  "For the REST API, the root certificates "
127
158
  "must be provided as a string path to the client.",
128
159
  )
160
+ if authentication_keys is not None:
161
+ log(ERROR, "Client authentication is not supported for this transport type.")
129
162
 
130
- # Necessary state to validate messages to be sent
131
- state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None}
132
-
133
- # Enable create_node and delete_node to store node
134
- node_store: Dict[str, Optional[Node]] = {KEY_NODE: None}
163
+ # Shared variables for inner functions
164
+ metadata: Optional[Metadata] = None
165
+ node: Optional[Node] = None
166
+ ping_thread: Optional[threading.Thread] = None
167
+ ping_stop_event = threading.Event()
135
168
 
136
169
  ###########################################################################
137
- # receive/send functions
170
+ # ping/create_node/delete_node/receive/send/get_run functions
138
171
  ###########################################################################
139
172
 
140
- def create_node() -> None:
141
- """Set create_node."""
142
- create_node_req_proto = CreateNodeRequest()
143
- create_node_req_bytes: bytes = create_node_req_proto.SerializeToString()
144
-
145
- res = retry_invoker.invoke(
146
- requests.post,
147
- url=f"{base_url}/{PATH_CREATE_NODE}",
148
- headers={
149
- "Accept": "application/protobuf",
150
- "Content-Type": "application/protobuf",
151
- },
152
- data=create_node_req_bytes,
153
- verify=verify,
154
- timeout=None,
155
- )
173
+ def _request(
174
+ req: GrpcMessage, res_type: Type[T], api_path: str, retry: bool = True
175
+ ) -> Optional[T]:
176
+ # Serialize the request
177
+ req_bytes = req.SerializeToString()
178
+
179
+ # Send the request
180
+ def post() -> requests.Response:
181
+ return requests.post(
182
+ f"{base_url}/{api_path}",
183
+ data=req_bytes,
184
+ headers={
185
+ "Accept": "application/protobuf",
186
+ "Content-Type": "application/protobuf",
187
+ },
188
+ verify=verify,
189
+ timeout=None,
190
+ )
191
+
192
+ if retry:
193
+ res: requests.Response = retry_invoker.invoke(post)
194
+ else:
195
+ res = post()
156
196
 
157
197
  # Check status code and headers
158
198
  if res.status_code != 200:
159
- return
199
+ return None
160
200
  if "content-type" not in res.headers:
161
201
  log(
162
202
  WARN,
163
203
  "[Node] POST /%s: missing header `Content-Type`",
164
- PATH_PULL_TASK_INS,
204
+ api_path,
165
205
  )
166
- return
206
+ return None
167
207
  if res.headers["content-type"] != "application/protobuf":
168
208
  log(
169
209
  WARN,
170
210
  "[Node] POST /%s: header `Content-Type` has wrong value",
171
- PATH_PULL_TASK_INS,
211
+ api_path,
172
212
  )
173
- return
213
+ return None
174
214
 
175
215
  # Deserialize ProtoBuf from bytes
176
- create_node_response_proto = CreateNodeResponse()
177
- create_node_response_proto.ParseFromString(res.content)
178
- # pylint: disable-next=no-member
179
- node_store[KEY_NODE] = create_node_response_proto.node
216
+ grpc_res = res_type()
217
+ grpc_res.ParseFromString(res.content)
218
+ return grpc_res
219
+
220
+ def ping() -> None:
221
+ # Get Node
222
+ if node is None:
223
+ log(ERROR, "Node instance missing")
224
+ return
225
+
226
+ # Construct the ping request
227
+ req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
228
+
229
+ # Send the request
230
+ res = _request(req, PingResponse, PATH_PING, retry=False)
231
+ if res is None:
232
+ return
233
+
234
+ # Check if success
235
+ if not res.success:
236
+ raise RuntimeError("Ping failed unexpectedly.")
237
+
238
+ # Wait
239
+ rd = random.uniform(*PING_RANDOM_RANGE)
240
+ next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
241
+ next_interval *= PING_BASE_MULTIPLIER + rd
242
+ if not ping_stop_event.is_set():
243
+ ping_stop_event.wait(next_interval)
244
+
245
+ def create_node() -> Optional[int]:
246
+ """Set create_node."""
247
+ req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
248
+
249
+ # Send the request
250
+ res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
251
+ if res is None:
252
+ return None
253
+
254
+ # Remember the node and the ping-loop thread
255
+ nonlocal node, ping_thread
256
+ node = res.node
257
+ ping_thread = start_ping_loop(ping, ping_stop_event)
258
+ return node.node_id
180
259
 
181
260
  def delete_node() -> None:
182
261
  """Set delete_node."""
183
- if node_store[KEY_NODE] is None:
262
+ nonlocal node
263
+ if node is None:
184
264
  log(ERROR, "Node instance missing")
185
265
  return
186
- node: Node = cast(Node, node_store[KEY_NODE])
187
- delete_node_req_proto = DeleteNodeRequest(node=node)
188
- delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString()
189
- res = retry_invoker.invoke(
190
- requests.post,
191
- url=f"{base_url}/{PATH_DELETE_NODE}",
192
- headers={
193
- "Accept": "application/protobuf",
194
- "Content-Type": "application/protobuf",
195
- },
196
- data=delete_node_req_req_bytes,
197
- verify=verify,
198
- timeout=None,
199
- )
200
266
 
201
- # Check status code and headers
202
- if res.status_code != 200:
203
- return
204
- if "content-type" not in res.headers:
205
- log(
206
- WARN,
207
- "[Node] POST /%s: missing header `Content-Type`",
208
- PATH_PULL_TASK_INS,
209
- )
267
+ # Stop the ping-loop thread
268
+ ping_stop_event.set()
269
+ if ping_thread is not None:
270
+ ping_thread.join()
271
+
272
+ # Send DeleteNode request
273
+ req = DeleteNodeRequest(node=node)
274
+
275
+ # Send the request
276
+ res = _request(req, DeleteNodeResponse, PATH_CREATE_NODE)
277
+ if res is None:
210
278
  return
211
- if res.headers["content-type"] != "application/protobuf":
212
- log(
213
- WARN,
214
- "[Node] POST /%s: header `Content-Type` has wrong value",
215
- PATH_PULL_TASK_INS,
216
- )
279
+
280
+ # Cleanup
281
+ node = None
217
282
 
218
283
  def receive() -> Optional[Message]:
219
284
  """Receive next task from server."""
220
285
  # Get Node
221
- if node_store[KEY_NODE] is None:
286
+ if node is None:
222
287
  log(ERROR, "Node instance missing")
223
288
  return None
224
- node: Node = cast(Node, node_store[KEY_NODE])
225
289
 
226
290
  # Request instructions (task) from server
227
- pull_task_ins_req_proto = PullTaskInsRequest(node=node)
228
- pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString()
229
-
230
- # Request instructions (task) from server
231
- res = retry_invoker.invoke(
232
- requests.post,
233
- url=f"{base_url}/{PATH_PULL_TASK_INS}",
234
- headers={
235
- "Accept": "application/protobuf",
236
- "Content-Type": "application/protobuf",
237
- },
238
- data=pull_task_ins_req_bytes,
239
- verify=verify,
240
- timeout=None,
241
- )
291
+ req = PullTaskInsRequest(node=node)
242
292
 
243
- # Check status code and headers
244
- if res.status_code != 200:
245
- return None
246
- if "content-type" not in res.headers:
247
- log(
248
- WARN,
249
- "[Node] POST /%s: missing header `Content-Type`",
250
- PATH_PULL_TASK_INS,
251
- )
252
- return None
253
- if res.headers["content-type"] != "application/protobuf":
254
- log(
255
- WARN,
256
- "[Node] POST /%s: header `Content-Type` has wrong value",
257
- PATH_PULL_TASK_INS,
258
- )
293
+ # Send the request
294
+ res = _request(req, PullTaskInsResponse, PATH_PULL_TASK_INS)
295
+ if res is None:
259
296
  return None
260
297
 
261
- # Deserialize ProtoBuf from bytes
262
- pull_task_ins_response_proto = PullTaskInsResponse()
263
- pull_task_ins_response_proto.ParseFromString(res.content)
264
-
265
298
  # Get the current TaskIns
266
- task_ins: Optional[TaskIns] = get_task_ins(pull_task_ins_response_proto)
299
+ task_ins: Optional[TaskIns] = get_task_ins(res)
267
300
 
268
301
  # Discard the current TaskIns if not valid
269
302
  if task_ins is not None and not (
@@ -273,86 +306,73 @@ def http_request_response(
273
306
  task_ins = None
274
307
 
275
308
  # Return the Message if available
309
+ nonlocal metadata
276
310
  message = None
277
- state[KEY_METADATA] = None
278
311
  if task_ins is not None:
279
312
  message = message_from_taskins(task_ins)
280
- state[KEY_METADATA] = copy(message.metadata)
313
+ metadata = copy(message.metadata)
281
314
  log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS)
282
315
  return message
283
316
 
284
317
  def send(message: Message) -> None:
285
318
  """Send task result back to server."""
286
319
  # Get Node
287
- if node_store[KEY_NODE] is None:
320
+ if node is None:
288
321
  log(ERROR, "Node instance missing")
289
322
  return
290
323
 
291
324
  # Get incoming message
292
- in_metadata = state[KEY_METADATA]
293
- if in_metadata is None:
325
+ nonlocal metadata
326
+ if metadata is None:
294
327
  log(ERROR, "No current message")
295
328
  return
296
329
 
297
330
  # Validate out message
298
- if not validate_out_message(message, in_metadata):
331
+ if not validate_out_message(message, metadata):
299
332
  log(ERROR, "Invalid out message")
300
333
  return
334
+ metadata = None
301
335
 
302
336
  # Construct TaskRes
303
337
  task_res = message_to_taskres(message)
304
338
 
305
339
  # Serialize ProtoBuf to bytes
306
- push_task_res_request_proto = PushTaskResRequest(task_res_list=[task_res])
307
- push_task_res_request_bytes: bytes = (
308
- push_task_res_request_proto.SerializeToString()
309
- )
340
+ req = PushTaskResRequest(task_res_list=[task_res])
310
341
 
311
- # Send ClientMessage to server
312
- res = retry_invoker.invoke(
313
- requests.post,
314
- url=f"{base_url}/{PATH_PUSH_TASK_RES}",
315
- headers={
316
- "Accept": "application/protobuf",
317
- "Content-Type": "application/protobuf",
318
- },
319
- data=push_task_res_request_bytes,
320
- verify=verify,
321
- timeout=None,
322
- )
323
-
324
- state[KEY_METADATA] = None
325
-
326
- # Check status code and headers
327
- if res.status_code != 200:
328
- return
329
- if "content-type" not in res.headers:
330
- log(
331
- WARN,
332
- "[Node] POST /%s: missing header `Content-Type`",
333
- PATH_PUSH_TASK_RES,
334
- )
335
- return
336
- if res.headers["content-type"] != "application/protobuf":
337
- log(
338
- WARN,
339
- "[Node] POST /%s: header `Content-Type` has wrong value",
340
- PATH_PUSH_TASK_RES,
341
- )
342
+ # Send the request
343
+ res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES)
344
+ if res is None:
342
345
  return
343
346
 
344
- # Deserialize ProtoBuf from bytes
345
- push_task_res_response_proto = PushTaskResResponse()
346
- push_task_res_response_proto.ParseFromString(res.content)
347
347
  log(
348
348
  INFO,
349
349
  "[Node] POST /%s: success, created result %s",
350
350
  PATH_PUSH_TASK_RES,
351
- push_task_res_response_proto.results, # pylint: disable=no-member
351
+ res.results, # pylint: disable=no-member
352
352
  )
353
353
 
354
+ def get_run(run_id: int) -> Run:
355
+ # Construct the request
356
+ req = GetRunRequest(run_id=run_id)
357
+
358
+ # Send the request
359
+ res = _request(req, GetRunResponse, PATH_GET_RUN)
360
+ if res is None:
361
+ return Run(run_id, "", "", {})
362
+
363
+ return Run(
364
+ run_id,
365
+ res.run.fab_id,
366
+ res.run.fab_version,
367
+ user_config_from_proto(res.run.override_config),
368
+ )
369
+
370
+ def get_fab(fab_hash: str) -> Fab:
371
+ # Call FleetAPI
372
+ raise NotImplementedError
373
+
354
374
  try:
355
375
  # Yield methods
356
- yield (receive, send, create_node, delete_node)
376
+ yield (receive, send, create_node, delete_node, get_run, get_fab)
357
377
  except Exception as exc: # pylint: disable=broad-except
358
378
  log(ERROR, exc)
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower SuperNode."""
16
+
17
+
18
+ from .app import flwr_clientapp as flwr_clientapp
19
+ from .app import run_client_app as run_client_app
20
+ from .app import run_supernode as run_supernode
21
+
22
+ __all__ = [
23
+ "flwr_clientapp",
24
+ "run_client_app",
25
+ "run_supernode",
26
+ ]