flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,31 +15,57 @@
15
15
  """SQLite based implemenation of server state."""
16
16
 
17
17
 
18
- import os
18
+ import json
19
19
  import re
20
20
  import sqlite3
21
- from datetime import datetime, timedelta
21
+ import time
22
22
  from logging import DEBUG, ERROR
23
- from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
23
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24
24
  from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import log, now
27
+ from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
28
+ from flwr.common.typing import Run, UserConfig
27
29
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
30
  from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
29
31
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
32
  from flwr.server.utils.validator import validate_task_ins_or_res
31
33
 
32
34
  from .state import State
35
+ from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
33
36
 
34
37
  SQL_CREATE_TABLE_NODE = """
35
38
  CREATE TABLE IF NOT EXISTS node(
36
- node_id INTEGER UNIQUE
39
+ node_id INTEGER UNIQUE,
40
+ online_until REAL,
41
+ ping_interval REAL,
42
+ public_key BLOB
37
43
  );
38
44
  """
39
45
 
46
+ SQL_CREATE_TABLE_CREDENTIAL = """
47
+ CREATE TABLE IF NOT EXISTS credential(
48
+ private_key BLOB PRIMARY KEY,
49
+ public_key BLOB
50
+ );
51
+ """
52
+
53
+ SQL_CREATE_TABLE_PUBLIC_KEY = """
54
+ CREATE TABLE IF NOT EXISTS public_key(
55
+ public_key BLOB UNIQUE
56
+ );
57
+ """
58
+
59
+ SQL_CREATE_INDEX_ONLINE_UNTIL = """
60
+ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
61
+ """
62
+
40
63
  SQL_CREATE_TABLE_RUN = """
41
64
  CREATE TABLE IF NOT EXISTS run(
42
- run_id INTEGER UNIQUE
65
+ run_id INTEGER UNIQUE,
66
+ fab_id TEXT,
67
+ fab_version TEXT,
68
+ override_config TEXT
43
69
  );
44
70
  """
45
71
 
@@ -52,9 +78,10 @@ CREATE TABLE IF NOT EXISTS task_ins(
52
78
  producer_node_id INTEGER,
53
79
  consumer_anonymous BOOLEAN,
54
80
  consumer_node_id INTEGER,
55
- created_at TEXT,
81
+ created_at REAL,
56
82
  delivered_at TEXT,
57
- ttl TEXT,
83
+ pushed_at REAL,
84
+ ttl REAL,
58
85
  ancestry TEXT,
59
86
  task_type TEXT,
60
87
  recordset BLOB,
@@ -62,7 +89,6 @@ CREATE TABLE IF NOT EXISTS task_ins(
62
89
  );
63
90
  """
64
91
 
65
-
66
92
  SQL_CREATE_TABLE_TASK_RES = """
67
93
  CREATE TABLE IF NOT EXISTS task_res(
68
94
  task_id TEXT UNIQUE,
@@ -72,9 +98,10 @@ CREATE TABLE IF NOT EXISTS task_res(
72
98
  producer_node_id INTEGER,
73
99
  consumer_anonymous BOOLEAN,
74
100
  consumer_node_id INTEGER,
75
- created_at TEXT,
101
+ created_at REAL,
76
102
  delivered_at TEXT,
77
- ttl TEXT,
103
+ pushed_at REAL,
104
+ ttl REAL,
78
105
  ancestry TEXT,
79
106
  task_type TEXT,
80
107
  recordset BLOB,
@@ -82,10 +109,10 @@ CREATE TABLE IF NOT EXISTS task_res(
82
109
  );
83
110
  """
84
111
 
85
- DictOrTuple = Union[Tuple[Any], Dict[str, Any]]
112
+ DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]]
86
113
 
87
114
 
88
- class SqliteState(State):
115
+ class SqliteState(State): # pylint: disable=R0904
89
116
  """SQLite-based state implementation."""
90
117
 
91
118
  def __init__(
@@ -123,6 +150,9 @@ class SqliteState(State):
123
150
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
124
151
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
125
152
  cur.execute(SQL_CREATE_TABLE_NODE)
153
+ cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
154
+ cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
155
+ cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
126
156
  res = cur.execute("SELECT name FROM sqlite_schema;")
127
157
 
128
158
  return res.fetchall()
@@ -130,7 +160,7 @@ class SqliteState(State):
130
160
  def query(
131
161
  self,
132
162
  query: str,
133
- data: Optional[Union[List[DictOrTuple], DictOrTuple]] = None,
163
+ data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
134
164
  ) -> List[Dict[str, Any]]:
135
165
  """Execute a SQL query."""
136
166
  if self.conn is None:
@@ -185,15 +215,11 @@ class SqliteState(State):
185
215
  log(ERROR, errors)
186
216
  return None
187
217
 
188
- # Create task_id, created_at and ttl
218
+ # Create task_id
189
219
  task_id = uuid4()
190
- created_at: datetime = now()
191
- ttl: datetime = created_at + timedelta(hours=24)
192
220
 
193
221
  # Store TaskIns
194
222
  task_ins.task_id = str(task_id)
195
- task_ins.task.created_at = created_at.isoformat()
196
- task_ins.task.ttl = ttl.isoformat()
197
223
  data = (task_ins_to_dict(task_ins),)
198
224
  columns = ", ".join([f":{key}" for key in data[0]])
199
225
  query = f"INSERT INTO task_ins VALUES({columns});"
@@ -320,15 +346,11 @@ class SqliteState(State):
320
346
  log(ERROR, errors)
321
347
  return None
322
348
 
323
- # Create task_id, created_at and ttl
349
+ # Create task_id
324
350
  task_id = uuid4()
325
- created_at: datetime = now()
326
- ttl: datetime = created_at + timedelta(hours=24)
327
351
 
328
352
  # Store TaskIns
329
353
  task_res.task_id = str(task_id)
330
- task_res.task.created_at = created_at.isoformat()
331
- task_res.task.ttl = ttl.isoformat()
332
354
  data = (task_res_to_dict(task_res),)
333
355
  columns = ", ".join([f":{key}" for key in data[0]])
334
356
  query = f"INSERT INTO task_res VALUES({columns});"
@@ -343,6 +365,7 @@ class SqliteState(State):
343
365
 
344
366
  return task_id
345
367
 
368
+ # pylint: disable-next=R0914
346
369
  def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
347
370
  """Get TaskRes for task_ids.
348
371
 
@@ -373,7 +396,7 @@ class SqliteState(State):
373
396
  AND delivered_at = ""
374
397
  """
375
398
 
376
- data: Dict[str, Union[str, int]] = {}
399
+ data: Dict[str, Union[str, float, int]] = {}
377
400
 
378
401
  if limit is not None:
379
402
  query += " LIMIT :limit"
@@ -407,6 +430,54 @@ class SqliteState(State):
407
430
  rows = self.query(query, data)
408
431
 
409
432
  result = [dict_to_task_res(row) for row in rows]
433
+
434
+ # 1. Query: Fetch consumer_node_id of remaining task_ids
435
+ # Assume the ancestry field only contains one element
436
+ data.clear()
437
+ replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
438
+ remaining_task_ids = task_ids - replied_task_ids
439
+ placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
440
+ query = f"""
441
+ SELECT consumer_node_id
442
+ FROM task_ins
443
+ WHERE task_id IN ({placeholders});
444
+ """
445
+ for index, task_id in enumerate(remaining_task_ids):
446
+ data[f"id_{index}"] = str(task_id)
447
+ node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
448
+
449
+ # 2. Query: Select offline nodes
450
+ placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
451
+ query = f"""
452
+ SELECT node_id
453
+ FROM node
454
+ WHERE node_id IN ({placeholders})
455
+ AND online_until < :time;
456
+ """
457
+ data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
458
+ data["time"] = time.time()
459
+ offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
460
+
461
+ # 3. Query: Select TaskIns for offline nodes
462
+ placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
463
+ query = f"""
464
+ SELECT *
465
+ FROM task_ins
466
+ WHERE consumer_node_id IN ({placeholders});
467
+ """
468
+ data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
469
+ task_ins_rows = self.query(query, data)
470
+
471
+ # Make TaskRes containing node unavailabe error
472
+ for row in task_ins_rows:
473
+ if limit and len(result) == limit:
474
+ break
475
+ task_ins = dict_to_task_ins(row)
476
+ err_taskres = make_node_unavailable_taskres(
477
+ ref_taskins=task_ins,
478
+ )
479
+ result.append(err_taskres)
480
+
410
481
  return result
411
482
 
412
483
  def num_task_ins(self) -> int:
@@ -467,23 +538,54 @@ class SqliteState(State):
467
538
 
468
539
  return None
469
540
 
470
- def create_node(self) -> int:
541
+ def create_node(
542
+ self, ping_interval: float, public_key: Optional[bytes] = None
543
+ ) -> int:
471
544
  """Create, store in state, and return `node_id`."""
472
545
  # Sample a random int64 as node_id
473
- node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
546
+ node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
547
+
548
+ query = "SELECT node_id FROM node WHERE public_key = :public_key;"
549
+ row = self.query(query, {"public_key": public_key})
550
+
551
+ if len(row) > 0:
552
+ log(ERROR, "Unexpected node registration failure.")
553
+ return 0
554
+
555
+ query = (
556
+ "INSERT INTO node "
557
+ "(node_id, online_until, ping_interval, public_key) "
558
+ "VALUES (?, ?, ?, ?)"
559
+ )
474
560
 
475
- query = "INSERT INTO node VALUES(:node_id);"
476
561
  try:
477
- self.query(query, {"node_id": node_id})
562
+ self.query(
563
+ query, (node_id, time.time() + ping_interval, ping_interval, public_key)
564
+ )
478
565
  except sqlite3.IntegrityError:
479
566
  log(ERROR, "Unexpected node registration failure.")
480
567
  return 0
481
568
  return node_id
482
569
 
483
- def delete_node(self, node_id: int) -> None:
570
+ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
484
571
  """Delete a client node."""
485
- query = "DELETE FROM node WHERE node_id = :node_id;"
486
- self.query(query, {"node_id": node_id})
572
+ query = "DELETE FROM node WHERE node_id = ?"
573
+ params = (node_id,)
574
+
575
+ if public_key is not None:
576
+ query += " AND public_key = ?"
577
+ params += (public_key,) # type: ignore
578
+
579
+ if self.conn is None:
580
+ raise AttributeError("State is not initialized.")
581
+
582
+ try:
583
+ with self.conn:
584
+ rows = self.conn.execute(query, params)
585
+ if rows.rowcount < 1:
586
+ raise ValueError("Public key or node_id not found")
587
+ except KeyError as exc:
588
+ log(ERROR, {"query": query, "data": params, "exception": exc})
487
589
 
488
590
  def get_nodes(self, run_id: int) -> Set[int]:
489
591
  """Retrieve all currently stored node IDs as a set.
@@ -499,26 +601,124 @@ class SqliteState(State):
499
601
  return set()
500
602
 
501
603
  # Get nodes
502
- query = "SELECT * FROM node;"
503
- rows = self.query(query)
604
+ query = "SELECT node_id FROM node WHERE online_until > ?;"
605
+ rows = self.query(query, (time.time(),))
504
606
  result: Set[int] = {row["node_id"] for row in rows}
505
607
  return result
506
608
 
507
- def create_run(self) -> int:
508
- """Create one run and store it in state."""
609
+ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
610
+ """Retrieve stored `node_id` filtered by `client_public_keys`."""
611
+ query = "SELECT node_id FROM node WHERE public_key = :public_key;"
612
+ row = self.query(query, {"public_key": client_public_key})
613
+ if len(row) > 0:
614
+ node_id: int = row[0]["node_id"]
615
+ return node_id
616
+ return None
617
+
618
+ def create_run(
619
+ self,
620
+ fab_id: str,
621
+ fab_version: str,
622
+ override_config: UserConfig,
623
+ ) -> int:
624
+ """Create a new run for the specified `fab_id` and `fab_version`."""
509
625
  # Sample a random int64 as run_id
510
- run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
626
+ run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
511
627
 
512
628
  # Check conflicts
513
629
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
514
630
  # If run_id does not exist
515
631
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
516
- query = "INSERT INTO run VALUES(:run_id);"
517
- self.query(query, {"run_id": run_id})
632
+ query = (
633
+ "INSERT INTO run (run_id, fab_id, fab_version, override_config)"
634
+ "VALUES (?, ?, ?, ?);"
635
+ )
636
+ self.query(
637
+ query, (run_id, fab_id, fab_version, json.dumps(override_config))
638
+ )
518
639
  return run_id
519
640
  log(ERROR, "Unexpected run creation failure.")
520
641
  return 0
521
642
 
643
+ def store_server_private_public_key(
644
+ self, private_key: bytes, public_key: bytes
645
+ ) -> None:
646
+ """Store `server_private_key` and `server_public_key` in state."""
647
+ query = "SELECT COUNT(*) FROM credential"
648
+ count = self.query(query)[0]["COUNT(*)"]
649
+ if count < 1:
650
+ query = (
651
+ "INSERT OR REPLACE INTO credential (private_key, public_key) "
652
+ "VALUES (:private_key, :public_key)"
653
+ )
654
+ self.query(query, {"private_key": private_key, "public_key": public_key})
655
+ else:
656
+ raise RuntimeError("Server private and public key already set")
657
+
658
+ def get_server_private_key(self) -> Optional[bytes]:
659
+ """Retrieve `server_private_key` in urlsafe bytes."""
660
+ query = "SELECT private_key FROM credential"
661
+ rows = self.query(query)
662
+ try:
663
+ private_key: Optional[bytes] = rows[0]["private_key"]
664
+ except IndexError:
665
+ private_key = None
666
+ return private_key
667
+
668
+ def get_server_public_key(self) -> Optional[bytes]:
669
+ """Retrieve `server_public_key` in urlsafe bytes."""
670
+ query = "SELECT public_key FROM credential"
671
+ rows = self.query(query)
672
+ try:
673
+ public_key: Optional[bytes] = rows[0]["public_key"]
674
+ except IndexError:
675
+ public_key = None
676
+ return public_key
677
+
678
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
679
+ """Store a set of `client_public_keys` in state."""
680
+ query = "INSERT INTO public_key (public_key) VALUES (?)"
681
+ data = [(key,) for key in public_keys]
682
+ self.query(query, data)
683
+
684
+ def store_client_public_key(self, public_key: bytes) -> None:
685
+ """Store a `client_public_key` in state."""
686
+ query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
687
+ self.query(query, {"public_key": public_key})
688
+
689
+ def get_client_public_keys(self) -> Set[bytes]:
690
+ """Retrieve all currently stored `client_public_keys` as a set."""
691
+ query = "SELECT public_key FROM public_key"
692
+ rows = self.query(query)
693
+ result: Set[bytes] = {row["public_key"] for row in rows}
694
+ return result
695
+
696
+ def get_run(self, run_id: int) -> Optional[Run]:
697
+ """Retrieve information about the run with the specified `run_id`."""
698
+ query = "SELECT * FROM run WHERE run_id = ?;"
699
+ try:
700
+ row = self.query(query, (run_id,))[0]
701
+ return Run(
702
+ run_id=run_id,
703
+ fab_id=row["fab_id"],
704
+ fab_version=row["fab_version"],
705
+ override_config=json.loads(row["override_config"]),
706
+ )
707
+ except sqlite3.IntegrityError:
708
+ log(ERROR, "`run_id` does not exist.")
709
+ return None
710
+
711
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
712
+ """Acknowledge a ping received from a node, serving as a heartbeat."""
713
+ # Update `online_until` and `ping_interval` for the given `node_id`
714
+ query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?;"
715
+ try:
716
+ self.query(query, (time.time() + ping_interval, ping_interval, node_id))
717
+ return True
718
+ except sqlite3.IntegrityError:
719
+ log(ERROR, "`node_id` does not exist.")
720
+ return False
721
+
522
722
 
523
723
  def dict_factory(
524
724
  cursor: sqlite3.Cursor,
@@ -544,6 +744,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]:
544
744
  "consumer_node_id": task_msg.task.consumer.node_id,
545
745
  "created_at": task_msg.task.created_at,
546
746
  "delivered_at": task_msg.task.delivered_at,
747
+ "pushed_at": task_msg.task.pushed_at,
547
748
  "ttl": task_msg.task.ttl,
548
749
  "ancestry": ",".join(task_msg.task.ancestry),
549
750
  "task_type": task_msg.task.task_type,
@@ -564,6 +765,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]:
564
765
  "consumer_node_id": task_msg.task.consumer.node_id,
565
766
  "created_at": task_msg.task.created_at,
566
767
  "delivered_at": task_msg.task.delivered_at,
768
+ "pushed_at": task_msg.task.pushed_at,
567
769
  "ttl": task_msg.task.ttl,
568
770
  "ancestry": ",".join(task_msg.task.ancestry),
569
771
  "task_type": task_msg.task.task_type,
@@ -592,6 +794,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns:
592
794
  ),
593
795
  created_at=task_dict["created_at"],
594
796
  delivered_at=task_dict["delivered_at"],
797
+ pushed_at=task_dict["pushed_at"],
595
798
  ttl=task_dict["ttl"],
596
799
  ancestry=task_dict["ancestry"].split(","),
597
800
  task_type=task_dict["task_type"],
@@ -621,6 +824,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes:
621
824
  ),
622
825
  created_at=task_dict["created_at"],
623
826
  delivered_at=task_dict["delivered_at"],
827
+ pushed_at=task_dict["pushed_at"],
624
828
  ttl=task_dict["ttl"],
625
829
  ancestry=task_dict["ancestry"].split(","),
626
830
  task_type=task_dict["task_type"],
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,10 +19,11 @@ import abc
19
19
  from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.typing import Run, UserConfig
22
23
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
23
24
 
24
25
 
25
- class State(abc.ABC):
26
+ class State(abc.ABC): # pylint: disable=R0904
26
27
  """Abstract State."""
27
28
 
28
29
  @abc.abstractmethod
@@ -132,11 +133,13 @@ class State(abc.ABC):
132
133
  """Delete all delivered TaskIns/TaskRes pairs."""
133
134
 
134
135
  @abc.abstractmethod
135
- def create_node(self) -> int:
136
+ def create_node(
137
+ self, ping_interval: float, public_key: Optional[bytes] = None
138
+ ) -> int:
136
139
  """Create, store in state, and return `node_id`."""
137
140
 
138
141
  @abc.abstractmethod
139
- def delete_node(self, node_id: int) -> None:
142
+ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
140
143
  """Remove `node_id` from state."""
141
144
 
142
145
  @abc.abstractmethod
@@ -150,5 +153,77 @@ class State(abc.ABC):
150
153
  """
151
154
 
152
155
  @abc.abstractmethod
153
- def create_run(self) -> int:
154
- """Create one run."""
156
+ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
157
+ """Retrieve stored `node_id` filtered by `client_public_keys`."""
158
+
159
+ @abc.abstractmethod
160
+ def create_run(
161
+ self,
162
+ fab_id: str,
163
+ fab_version: str,
164
+ override_config: UserConfig,
165
+ ) -> int:
166
+ """Create a new run for the specified `fab_id` and `fab_version`."""
167
+
168
+ @abc.abstractmethod
169
+ def get_run(self, run_id: int) -> Optional[Run]:
170
+ """Retrieve information about the run with the specified `run_id`.
171
+
172
+ Parameters
173
+ ----------
174
+ run_id : int
175
+ The identifier of the run.
176
+
177
+ Returns
178
+ -------
179
+ Optional[Run]
180
+ A dataclass instance containing three elements if `run_id` is valid:
181
+ - `run_id`: The identifier of the run, same as the specified `run_id`.
182
+ - `fab_id`: The identifier of the FAB used in the specified run.
183
+ - `fab_version`: The version of the FAB used in the specified run.
184
+ """
185
+
186
+ @abc.abstractmethod
187
+ def store_server_private_public_key(
188
+ self, private_key: bytes, public_key: bytes
189
+ ) -> None:
190
+ """Store `server_private_key` and `server_public_key` in state."""
191
+
192
+ @abc.abstractmethod
193
+ def get_server_private_key(self) -> Optional[bytes]:
194
+ """Retrieve `server_private_key` in urlsafe bytes."""
195
+
196
+ @abc.abstractmethod
197
+ def get_server_public_key(self) -> Optional[bytes]:
198
+ """Retrieve `server_public_key` in urlsafe bytes."""
199
+
200
+ @abc.abstractmethod
201
+ def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
202
+ """Store a set of `client_public_keys` in state."""
203
+
204
+ @abc.abstractmethod
205
+ def store_client_public_key(self, public_key: bytes) -> None:
206
+ """Store a `client_public_key` in state."""
207
+
208
+ @abc.abstractmethod
209
+ def get_client_public_keys(self) -> Set[bytes]:
210
+ """Retrieve all currently stored `client_public_keys` as a set."""
211
+
212
+ @abc.abstractmethod
213
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
214
+ """Acknowledge a ping received from a node, serving as a heartbeat.
215
+
216
+ Parameters
217
+ ----------
218
+ node_id : int
219
+ The `node_id` from which the ping was received.
220
+ ping_interval : float
221
+ The interval (in seconds) from the current timestamp within which the next
222
+ ping from this node must be received. This acts as a hard deadline to ensure
223
+ an accurate assessment of the node's availability.
224
+
225
+ Returns
226
+ -------
227
+ is_acknowledged : bool
228
+ True if the ping is successfully acknowledged; otherwise, False.
229
+ """
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -26,7 +26,16 @@ from .state import State
26
26
 
27
27
 
28
28
  class StateFactory:
29
- """Factory class that creates State instances."""
29
+ """Factory class that creates State instances.
30
+
31
+ Parameters
32
+ ----------
33
+ database : str
34
+ A string representing the path to the database file that will be opened.
35
+ Note that passing ':memory:' will open a connection to a database that is
36
+ in RAM, instead of on disk. For more information on special in-memory
37
+ databases, please refer to https://sqlite.org/inmemorydb.html.
38
+ """
30
39
 
31
40
  def __init__(self, database: str) -> None:
32
41
  self.database = database
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for State."""
16
+
17
+
18
+ import time
19
+ from logging import ERROR
20
+ from os import urandom
21
+ from uuid import uuid4
22
+
23
+ from flwr.common import log
24
+ from flwr.common.constant import ErrorCode
25
+ from flwr.proto.error_pb2 import Error # pylint: disable=E0611
26
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
27
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
28
+
29
+ NODE_UNAVAILABLE_ERROR_REASON = (
30
+ "Error: Node Unavailable - The destination node is currently unavailable. "
31
+ "It exceeds the time limit specified in its last ping."
32
+ )
33
+
34
+
35
+ def generate_rand_int_from_bytes(num_bytes: int) -> int:
36
+ """Generate a random `num_bytes` integer."""
37
+ return int.from_bytes(urandom(num_bytes), "little", signed=True)
38
+
39
+
40
+ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
41
+ """Generate a TaskRes with a node unavailable error from a TaskIns."""
42
+ current_time = time.time()
43
+ ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
44
+ if ttl < 0:
45
+ log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
46
+ ttl = 0
47
+ return TaskRes(
48
+ task_id=str(uuid4()),
49
+ group_id=ref_taskins.group_id,
50
+ run_id=ref_taskins.run_id,
51
+ task=Task(
52
+ producer=Node(node_id=ref_taskins.task.consumer.node_id, anonymous=False),
53
+ consumer=Node(node_id=ref_taskins.task.producer.node_id, anonymous=False),
54
+ created_at=current_time,
55
+ ttl=ttl,
56
+ ancestry=[ref_taskins.task_id],
57
+ task_type=ref_taskins.task.task_type,
58
+ error=Error(
59
+ code=ErrorCode.NODE_UNAVAILABLE, reason=NODE_UNAVAILABLE_ERROR_REASON
60
+ ),
61
+ ),
62
+ )
flwr/server/typing.py CHANGED
@@ -20,6 +20,8 @@ from typing import Callable
20
20
  from flwr.common import Context
21
21
 
22
22
  from .driver import Driver
23
+ from .serverapp_components import ServerAppComponents
23
24
 
24
25
  ServerAppCallable = Callable[[Driver, Context], None]
25
26
  Workflow = Callable[[Driver, Context], None]
27
+ ServerFn = Callable[[Context], ServerAppComponents]
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2021 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.