flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (95) hide show
  1. flwr/cli/build.py +18 -4
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +135 -51
  33. flwr/client/__init__.py +2 -0
  34. flwr/client/app.py +63 -26
  35. flwr/client/client_app.py +49 -4
  36. flwr/client/grpc_adapter_client/connection.py +3 -2
  37. flwr/client/grpc_client/connection.py +3 -2
  38. flwr/client/grpc_rere_client/connection.py +17 -6
  39. flwr/client/message_handler/message_handler.py +3 -4
  40. flwr/client/node_state.py +60 -10
  41. flwr/client/node_state_tests.py +4 -3
  42. flwr/client/rest_client/connection.py +19 -8
  43. flwr/client/supernode/app.py +60 -21
  44. flwr/client/typing.py +1 -0
  45. flwr/common/config.py +87 -2
  46. flwr/common/constant.py +6 -0
  47. flwr/common/context.py +26 -1
  48. flwr/common/logger.py +38 -0
  49. flwr/common/message.py +0 -17
  50. flwr/common/serde.py +45 -0
  51. flwr/common/telemetry.py +17 -0
  52. flwr/common/typing.py +5 -0
  53. flwr/proto/common_pb2.py +36 -0
  54. flwr/proto/common_pb2.pyi +121 -0
  55. flwr/proto/common_pb2_grpc.py +4 -0
  56. flwr/proto/common_pb2_grpc.pyi +4 -0
  57. flwr/proto/driver_pb2.py +24 -19
  58. flwr/proto/driver_pb2.pyi +21 -1
  59. flwr/proto/exec_pb2.py +16 -11
  60. flwr/proto/exec_pb2.pyi +22 -1
  61. flwr/proto/run_pb2.py +12 -7
  62. flwr/proto/run_pb2.pyi +22 -1
  63. flwr/proto/task_pb2.py +7 -8
  64. flwr/server/__init__.py +2 -0
  65. flwr/server/compat/legacy_context.py +5 -4
  66. flwr/server/driver/grpc_driver.py +82 -140
  67. flwr/server/run_serverapp.py +40 -15
  68. flwr/server/server_app.py +56 -10
  69. flwr/server/serverapp_components.py +52 -0
  70. flwr/server/superlink/driver/driver_servicer.py +18 -3
  71. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  72. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  73. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  74. flwr/server/superlink/fleet/vce/vce_api.py +149 -122
  75. flwr/server/superlink/state/in_memory_state.py +15 -7
  76. flwr/server/superlink/state/sqlite_state.py +27 -12
  77. flwr/server/superlink/state/state.py +7 -2
  78. flwr/server/superlink/state/utils.py +6 -0
  79. flwr/server/typing.py +2 -0
  80. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  81. flwr/simulation/app.py +52 -36
  82. flwr/simulation/ray_transport/ray_actor.py +15 -19
  83. flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
  84. flwr/simulation/run_simulation.py +237 -66
  85. flwr/superexec/app.py +14 -7
  86. flwr/superexec/deployment.py +186 -0
  87. flwr/superexec/exec_grpc.py +5 -1
  88. flwr/superexec/exec_servicer.py +4 -1
  89. flwr/superexec/executor.py +18 -0
  90. flwr/superexec/simulation.py +151 -0
  91. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  92. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
  93. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  94. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  95. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
@@ -33,8 +33,8 @@ class Backend(ABC):
33
33
  """Construct a backend."""
34
34
 
35
35
  @abstractmethod
36
- async def build(self) -> None:
37
- """Build backend asynchronously.
36
+ def build(self) -> None:
37
+ """Build backend.
38
38
 
39
39
  Different components need to be in place before workers in a backend are ready
40
40
  to accept jobs. When this method finishes executing, the backend should be fully
@@ -54,11 +54,11 @@ class Backend(ABC):
54
54
  """Report whether a backend worker is idle and can therefore run a ClientApp."""
55
55
 
56
56
  @abstractmethod
57
- async def terminate(self) -> None:
57
+ def terminate(self) -> None:
58
58
  """Terminate backend."""
59
59
 
60
60
  @abstractmethod
61
- async def process_message(
61
+ def process_message(
62
62
  self,
63
63
  app: Callable[[], ClientApp],
64
64
  message: Message,
@@ -21,6 +21,7 @@ from typing import Callable, Dict, List, Tuple, Union
21
21
  import ray
22
22
 
23
23
  from flwr.client.client_app import ClientApp
24
+ from flwr.common.constant import PARTITION_ID_KEY
24
25
  from flwr.common.context import Context
25
26
  from flwr.common.logger import log
26
27
  from flwr.common.message import Message
@@ -153,12 +154,12 @@ class RayBackend(Backend):
153
154
  """Report whether the pool has idle actors."""
154
155
  return self.pool.is_actor_available()
155
156
 
156
- async def build(self) -> None:
157
+ def build(self) -> None:
157
158
  """Build pool of Ray actors that this backend will submit jobs to."""
158
- await self.pool.add_actors_to_pool(self.pool.actors_capacity)
159
+ self.pool.add_actors_to_pool(self.pool.actors_capacity)
159
160
  log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
160
161
 
161
- async def process_message(
162
+ def process_message(
162
163
  self,
163
164
  app: Callable[[], ClientApp],
164
165
  message: Message,
@@ -168,21 +169,20 @@ class RayBackend(Backend):
168
169
 
169
170
  Return output message and updated context.
170
171
  """
171
- partition_id = message.metadata.partition_id
172
+ partition_id = context.node_config[PARTITION_ID_KEY]
172
173
 
173
174
  try:
174
175
  # Submit a task to the pool
175
- future = await self.pool.submit(
176
+ future = self.pool.submit(
176
177
  lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
177
178
  (app, message, str(partition_id), context),
178
179
  )
179
180
 
180
- await future
181
181
  # Fetch result
182
182
  (
183
183
  out_mssg,
184
184
  updated_context,
185
- ) = await self.pool.fetch_result_and_return_actor_to_pool(future)
185
+ ) = self.pool.fetch_result_and_return_actor_to_pool(future)
186
186
 
187
187
  return out_mssg, updated_context
188
188
 
@@ -193,11 +193,11 @@ class RayBackend(Backend):
193
193
  self.__class__.__name__,
194
194
  )
195
195
  # add actor back into pool
196
- await self.pool.add_actor_back_to_pool(future)
196
+ self.pool.add_actor_back_to_pool(future)
197
197
  raise ex
198
198
 
199
- async def terminate(self) -> None:
199
+ def terminate(self) -> None:
200
200
  """Terminate all actors in actor pool."""
201
- await self.pool.terminate_all_actors()
201
+ self.pool.terminate_all_actors()
202
202
  ray.shutdown()
203
203
  log(DEBUG, "Terminated %s", self.__class__.__name__)
@@ -14,24 +14,33 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine API."""
16
16
 
17
- import asyncio
17
+
18
18
  import json
19
- import sys
19
+ import threading
20
20
  import time
21
21
  import traceback
22
+ from concurrent.futures import ThreadPoolExecutor
22
23
  from logging import DEBUG, ERROR, INFO, WARN
23
24
  from pathlib import Path
24
- from typing import Callable, Dict, List, Optional
25
+ from queue import Empty, Queue
26
+ from time import sleep
27
+ from typing import Callable, Dict, Optional
25
28
 
26
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
27
30
  from flwr.client.node_state import NodeState
28
- from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
31
+ from flwr.client.supernode.app import _get_load_client_app_fn
32
+ from flwr.common.constant import (
33
+ NUM_PARTITIONS_KEY,
34
+ PARTITION_ID_KEY,
35
+ PING_MAX_INTERVAL,
36
+ ErrorCode,
37
+ )
29
38
  from flwr.common.logger import log
30
39
  from flwr.common.message import Error
31
- from flwr.common.object_ref import load_app
32
40
  from flwr.common.serde import message_from_taskins, message_to_taskres
33
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
34
- from flwr.server.superlink.state import StateFactory
41
+ from flwr.common.typing import Run
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
+ from flwr.server.superlink.state import State, StateFactory
35
44
 
36
45
  from .backend import Backend, error_messages_backends, supported_backends
37
46
 
@@ -51,34 +60,57 @@ def _register_nodes(
51
60
  return nodes_mapping
52
61
 
53
62
 
63
+ def _register_node_states(
64
+ nodes_mapping: NodeToPartitionMapping,
65
+ run: Run,
66
+ app_dir: Optional[str] = None,
67
+ ) -> Dict[int, NodeState]:
68
+ """Create NodeState objects and pre-register the context for the run."""
69
+ node_states: Dict[int, NodeState] = {}
70
+ num_partitions = len(set(nodes_mapping.values()))
71
+ for node_id, partition_id in nodes_mapping.items():
72
+ node_states[node_id] = NodeState(
73
+ node_id=node_id,
74
+ node_config={
75
+ PARTITION_ID_KEY: str(partition_id),
76
+ NUM_PARTITIONS_KEY: str(num_partitions),
77
+ },
78
+ )
79
+
80
+ # Pre-register Context objects
81
+ node_states[node_id].register_context(
82
+ run_id=run.run_id, run=run, app_dir=app_dir
83
+ )
84
+
85
+ return node_states
86
+
87
+
54
88
  # pylint: disable=too-many-arguments,too-many-locals
55
- async def worker(
89
+ def worker(
56
90
  app_fn: Callable[[], ClientApp],
57
- queue: "asyncio.Queue[TaskIns]",
91
+ taskins_queue: "Queue[TaskIns]",
92
+ taskres_queue: "Queue[TaskRes]",
58
93
  node_states: Dict[int, NodeState],
59
- state_factory: StateFactory,
60
- nodes_mapping: NodeToPartitionMapping,
61
94
  backend: Backend,
95
+ f_stop: threading.Event,
62
96
  ) -> None:
63
97
  """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
64
- state = state_factory.state()
65
- while True:
98
+ while not f_stop.is_set():
66
99
  out_mssg = None
67
100
  try:
68
- task_ins: TaskIns = await queue.get()
101
+ # Fetch from queue with timeout. We use a timeout so
102
+ # the stopping event can be evaluated even when the queue is empty.
103
+ task_ins: TaskIns = taskins_queue.get(timeout=1.0)
69
104
  node_id = task_ins.task.consumer.node_id
70
105
 
71
- # Register and retrieve runstate
72
- node_states[node_id].register_context(run_id=task_ins.run_id)
106
+ # Retrieve context
73
107
  context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
74
108
 
75
109
  # Convert TaskIns to Message
76
110
  message = message_from_taskins(task_ins)
77
- # Set partition_id
78
- message.metadata.partition_id = nodes_mapping[node_id]
79
111
 
80
112
  # Let backend process message
81
- out_mssg, updated_context = await backend.process_message(
113
+ out_mssg, updated_context = backend.process_message(
82
114
  app_fn, message, context
83
115
  )
84
116
 
@@ -86,11 +118,9 @@ async def worker(
86
118
  node_states[node_id].update_context(
87
119
  task_ins.run_id, context=updated_context
88
120
  )
89
-
90
- except asyncio.CancelledError as e:
91
- log(DEBUG, "Terminating async worker: %s", e)
92
- break
93
-
121
+ except Empty:
122
+ # An exception raised if queue.get times out
123
+ pass
94
124
  # Exceptions aren't raised but reported as an error message
95
125
  except Exception as ex: # pylint: disable=broad-exception-caught
96
126
  log(ERROR, ex)
@@ -114,67 +144,48 @@ async def worker(
114
144
  task_res = message_to_taskres(out_mssg)
115
145
  # Store TaskRes in state
116
146
  task_res.task.pushed_at = time.time()
117
- state.store_task_res(task_res)
147
+ taskres_queue.put(task_res)
118
148
 
119
149
 
120
- async def add_taskins_to_queue(
121
- queue: "asyncio.Queue[TaskIns]",
122
- state_factory: StateFactory,
150
+ def add_taskins_to_queue(
151
+ state: State,
152
+ queue: "Queue[TaskIns]",
123
153
  nodes_mapping: NodeToPartitionMapping,
124
- backend: Backend,
125
- consumers: List["asyncio.Task[None]"],
126
- f_stop: asyncio.Event,
154
+ f_stop: threading.Event,
127
155
  ) -> None:
128
- """Retrieve TaskIns and add it to the queue."""
129
- state = state_factory.state()
130
- num_initial_consumers = len(consumers)
156
+ """Put TaskIns in a queue from State."""
131
157
  while not f_stop.is_set():
132
158
  for node_id in nodes_mapping.keys():
133
- task_ins = state.get_task_ins(node_id=node_id, limit=1)
134
- if task_ins:
135
- await queue.put(task_ins[0])
136
-
137
- # Count consumers that are running
138
- num_active = sum(not (cc.done()) for cc in consumers)
139
-
140
- # Alert if number of consumers decreased by half
141
- if num_active < num_initial_consumers // 2:
142
- log(
143
- WARN,
144
- "Number of active workers has more than halved: (%i/%i active)",
145
- num_active,
146
- num_initial_consumers,
147
- )
159
+ task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
160
+ for task_ins in task_ins_list:
161
+ queue.put(task_ins)
162
+ sleep(0.1)
148
163
 
149
- # Break if consumers died
150
- if num_active == 0:
151
- raise RuntimeError("All workers have died. Ending Simulation.")
152
164
 
153
- # Log some stats
154
- log(
155
- DEBUG,
156
- "Simulation Engine stats: "
157
- "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
158
- num_active,
159
- num_initial_consumers,
160
- backend.__class__.__name__,
161
- backend.num_workers,
162
- queue.qsize(),
163
- )
164
- await asyncio.sleep(1.0)
165
- log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
165
+ def put_taskres_into_state(
166
+ state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
167
+ ) -> None:
168
+ """Put TaskRes into State from a queue."""
169
+ while not f_stop.is_set():
170
+ try:
171
+ taskres = queue.get(timeout=1.0)
172
+ state.store_task_res(taskres)
173
+ except Empty:
174
+ # queue is empty when timeout was triggered
175
+ pass
166
176
 
167
177
 
168
- async def run(
178
+ def run_api(
169
179
  app_fn: Callable[[], ClientApp],
170
180
  backend_fn: Callable[[], Backend],
171
181
  nodes_mapping: NodeToPartitionMapping,
172
182
  state_factory: StateFactory,
173
183
  node_states: Dict[int, NodeState],
174
- f_stop: asyncio.Event,
184
+ f_stop: threading.Event,
175
185
  ) -> None:
176
- """Run the VCE async."""
177
- queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
186
+ """Run the VCE."""
187
+ taskins_queue: "Queue[TaskIns]" = Queue()
188
+ taskres_queue: "Queue[TaskRes]" = Queue()
178
189
 
179
190
  try:
180
191
 
@@ -182,29 +193,48 @@ async def run(
182
193
  backend = backend_fn()
183
194
 
184
195
  # Build backend
185
- await backend.build()
196
+ backend.build()
186
197
 
187
198
  # Add workers (they submit Messages to Backend)
188
- worker_tasks = [
189
- asyncio.create_task(
190
- worker(
191
- app_fn, queue, node_states, state_factory, nodes_mapping, backend
192
- )
193
- )
194
- for _ in range(backend.num_workers)
195
- ]
196
- # Create producer (adds TaskIns into Queue)
197
- producer = asyncio.create_task(
198
- add_taskins_to_queue(
199
- queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
200
- )
199
+ state = state_factory.state()
200
+
201
+ extractor_th = threading.Thread(
202
+ target=add_taskins_to_queue,
203
+ args=(
204
+ state,
205
+ taskins_queue,
206
+ nodes_mapping,
207
+ f_stop,
208
+ ),
209
+ )
210
+ extractor_th.start()
211
+
212
+ injector_th = threading.Thread(
213
+ target=put_taskres_into_state,
214
+ args=(
215
+ state,
216
+ taskres_queue,
217
+ f_stop,
218
+ ),
201
219
  )
220
+ injector_th.start()
221
+
222
+ with ThreadPoolExecutor() as executor:
223
+ _ = [
224
+ executor.submit(
225
+ worker,
226
+ app_fn,
227
+ taskins_queue,
228
+ taskres_queue,
229
+ node_states,
230
+ backend,
231
+ f_stop,
232
+ )
233
+ for _ in range(backend.num_workers)
234
+ ]
202
235
 
203
- # Wait for producer to finish
204
- # The producer runs forever until f_stop is set or until
205
- # all worker (consumer) coroutines are completed. Workers
206
- # also run forever and only end if an exception is raised.
207
- await asyncio.gather(producer)
236
+ extractor_th.join()
237
+ injector_th.join()
208
238
 
209
239
  except Exception as ex:
210
240
 
@@ -219,18 +249,9 @@ async def run(
219
249
  raise RuntimeError("Simulation Engine crashed.") from ex
220
250
 
221
251
  finally:
222
- # Produced task terminated, now cancel worker tasks
223
- for w_t in worker_tasks:
224
- _ = w_t.cancel()
225
-
226
- while not all(w_t.done() for w_t in worker_tasks):
227
- log(DEBUG, "Terminating async workers...")
228
- await asyncio.sleep(0.5)
229
-
230
- await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
231
252
 
232
253
  # Terminate backend
233
- await backend.terminate()
254
+ backend.terminate()
234
255
 
235
256
 
236
257
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
@@ -239,7 +260,10 @@ def start_vce(
239
260
  backend_name: str,
240
261
  backend_config_json_stream: str,
241
262
  app_dir: str,
242
- f_stop: asyncio.Event,
263
+ is_app: bool,
264
+ f_stop: threading.Event,
265
+ run: Run,
266
+ flwr_dir: Optional[str] = None,
243
267
  client_app: Optional[ClientApp] = None,
244
268
  client_app_attr: Optional[str] = None,
245
269
  num_supernodes: Optional[int] = None,
@@ -290,9 +314,9 @@ def start_vce(
290
314
  )
291
315
 
292
316
  # Construct mapping of NodeStates
293
- node_states: Dict[int, NodeState] = {}
294
- for node_id in nodes_mapping:
295
- node_states[node_id] = NodeState()
317
+ node_states = _register_node_states(
318
+ nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
319
+ )
296
320
 
297
321
  # Load backend config
298
322
  log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
@@ -321,16 +345,12 @@ def start_vce(
321
345
  def _load() -> ClientApp:
322
346
 
323
347
  if client_app_attr:
324
-
325
- if app_dir is not None:
326
- sys.path.insert(0, app_dir)
327
-
328
- app: ClientApp = load_app(client_app_attr, LoadClientAppError, app_dir)
329
-
330
- if not isinstance(app, ClientApp):
331
- raise LoadClientAppError(
332
- f"Attribute {client_app_attr} is not of type {ClientApp}",
333
- ) from None
348
+ app = _get_load_client_app_fn(
349
+ default_app_ref=client_app_attr,
350
+ dir_arg=app_dir,
351
+ flwr_dir_arg=flwr_dir,
352
+ multi_app=True,
353
+ )(run.fab_id, run.fab_version)
334
354
 
335
355
  if client_app:
336
356
  app = client_app
@@ -340,18 +360,25 @@ def start_vce(
340
360
 
341
361
  try:
342
362
  # Test if ClientApp can be loaded
343
- _ = app_fn()
363
+ client_app = app_fn()
364
+
365
+ # Cache `ClientApp`
366
+ if client_app_attr:
367
+ # Now wrap the loaded ClientApp in a dummy function
368
+ # this prevent unnecesary low-level loading of ClientApp
369
+ def _load_client_app() -> ClientApp:
370
+ return client_app
371
+
372
+ app_fn = _load_client_app
344
373
 
345
374
  # Run main simulation loop
346
- asyncio.run(
347
- run(
348
- app_fn,
349
- backend_fn,
350
- nodes_mapping,
351
- state_factory,
352
- node_states,
353
- f_stop,
354
- )
375
+ run_api(
376
+ app_fn,
377
+ backend_fn,
378
+ nodes_mapping,
379
+ state_factory,
380
+ node_states,
381
+ f_stop,
355
382
  )
356
383
  except LoadClientAppError as loadapp_ex:
357
384
  f_stop_delay = 10
@@ -15,7 +15,6 @@
15
15
  """In-memory State implementation."""
16
16
 
17
17
 
18
- import os
19
18
  import threading
20
19
  import time
21
20
  from logging import ERROR
@@ -23,12 +22,13 @@ from typing import Dict, List, Optional, Set, Tuple
23
22
  from uuid import UUID, uuid4
24
23
 
25
24
  from flwr.common import log, now
26
- from flwr.common.typing import Run
25
+ from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
26
+ from flwr.common.typing import Run, UserConfig
27
27
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
28
28
  from flwr.server.superlink.state.state import State
29
29
  from flwr.server.utils import validate_task_ins_or_res
30
30
 
31
- from .utils import make_node_unavailable_taskres
31
+ from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
32
32
 
33
33
 
34
34
  class InMemoryState(State): # pylint: disable=R0902,R0904
@@ -216,7 +216,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
216
216
  ) -> int:
217
217
  """Create, store in state, and return `node_id`."""
218
218
  # Sample a random int64 as node_id
219
- node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
219
+ node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
220
220
 
221
221
  with self.lock:
222
222
  if node_id in self.node_ids:
@@ -275,15 +275,23 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
275
275
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
276
276
  return self.public_key_to_node_id.get(client_public_key)
277
277
 
278
- def create_run(self, fab_id: str, fab_version: str) -> int:
278
+ def create_run(
279
+ self,
280
+ fab_id: str,
281
+ fab_version: str,
282
+ override_config: UserConfig,
283
+ ) -> int:
279
284
  """Create a new run for the specified `fab_id` and `fab_version`."""
280
285
  # Sample a random int64 as run_id
281
286
  with self.lock:
282
- run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
287
+ run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
283
288
 
284
289
  if run_id not in self.run_ids:
285
290
  self.run_ids[run_id] = Run(
286
- run_id=run_id, fab_id=fab_id, fab_version=fab_version
291
+ run_id=run_id,
292
+ fab_id=fab_id,
293
+ fab_version=fab_version,
294
+ override_config=override_config,
287
295
  )
288
296
  return run_id
289
297
  log(ERROR, "Unexpected run creation failure.")
@@ -15,7 +15,7 @@
15
15
  """SQLite based implemenation of server state."""
16
16
 
17
17
 
18
- import os
18
+ import json
19
19
  import re
20
20
  import sqlite3
21
21
  import time
@@ -24,14 +24,15 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24
24
  from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import log, now
27
- from flwr.common.typing import Run
27
+ from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
28
+ from flwr.common.typing import Run, UserConfig
28
29
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
30
  from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
30
31
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
31
32
  from flwr.server.utils.validator import validate_task_ins_or_res
32
33
 
33
34
  from .state import State
34
- from .utils import make_node_unavailable_taskres
35
+ from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
35
36
 
36
37
  SQL_CREATE_TABLE_NODE = """
37
38
  CREATE TABLE IF NOT EXISTS node(
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
61
62
 
62
63
  SQL_CREATE_TABLE_RUN = """
63
64
  CREATE TABLE IF NOT EXISTS run(
64
- run_id INTEGER UNIQUE,
65
- fab_id TEXT,
66
- fab_version TEXT
65
+ run_id INTEGER UNIQUE,
66
+ fab_id TEXT,
67
+ fab_version TEXT,
68
+ override_config TEXT
67
69
  );
68
70
  """
69
71
 
@@ -541,7 +543,7 @@ class SqliteState(State): # pylint: disable=R0904
541
543
  ) -> int:
542
544
  """Create, store in state, and return `node_id`."""
543
545
  # Sample a random int64 as node_id
544
- node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
546
+ node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
545
547
 
546
548
  query = "SELECT node_id FROM node WHERE public_key = :public_key;"
547
549
  row = self.query(query, {"public_key": public_key})
@@ -613,17 +615,27 @@ class SqliteState(State): # pylint: disable=R0904
613
615
  return node_id
614
616
  return None
615
617
 
616
- def create_run(self, fab_id: str, fab_version: str) -> int:
618
+ def create_run(
619
+ self,
620
+ fab_id: str,
621
+ fab_version: str,
622
+ override_config: UserConfig,
623
+ ) -> int:
617
624
  """Create a new run for the specified `fab_id` and `fab_version`."""
618
625
  # Sample a random int64 as run_id
619
- run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
626
+ run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
620
627
 
621
628
  # Check conflicts
622
629
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
623
630
  # If run_id does not exist
624
631
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
625
- query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
626
- self.query(query, (run_id, fab_id, fab_version))
632
+ query = (
633
+ "INSERT INTO run (run_id, fab_id, fab_version, override_config)"
634
+ "VALUES (?, ?, ?, ?);"
635
+ )
636
+ self.query(
637
+ query, (run_id, fab_id, fab_version, json.dumps(override_config))
638
+ )
627
639
  return run_id
628
640
  log(ERROR, "Unexpected run creation failure.")
629
641
  return 0
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
687
699
  try:
688
700
  row = self.query(query, (run_id,))[0]
689
701
  return Run(
690
- run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
702
+ run_id=run_id,
703
+ fab_id=row["fab_id"],
704
+ fab_version=row["fab_version"],
705
+ override_config=json.loads(row["override_config"]),
691
706
  )
692
707
  except sqlite3.IntegrityError:
693
708
  log(ERROR, "`run_id` does not exist.")
@@ -19,7 +19,7 @@ import abc
19
19
  from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
22
- from flwr.common.typing import Run
22
+ from flwr.common.typing import Run, UserConfig
23
23
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
24
24
 
25
25
 
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
157
157
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
158
158
 
159
159
  @abc.abstractmethod
160
- def create_run(self, fab_id: str, fab_version: str) -> int:
160
+ def create_run(
161
+ self,
162
+ fab_id: str,
163
+ fab_version: str,
164
+ override_config: UserConfig,
165
+ ) -> int:
161
166
  """Create a new run for the specified `fab_id` and `fab_version`."""
162
167
 
163
168
  @abc.abstractmethod
@@ -17,6 +17,7 @@
17
17
 
18
18
  import time
19
19
  from logging import ERROR
20
+ from os import urandom
20
21
  from uuid import uuid4
21
22
 
22
23
  from flwr.common import log
@@ -31,6 +32,11 @@ NODE_UNAVAILABLE_ERROR_REASON = (
31
32
  )
32
33
 
33
34
 
35
+ def generate_rand_int_from_bytes(num_bytes: int) -> int:
36
+ """Generate a random `num_bytes` integer."""
37
+ return int.from_bytes(urandom(num_bytes), "little", signed=True)
38
+
39
+
34
40
  def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
35
41
  """Generate a TaskRes with a node unavailable error from a TaskIns."""
36
42
  current_time = time.time()