flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,27 +15,40 @@
15
15
  """In-memory State implementation."""
16
16
 
17
17
 
18
- import os
19
18
  import threading
20
- from datetime import datetime, timedelta
19
+ import time
21
20
  from logging import ERROR
22
- from typing import Dict, List, Optional, Set
21
+ from typing import Dict, List, Optional, Set, Tuple
23
22
  from uuid import UUID, uuid4
24
23
 
25
24
  from flwr.common import log, now
25
+ from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
26
+ from flwr.common.typing import Run, UserConfig
26
27
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
27
28
  from flwr.server.superlink.state.state import State
28
29
  from flwr.server.utils import validate_task_ins_or_res
29
30
 
31
+ from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
30
32
 
31
- class InMemoryState(State):
33
+
34
+ class InMemoryState(State): # pylint: disable=R0902,R0904
32
35
  """In-memory State implementation."""
33
36
 
34
37
  def __init__(self) -> None:
35
- self.node_ids: Set[int] = set()
36
- self.run_ids: Set[int] = set()
38
+
39
+ # Map node_id to (online_until, ping_interval)
40
+ self.node_ids: Dict[int, Tuple[float, float]] = {}
41
+ self.public_key_to_node_id: Dict[bytes, int] = {}
42
+
43
+ # Map run_id to (fab_id, fab_version)
44
+ self.run_ids: Dict[int, Run] = {}
37
45
  self.task_ins_store: Dict[UUID, TaskIns] = {}
38
46
  self.task_res_store: Dict[UUID, TaskRes] = {}
47
+
48
+ self.client_public_keys: Set[bytes] = set()
49
+ self.server_public_key: Optional[bytes] = None
50
+ self.server_private_key: Optional[bytes] = None
51
+
39
52
  self.lock = threading.Lock()
40
53
 
41
54
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
@@ -50,15 +63,11 @@ class InMemoryState(State):
50
63
  log(ERROR, "`run_id` is invalid")
51
64
  return None
52
65
 
53
- # Create task_id, created_at and ttl
66
+ # Create task_id
54
67
  task_id = uuid4()
55
- created_at: datetime = now()
56
- ttl: datetime = created_at + timedelta(hours=24)
57
68
 
58
69
  # Store TaskIns
59
70
  task_ins.task_id = str(task_id)
60
- task_ins.task.created_at = created_at.isoformat()
61
- task_ins.task.ttl = ttl.isoformat()
62
71
  with self.lock:
63
72
  self.task_ins_store[task_id] = task_ins
64
73
 
@@ -113,15 +122,11 @@ class InMemoryState(State):
113
122
  log(ERROR, "`run_id` is invalid")
114
123
  return None
115
124
 
116
- # Create task_id, created_at and ttl
125
+ # Create task_id
117
126
  task_id = uuid4()
118
- created_at: datetime = now()
119
- ttl: datetime = created_at + timedelta(hours=24)
120
127
 
121
128
  # Store TaskRes
122
129
  task_res.task_id = str(task_id)
123
- task_res.task.created_at = created_at.isoformat()
124
- task_res.task.ttl = ttl.isoformat()
125
130
  with self.lock:
126
131
  self.task_res_store[task_id] = task_res
127
132
 
@@ -136,14 +141,31 @@ class InMemoryState(State):
136
141
  with self.lock:
137
142
  # Find TaskRes that were not delivered yet
138
143
  task_res_list: List[TaskRes] = []
144
+ replied_task_ids: Set[UUID] = set()
139
145
  for _, task_res in self.task_res_store.items():
140
- if (
141
- UUID(task_res.task.ancestry[0]) in task_ids
142
- and task_res.task.delivered_at == ""
143
- ):
146
+ reply_to = UUID(task_res.task.ancestry[0])
147
+ if reply_to in task_ids and task_res.task.delivered_at == "":
144
148
  task_res_list.append(task_res)
149
+ replied_task_ids.add(reply_to)
150
+ if limit and len(task_res_list) == limit:
151
+ break
152
+
153
+ # Check if the node is offline
154
+ for task_id in task_ids - replied_task_ids:
145
155
  if limit and len(task_res_list) == limit:
146
156
  break
157
+ task_ins = self.task_ins_store.get(task_id)
158
+ if task_ins is None:
159
+ continue
160
+ node_id = task_ins.task.consumer.node_id
161
+ online_until, _ = self.node_ids[node_id]
162
+ # Generate a TaskRes containing an error reply if the node is offline.
163
+ if online_until < time.time():
164
+ err_taskres = make_node_unavailable_taskres(
165
+ ref_taskins=task_ins,
166
+ )
167
+ self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
168
+ task_res_list.append(err_taskres)
147
169
 
148
170
  # Mark all of them as delivered
149
171
  delivered_at = now().isoformat()
@@ -189,22 +211,47 @@ class InMemoryState(State):
189
211
  """
190
212
  return len(self.task_res_store)
191
213
 
192
- def create_node(self) -> int:
214
+ def create_node(
215
+ self, ping_interval: float, public_key: Optional[bytes] = None
216
+ ) -> int:
193
217
  """Create, store in state, and return `node_id`."""
194
218
  # Sample a random int64 as node_id
195
- node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
219
+ node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
196
220
 
197
- if node_id not in self.node_ids:
198
- self.node_ids.add(node_id)
221
+ with self.lock:
222
+ if node_id in self.node_ids:
223
+ log(ERROR, "Unexpected node registration failure.")
224
+ return 0
225
+
226
+ if public_key is not None:
227
+ if (
228
+ public_key in self.public_key_to_node_id
229
+ or node_id in self.public_key_to_node_id.values()
230
+ ):
231
+ log(ERROR, "Unexpected node registration failure.")
232
+ return 0
233
+
234
+ self.public_key_to_node_id[public_key] = node_id
235
+
236
+ self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
199
237
  return node_id
200
- log(ERROR, "Unexpected node registration failure.")
201
- return 0
202
238
 
203
- def delete_node(self, node_id: int) -> None:
239
+ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
204
240
  """Delete a client node."""
205
- if node_id not in self.node_ids:
206
- raise ValueError(f"Node {node_id} not found")
207
- self.node_ids.remove(node_id)
241
+ with self.lock:
242
+ if node_id not in self.node_ids:
243
+ raise ValueError(f"Node {node_id} not found")
244
+
245
+ if public_key is not None:
246
+ if (
247
+ public_key not in self.public_key_to_node_id
248
+ or node_id not in self.public_key_to_node_id.values()
249
+ ):
250
+ raise ValueError("Public key or node_id not found")
251
+
252
+ del self.public_key_to_node_id[public_key]
253
+
254
+ del self.node_ids[node_id]
208
255
 
209
256
  def get_nodes(self, run_id: int) -> Set[int]:
210
257
  """Return all available client nodes.
@@ -214,17 +261,87 @@ class InMemoryState(State):
214
261
  If the provided `run_id` does not exist or has no matching nodes,
215
262
  an empty `Set` MUST be returned.
216
263
  """
217
- if run_id not in self.run_ids:
218
- return set()
219
- return self.node_ids
220
-
221
- def create_run(self) -> int:
222
- """Create one run."""
264
+ with self.lock:
265
+ if run_id not in self.run_ids:
266
+ return set()
267
+ current_time = time.time()
268
+ return {
269
+ node_id
270
+ for node_id, (online_until, _) in self.node_ids.items()
271
+ if online_until > current_time
272
+ }
273
+
274
+ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
275
+ """Retrieve stored `node_id` filtered by `client_public_keys`."""
276
+ return self.public_key_to_node_id.get(client_public_key)
277
+
278
+ def create_run(
279
+ self,
280
+ fab_id: str,
281
+ fab_version: str,
282
+ override_config: UserConfig,
283
+ ) -> int:
284
+ """Create a new run for the specified `fab_id` and `fab_version`."""
223
285
  # Sample a random int64 as run_id
224
- run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
225
-
226
- if run_id not in self.run_ids:
227
- self.run_ids.add(run_id)
228
- return run_id
286
+ with self.lock:
287
+ run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
288
+
289
+ if run_id not in self.run_ids:
290
+ self.run_ids[run_id] = Run(
291
+ run_id=run_id,
292
+ fab_id=fab_id,
293
+ fab_version=fab_version,
294
+ override_config=override_config,
295
+ )
296
+ return run_id
229
297
  log(ERROR, "Unexpected run creation failure.")
230
298
  return 0
299
+
300
+ def store_server_private_public_key(
301
+ self, private_key: bytes, public_key: bytes
302
+ ) -> None:
303
+ """Store `server_private_key` and `server_public_key` in state."""
304
+ with self.lock:
305
+ if self.server_private_key is None and self.server_public_key is None:
306
+ self.server_private_key = private_key
307
+ self.server_public_key = public_key
308
+ else:
309
+ raise RuntimeError("Server private and public key already set")
310
+
311
+ def get_server_private_key(self) -> Optional[bytes]:
312
+ """Retrieve `server_private_key` in urlsafe bytes."""
313
+ return self.server_private_key
314
+
315
+ def get_server_public_key(self) -> Optional[bytes]:
316
+ """Retrieve `server_public_key` in urlsafe bytes."""
317
+ return self.server_public_key
318
+
319
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
320
+ """Store a set of `client_public_keys` in state."""
321
+ with self.lock:
322
+ self.client_public_keys = public_keys
323
+
324
+ def store_client_public_key(self, public_key: bytes) -> None:
325
+ """Store a `client_public_key` in state."""
326
+ with self.lock:
327
+ self.client_public_keys.add(public_key)
328
+
329
+ def get_client_public_keys(self) -> Set[bytes]:
330
+ """Retrieve all currently stored `client_public_keys` as a set."""
331
+ return self.client_public_keys
332
+
333
+ def get_run(self, run_id: int) -> Optional[Run]:
334
+ """Retrieve information about the run with the specified `run_id`."""
335
+ with self.lock:
336
+ if run_id not in self.run_ids:
337
+ log(ERROR, "`run_id` is invalid")
338
+ return None
339
+ return self.run_ids[run_id]
340
+
341
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
342
+ """Acknowledge a ping received from a node, serving as a heartbeat."""
343
+ with self.lock:
344
+ if node_id in self.node_ids:
345
+ self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
346
+ return True
347
+ return False