flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (99) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +47 -27
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +32 -21
  5. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
  13. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  14. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  16. flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
  17. flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
  18. flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
  19. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  20. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
  21. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  22. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
  23. flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
  24. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
  25. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
  26. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  27. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  28. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  29. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  30. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  34. flwr/cli/run/run.py +133 -54
  35. flwr/client/app.py +56 -24
  36. flwr/client/client_app.py +28 -8
  37. flwr/client/grpc_adapter_client/connection.py +3 -2
  38. flwr/client/grpc_client/connection.py +3 -2
  39. flwr/client/grpc_rere_client/connection.py +17 -6
  40. flwr/client/message_handler/message_handler.py +1 -1
  41. flwr/client/node_state.py +59 -12
  42. flwr/client/node_state_tests.py +4 -3
  43. flwr/client/rest_client/connection.py +19 -8
  44. flwr/client/supernode/app.py +39 -39
  45. flwr/client/typing.py +2 -2
  46. flwr/common/config.py +92 -2
  47. flwr/common/constant.py +3 -0
  48. flwr/common/context.py +24 -9
  49. flwr/common/logger.py +25 -0
  50. flwr/common/object_ref.py +84 -21
  51. flwr/common/serde.py +45 -0
  52. flwr/common/telemetry.py +17 -0
  53. flwr/common/typing.py +5 -0
  54. flwr/proto/common_pb2.py +36 -0
  55. flwr/proto/common_pb2.pyi +121 -0
  56. flwr/proto/common_pb2_grpc.py +4 -0
  57. flwr/proto/common_pb2_grpc.pyi +4 -0
  58. flwr/proto/driver_pb2.py +24 -19
  59. flwr/proto/driver_pb2.pyi +21 -1
  60. flwr/proto/exec_pb2.py +20 -11
  61. flwr/proto/exec_pb2.pyi +41 -1
  62. flwr/proto/run_pb2.py +12 -7
  63. flwr/proto/run_pb2.pyi +22 -1
  64. flwr/proto/task_pb2.py +7 -8
  65. flwr/server/__init__.py +2 -0
  66. flwr/server/compat/legacy_context.py +5 -4
  67. flwr/server/driver/grpc_driver.py +82 -140
  68. flwr/server/run_serverapp.py +40 -18
  69. flwr/server/server_app.py +56 -10
  70. flwr/server/serverapp_components.py +52 -0
  71. flwr/server/superlink/driver/driver_servicer.py +18 -3
  72. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  73. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  74. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  75. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  76. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  77. flwr/server/superlink/state/in_memory_state.py +11 -3
  78. flwr/server/superlink/state/sqlite_state.py +23 -8
  79. flwr/server/superlink/state/state.py +7 -2
  80. flwr/server/typing.py +2 -0
  81. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  82. flwr/simulation/__init__.py +1 -1
  83. flwr/simulation/app.py +4 -3
  84. flwr/simulation/ray_transport/ray_actor.py +15 -19
  85. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  86. flwr/simulation/run_simulation.py +269 -70
  87. flwr/superexec/app.py +17 -11
  88. flwr/superexec/deployment.py +111 -35
  89. flwr/superexec/exec_grpc.py +5 -1
  90. flwr/superexec/exec_servicer.py +6 -1
  91. flwr/superexec/executor.py +21 -0
  92. flwr/superexec/simulation.py +181 -0
  93. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
  94. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
  95. flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
  96. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
  97. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
  98. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
  99. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
@@ -14,24 +14,33 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine API."""
16
16
 
17
- import asyncio
17
+
18
18
  import json
19
- import sys
19
+ import threading
20
20
  import time
21
21
  import traceback
22
+ from concurrent.futures import ThreadPoolExecutor
22
23
  from logging import DEBUG, ERROR, INFO, WARN
23
24
  from pathlib import Path
24
- from typing import Callable, Dict, List, Optional
25
+ from queue import Empty, Queue
26
+ from time import sleep
27
+ from typing import Callable, Dict, Optional
25
28
 
26
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
27
30
  from flwr.client.node_state import NodeState
28
- from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
31
+ from flwr.client.supernode.app import _get_load_client_app_fn
32
+ from flwr.common.constant import (
33
+ NUM_PARTITIONS_KEY,
34
+ PARTITION_ID_KEY,
35
+ PING_MAX_INTERVAL,
36
+ ErrorCode,
37
+ )
29
38
  from flwr.common.logger import log
30
39
  from flwr.common.message import Error
31
- from flwr.common.object_ref import load_app
32
40
  from flwr.common.serde import message_from_taskins, message_to_taskres
33
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
34
- from flwr.server.superlink.state import StateFactory
41
+ from flwr.common.typing import Run
42
+ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
43
+ from flwr.server.superlink.state import State, StateFactory
35
44
 
36
45
  from .backend import Backend, error_messages_backends, supported_backends
37
46
 
@@ -51,31 +60,57 @@ def _register_nodes(
51
60
  return nodes_mapping
52
61
 
53
62
 
63
+ def _register_node_states(
64
+ nodes_mapping: NodeToPartitionMapping,
65
+ run: Run,
66
+ app_dir: Optional[str] = None,
67
+ ) -> Dict[int, NodeState]:
68
+ """Create NodeState objects and pre-register the context for the run."""
69
+ node_states: Dict[int, NodeState] = {}
70
+ num_partitions = len(set(nodes_mapping.values()))
71
+ for node_id, partition_id in nodes_mapping.items():
72
+ node_states[node_id] = NodeState(
73
+ node_id=node_id,
74
+ node_config={
75
+ PARTITION_ID_KEY: partition_id,
76
+ NUM_PARTITIONS_KEY: num_partitions,
77
+ },
78
+ )
79
+
80
+ # Pre-register Context objects
81
+ node_states[node_id].register_context(
82
+ run_id=run.run_id, run=run, app_dir=app_dir
83
+ )
84
+
85
+ return node_states
86
+
87
+
54
88
  # pylint: disable=too-many-arguments,too-many-locals
55
- async def worker(
89
+ def worker(
56
90
  app_fn: Callable[[], ClientApp],
57
- queue: "asyncio.Queue[TaskIns]",
91
+ taskins_queue: "Queue[TaskIns]",
92
+ taskres_queue: "Queue[TaskRes]",
58
93
  node_states: Dict[int, NodeState],
59
- state_factory: StateFactory,
60
94
  backend: Backend,
95
+ f_stop: threading.Event,
61
96
  ) -> None:
62
97
  """Get TaskIns from queue and pass it to an actor in the pool to execute it."""
63
- state = state_factory.state()
64
- while True:
98
+ while not f_stop.is_set():
65
99
  out_mssg = None
66
100
  try:
67
- task_ins: TaskIns = await queue.get()
101
+ # Fetch from queue with timeout. We use a timeout so
102
+ # the stopping event can be evaluated even when the queue is empty.
103
+ task_ins: TaskIns = taskins_queue.get(timeout=1.0)
68
104
  node_id = task_ins.task.consumer.node_id
69
105
 
70
- # Register and retrieve runstate
71
- node_states[node_id].register_context(run_id=task_ins.run_id)
106
+ # Retrieve context
72
107
  context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)
73
108
 
74
109
  # Convert TaskIns to Message
75
110
  message = message_from_taskins(task_ins)
76
111
 
77
112
  # Let backend process message
78
- out_mssg, updated_context = await backend.process_message(
113
+ out_mssg, updated_context = backend.process_message(
79
114
  app_fn, message, context
80
115
  )
81
116
 
@@ -83,11 +118,9 @@ async def worker(
83
118
  node_states[node_id].update_context(
84
119
  task_ins.run_id, context=updated_context
85
120
  )
86
-
87
- except asyncio.CancelledError as e:
88
- log(DEBUG, "Terminating async worker: %s", e)
89
- break
90
-
121
+ except Empty:
122
+ # An exception raised if queue.get times out
123
+ pass
91
124
  # Exceptions aren't raised but reported as an error message
92
125
  except Exception as ex: # pylint: disable=broad-exception-caught
93
126
  log(ERROR, ex)
@@ -111,67 +144,48 @@ async def worker(
111
144
  task_res = message_to_taskres(out_mssg)
112
145
  # Store TaskRes in state
113
146
  task_res.task.pushed_at = time.time()
114
- state.store_task_res(task_res)
147
+ taskres_queue.put(task_res)
115
148
 
116
149
 
117
- async def add_taskins_to_queue(
118
- queue: "asyncio.Queue[TaskIns]",
119
- state_factory: StateFactory,
150
+ def add_taskins_to_queue(
151
+ state: State,
152
+ queue: "Queue[TaskIns]",
120
153
  nodes_mapping: NodeToPartitionMapping,
121
- backend: Backend,
122
- consumers: List["asyncio.Task[None]"],
123
- f_stop: asyncio.Event,
154
+ f_stop: threading.Event,
124
155
  ) -> None:
125
- """Retrieve TaskIns and add it to the queue."""
126
- state = state_factory.state()
127
- num_initial_consumers = len(consumers)
156
+ """Put TaskIns in a queue from State."""
128
157
  while not f_stop.is_set():
129
158
  for node_id in nodes_mapping.keys():
130
- task_ins = state.get_task_ins(node_id=node_id, limit=1)
131
- if task_ins:
132
- await queue.put(task_ins[0])
133
-
134
- # Count consumers that are running
135
- num_active = sum(not (cc.done()) for cc in consumers)
136
-
137
- # Alert if number of consumers decreased by half
138
- if num_active < num_initial_consumers // 2:
139
- log(
140
- WARN,
141
- "Number of active workers has more than halved: (%i/%i active)",
142
- num_active,
143
- num_initial_consumers,
144
- )
159
+ task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
160
+ for task_ins in task_ins_list:
161
+ queue.put(task_ins)
162
+ sleep(0.1)
145
163
 
146
- # Break if consumers died
147
- if num_active == 0:
148
- raise RuntimeError("All workers have died. Ending Simulation.")
149
164
 
150
- # Log some stats
151
- log(
152
- DEBUG,
153
- "Simulation Engine stats: "
154
- "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
155
- num_active,
156
- num_initial_consumers,
157
- backend.__class__.__name__,
158
- backend.num_workers,
159
- queue.qsize(),
160
- )
161
- await asyncio.sleep(1.0)
162
- log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
165
+ def put_taskres_into_state(
166
+ state: State, queue: "Queue[TaskRes]", f_stop: threading.Event
167
+ ) -> None:
168
+ """Put TaskRes into State from a queue."""
169
+ while not f_stop.is_set():
170
+ try:
171
+ taskres = queue.get(timeout=1.0)
172
+ state.store_task_res(taskres)
173
+ except Empty:
174
+ # queue is empty when timeout was triggered
175
+ pass
163
176
 
164
177
 
165
- async def run(
178
+ def run_api(
166
179
  app_fn: Callable[[], ClientApp],
167
180
  backend_fn: Callable[[], Backend],
168
181
  nodes_mapping: NodeToPartitionMapping,
169
182
  state_factory: StateFactory,
170
183
  node_states: Dict[int, NodeState],
171
- f_stop: asyncio.Event,
184
+ f_stop: threading.Event,
172
185
  ) -> None:
173
- """Run the VCE async."""
174
- queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
186
+ """Run the VCE."""
187
+ taskins_queue: "Queue[TaskIns]" = Queue()
188
+ taskres_queue: "Queue[TaskRes]" = Queue()
175
189
 
176
190
  try:
177
191
 
@@ -179,27 +193,48 @@ async def run(
179
193
  backend = backend_fn()
180
194
 
181
195
  # Build backend
182
- await backend.build()
196
+ backend.build()
183
197
 
184
198
  # Add workers (they submit Messages to Backend)
185
- worker_tasks = [
186
- asyncio.create_task(
187
- worker(app_fn, queue, node_states, state_factory, backend)
188
- )
189
- for _ in range(backend.num_workers)
190
- ]
191
- # Create producer (adds TaskIns into Queue)
192
- producer = asyncio.create_task(
193
- add_taskins_to_queue(
194
- queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
195
- )
199
+ state = state_factory.state()
200
+
201
+ extractor_th = threading.Thread(
202
+ target=add_taskins_to_queue,
203
+ args=(
204
+ state,
205
+ taskins_queue,
206
+ nodes_mapping,
207
+ f_stop,
208
+ ),
196
209
  )
210
+ extractor_th.start()
197
211
 
198
- # Wait for producer to finish
199
- # The producer runs forever until f_stop is set or until
200
- # all worker (consumer) coroutines are completed. Workers
201
- # also run forever and only end if an exception is raised.
202
- await asyncio.gather(producer)
212
+ injector_th = threading.Thread(
213
+ target=put_taskres_into_state,
214
+ args=(
215
+ state,
216
+ taskres_queue,
217
+ f_stop,
218
+ ),
219
+ )
220
+ injector_th.start()
221
+
222
+ with ThreadPoolExecutor() as executor:
223
+ _ = [
224
+ executor.submit(
225
+ worker,
226
+ app_fn,
227
+ taskins_queue,
228
+ taskres_queue,
229
+ node_states,
230
+ backend,
231
+ f_stop,
232
+ )
233
+ for _ in range(backend.num_workers)
234
+ ]
235
+
236
+ extractor_th.join()
237
+ injector_th.join()
203
238
 
204
239
  except Exception as ex:
205
240
 
@@ -214,18 +249,9 @@ async def run(
214
249
  raise RuntimeError("Simulation Engine crashed.") from ex
215
250
 
216
251
  finally:
217
- # Produced task terminated, now cancel worker tasks
218
- for w_t in worker_tasks:
219
- _ = w_t.cancel()
220
-
221
- while not all(w_t.done() for w_t in worker_tasks):
222
- log(DEBUG, "Terminating async workers...")
223
- await asyncio.sleep(0.5)
224
-
225
- await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])
226
252
 
227
253
  # Terminate backend
228
- await backend.terminate()
254
+ backend.terminate()
229
255
 
230
256
 
231
257
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
@@ -234,7 +260,10 @@ def start_vce(
234
260
  backend_name: str,
235
261
  backend_config_json_stream: str,
236
262
  app_dir: str,
237
- f_stop: asyncio.Event,
263
+ is_app: bool,
264
+ f_stop: threading.Event,
265
+ run: Run,
266
+ flwr_dir: Optional[str] = None,
238
267
  client_app: Optional[ClientApp] = None,
239
268
  client_app_attr: Optional[str] = None,
240
269
  num_supernodes: Optional[int] = None,
@@ -285,9 +314,9 @@ def start_vce(
285
314
  )
286
315
 
287
316
  # Construct mapping of NodeStates
288
- node_states: Dict[int, NodeState] = {}
289
- for node_id, partition_id in nodes_mapping.items():
290
- node_states[node_id] = NodeState(partition_id=partition_id)
317
+ node_states = _register_node_states(
318
+ nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
319
+ )
291
320
 
292
321
  # Load backend config
293
322
  log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
@@ -316,16 +345,12 @@ def start_vce(
316
345
  def _load() -> ClientApp:
317
346
 
318
347
  if client_app_attr:
319
-
320
- if app_dir is not None:
321
- sys.path.insert(0, app_dir)
322
-
323
- app: ClientApp = load_app(client_app_attr, LoadClientAppError, app_dir)
324
-
325
- if not isinstance(app, ClientApp):
326
- raise LoadClientAppError(
327
- f"Attribute {client_app_attr} is not of type {ClientApp}",
328
- ) from None
348
+ app = _get_load_client_app_fn(
349
+ default_app_ref=client_app_attr,
350
+ project_dir=app_dir,
351
+ flwr_dir=flwr_dir,
352
+ multi_app=True,
353
+ )(run.fab_id, run.fab_version)
329
354
 
330
355
  if client_app:
331
356
  app = client_app
@@ -335,18 +360,25 @@ def start_vce(
335
360
 
336
361
  try:
337
362
  # Test if ClientApp can be loaded
338
- _ = app_fn()
363
+ client_app = app_fn()
364
+
365
+ # Cache `ClientApp`
366
+ if client_app_attr:
367
+ # Now wrap the loaded ClientApp in a dummy function
368
+ # this prevent unnecesary low-level loading of ClientApp
369
+ def _load_client_app() -> ClientApp:
370
+ return client_app
371
+
372
+ app_fn = _load_client_app
339
373
 
340
374
  # Run main simulation loop
341
- asyncio.run(
342
- run(
343
- app_fn,
344
- backend_fn,
345
- nodes_mapping,
346
- state_factory,
347
- node_states,
348
- f_stop,
349
- )
375
+ run_api(
376
+ app_fn,
377
+ backend_fn,
378
+ nodes_mapping,
379
+ state_factory,
380
+ node_states,
381
+ f_stop,
350
382
  )
351
383
  except LoadClientAppError as loadapp_ex:
352
384
  f_stop_delay = 10
@@ -23,7 +23,7 @@ from uuid import UUID, uuid4
23
23
 
24
24
  from flwr.common import log, now
25
25
  from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
26
- from flwr.common.typing import Run
26
+ from flwr.common.typing import Run, UserConfig
27
27
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
28
28
  from flwr.server.superlink.state.state import State
29
29
  from flwr.server.utils import validate_task_ins_or_res
@@ -275,7 +275,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
275
275
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
276
276
  return self.public_key_to_node_id.get(client_public_key)
277
277
 
278
- def create_run(self, fab_id: str, fab_version: str) -> int:
278
+ def create_run(
279
+ self,
280
+ fab_id: str,
281
+ fab_version: str,
282
+ override_config: UserConfig,
283
+ ) -> int:
279
284
  """Create a new run for the specified `fab_id` and `fab_version`."""
280
285
  # Sample a random int64 as run_id
281
286
  with self.lock:
@@ -283,7 +288,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
283
288
 
284
289
  if run_id not in self.run_ids:
285
290
  self.run_ids[run_id] = Run(
286
- run_id=run_id, fab_id=fab_id, fab_version=fab_version
291
+ run_id=run_id,
292
+ fab_id=fab_id,
293
+ fab_version=fab_version,
294
+ override_config=override_config,
287
295
  )
288
296
  return run_id
289
297
  log(ERROR, "Unexpected run creation failure.")
@@ -15,6 +15,7 @@
15
15
  """SQLite based implemenation of server state."""
16
16
 
17
17
 
18
+ import json
18
19
  import re
19
20
  import sqlite3
20
21
  import time
@@ -24,7 +25,7 @@ from uuid import UUID, uuid4
24
25
 
25
26
  from flwr.common import log, now
26
27
  from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
27
- from flwr.common.typing import Run
28
+ from flwr.common.typing import Run, UserConfig
28
29
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
30
  from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
30
31
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
@@ -61,9 +62,10 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
61
62
 
62
63
  SQL_CREATE_TABLE_RUN = """
63
64
  CREATE TABLE IF NOT EXISTS run(
64
- run_id INTEGER UNIQUE,
65
- fab_id TEXT,
66
- fab_version TEXT
65
+ run_id INTEGER UNIQUE,
66
+ fab_id TEXT,
67
+ fab_version TEXT,
68
+ override_config TEXT
67
69
  );
68
70
  """
69
71
 
@@ -613,7 +615,12 @@ class SqliteState(State): # pylint: disable=R0904
613
615
  return node_id
614
616
  return None
615
617
 
616
- def create_run(self, fab_id: str, fab_version: str) -> int:
618
+ def create_run(
619
+ self,
620
+ fab_id: str,
621
+ fab_version: str,
622
+ override_config: UserConfig,
623
+ ) -> int:
617
624
  """Create a new run for the specified `fab_id` and `fab_version`."""
618
625
  # Sample a random int64 as run_id
619
626
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
@@ -622,8 +629,13 @@ class SqliteState(State): # pylint: disable=R0904
622
629
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
623
630
  # If run_id does not exist
624
631
  if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
625
- query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
626
- self.query(query, (run_id, fab_id, fab_version))
632
+ query = (
633
+ "INSERT INTO run (run_id, fab_id, fab_version, override_config)"
634
+ "VALUES (?, ?, ?, ?);"
635
+ )
636
+ self.query(
637
+ query, (run_id, fab_id, fab_version, json.dumps(override_config))
638
+ )
627
639
  return run_id
628
640
  log(ERROR, "Unexpected run creation failure.")
629
641
  return 0
@@ -687,7 +699,10 @@ class SqliteState(State): # pylint: disable=R0904
687
699
  try:
688
700
  row = self.query(query, (run_id,))[0]
689
701
  return Run(
690
- run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
702
+ run_id=run_id,
703
+ fab_id=row["fab_id"],
704
+ fab_version=row["fab_version"],
705
+ override_config=json.loads(row["override_config"]),
691
706
  )
692
707
  except sqlite3.IntegrityError:
693
708
  log(ERROR, "`run_id` does not exist.")
@@ -19,7 +19,7 @@ import abc
19
19
  from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
22
- from flwr.common.typing import Run
22
+ from flwr.common.typing import Run, UserConfig
23
23
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
24
24
 
25
25
 
@@ -157,7 +157,12 @@ class State(abc.ABC): # pylint: disable=R0904
157
157
  """Retrieve stored `node_id` filtered by `client_public_keys`."""
158
158
 
159
159
  @abc.abstractmethod
160
- def create_run(self, fab_id: str, fab_version: str) -> int:
160
+ def create_run(
161
+ self,
162
+ fab_id: str,
163
+ fab_version: str,
164
+ override_config: UserConfig,
165
+ ) -> int:
161
166
  """Create a new run for the specified `fab_id` and `fab_version`."""
162
167
 
163
168
  @abc.abstractmethod
flwr/server/typing.py CHANGED
@@ -20,6 +20,8 @@ from typing import Callable
20
20
  from flwr.common import Context
21
21
 
22
22
  from .driver import Driver
23
+ from .serverapp_components import ServerAppComponents
23
24
 
24
25
  ServerAppCallable = Callable[[Driver, Context], None]
25
26
  Workflow = Callable[[Driver, Context], None]
27
+ ServerFn = Callable[[Context], ServerAppComponents]
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
81
81
  forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
82
82
  aggregate_ndarrays: NDArrays = field(default_factory=list)
83
83
  legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
+ failures: List[Exception] = field(default_factory=list)
84
85
 
85
86
 
86
87
  class SecAggPlusWorkflow:
@@ -394,6 +395,7 @@ class SecAggPlusWorkflow:
394
395
 
395
396
  for msg in msgs:
396
397
  if msg.has_error():
398
+ state.failures.append(Exception(msg.error))
397
399
  continue
398
400
  key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
399
401
  node_id = msg.metadata.src_node_id
@@ -451,6 +453,9 @@ class SecAggPlusWorkflow:
451
453
  nid: [] for nid in state.active_node_ids
452
454
  } # dest node ID -> list of src node IDs
453
455
  for msg in msgs:
456
+ if msg.has_error():
457
+ state.failures.append(Exception(msg.error))
458
+ continue
454
459
  node_id = msg.metadata.src_node_id
455
460
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
456
461
  dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
@@ -515,6 +520,9 @@ class SecAggPlusWorkflow:
515
520
  # Sum collected masked vectors and compute active/dead node IDs
516
521
  masked_vector = None
517
522
  for msg in msgs:
523
+ if msg.has_error():
524
+ state.failures.append(Exception(msg.error))
525
+ continue
518
526
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
519
527
  bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
520
528
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
@@ -528,6 +536,9 @@ class SecAggPlusWorkflow:
528
536
 
529
537
  # Backward compatibility with Strategy
530
538
  for msg in msgs:
539
+ if msg.has_error():
540
+ state.failures.append(Exception(msg.error))
541
+ continue
531
542
  fitres = compat.recordset_to_fitres(msg.content, True)
532
543
  proxy = state.nid_to_proxies[msg.metadata.src_node_id]
533
544
  state.legacy_results.append((proxy, fitres))
@@ -584,6 +595,9 @@ class SecAggPlusWorkflow:
584
595
  for nid in state.sampled_node_ids:
585
596
  collected_shares_dict[nid] = []
586
597
  for msg in msgs:
598
+ if msg.has_error():
599
+ state.failures.append(Exception(msg.error))
600
+ continue
587
601
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
588
602
  nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
589
603
  shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
@@ -652,9 +666,11 @@ class SecAggPlusWorkflow:
652
666
  INFO,
653
667
  "aggregate_fit: received %s results and %s failures",
654
668
  len(results),
655
- 0,
669
+ len(state.failures),
670
+ )
671
+ aggregated_result = context.strategy.aggregate_fit(
672
+ current_round, results, state.failures # type: ignore
656
673
  )
657
- aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
658
674
  parameters_aggregated, metrics_aggregated = aggregated_result
659
675
 
660
676
  # Update the parameters and write history
@@ -28,7 +28,7 @@ else:
28
28
 
29
29
  To install the necessary dependencies, install `flwr` with the `simulation` extra:
30
30
 
31
- pip install -U flwr["simulation"]
31
+ pip install -U "flwr[simulation]"
32
32
  """
33
33
 
34
34
  def start_simulation(*args, **kwargs): # type: ignore
flwr/simulation/app.py CHANGED
@@ -111,9 +111,9 @@ def start_simulation(
111
111
  Parameters
112
112
  ----------
113
113
  client_fn : ClientFnExt
114
- A function creating Client instances. The function must have the signature
115
- `client_fn(node_id: int, partition_id: Optional[int]). It should return
116
- a single client instance of type Client. Note that the created client
114
+ A function creating `Client` instances. The function must have the signature
115
+ `client_fn(context: Context). It should return
116
+ a single client instance of type `Client`. Note that the created client
117
117
  instances are ephemeral and will often be destroyed after a single method
118
118
  invocation. Since client instances are not long-lived, they should not attempt
119
119
  to carry state over method invocations. Any state required by the instance
@@ -327,6 +327,7 @@ def start_simulation(
327
327
  client_fn=client_fn,
328
328
  node_id=node_id,
329
329
  partition_id=partition_id,
330
+ num_partitions=num_clients,
330
331
  actor_pool=pool,
331
332
  )
332
333
  initialized_server.client_manager().register(client=client_proxy)