flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/app.py +2 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +15 -2
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  14. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
  15. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  16. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  17. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  18. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  19. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  20. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  21. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  23. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  24. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  26. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  27. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  28. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  29. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  30. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  31. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  32. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  33. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  34. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  35. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  36. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  37. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  38. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  39. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  40. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  41. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  42. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  43. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  44. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
  45. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  46. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  47. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  49. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  50. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  52. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  53. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  54. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
  55. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  56. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  57. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  58. flwr/cli/pull.py +100 -0
  59. flwr/cli/run/run.py +9 -13
  60. flwr/cli/stop.py +7 -4
  61. flwr/cli/utils.py +36 -8
  62. flwr/client/grpc_rere_client/connection.py +1 -12
  63. flwr/client/rest_client/connection.py +3 -0
  64. flwr/clientapp/__init__.py +10 -0
  65. flwr/clientapp/mod/__init__.py +29 -0
  66. flwr/clientapp/mod/centraldp_mods.py +248 -0
  67. flwr/clientapp/mod/localdp_mod.py +169 -0
  68. flwr/clientapp/typing.py +22 -0
  69. flwr/common/args.py +20 -6
  70. flwr/common/auth_plugin/__init__.py +4 -4
  71. flwr/common/auth_plugin/auth_plugin.py +7 -7
  72. flwr/common/constant.py +26 -4
  73. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  74. flwr/common/exit/__init__.py +4 -0
  75. flwr/common/exit/exit.py +8 -1
  76. flwr/common/exit/exit_code.py +30 -7
  77. flwr/common/exit/exit_handler.py +62 -0
  78. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  79. flwr/common/grpc.py +0 -11
  80. flwr/common/inflatable_utils.py +1 -1
  81. flwr/common/logger.py +1 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/retry_invoker.py +30 -11
  84. flwr/common/telemetry.py +4 -0
  85. flwr/compat/server/app.py +2 -2
  86. flwr/proto/appio_pb2.py +25 -17
  87. flwr/proto/appio_pb2.pyi +46 -2
  88. flwr/proto/clientappio_pb2.py +3 -11
  89. flwr/proto/clientappio_pb2.pyi +0 -47
  90. flwr/proto/clientappio_pb2_grpc.py +19 -20
  91. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  92. flwr/proto/control_pb2.py +66 -0
  93. flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
  94. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
  95. flwr/proto/control_pb2_grpc.pyi +106 -0
  96. flwr/proto/serverappio_pb2.py +2 -2
  97. flwr/proto/serverappio_pb2_grpc.py +68 -0
  98. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  99. flwr/proto/simulationio_pb2.py +4 -11
  100. flwr/proto/simulationio_pb2.pyi +0 -58
  101. flwr/proto/simulationio_pb2_grpc.py +129 -27
  102. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  103. flwr/server/app.py +142 -152
  104. flwr/server/grid/grpc_grid.py +3 -0
  105. flwr/server/grid/inmemory_grid.py +1 -0
  106. flwr/server/serverapp/app.py +157 -146
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  110. flwr/server/superlink/linkstate/linkstate.py +2 -1
  111. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  112. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  113. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  114. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  115. flwr/serverapp/__init__.py +12 -0
  116. flwr/serverapp/exception.py +38 -0
  117. flwr/serverapp/strategy/__init__.py +64 -0
  118. flwr/serverapp/strategy/bulyan.py +238 -0
  119. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  120. flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
  121. flwr/serverapp/strategy/fedadagrad.py +159 -0
  122. flwr/serverapp/strategy/fedadam.py +178 -0
  123. flwr/serverapp/strategy/fedavg.py +320 -0
  124. flwr/serverapp/strategy/fedavgm.py +198 -0
  125. flwr/serverapp/strategy/fedmedian.py +105 -0
  126. flwr/serverapp/strategy/fedopt.py +218 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +170 -0
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/result.py +105 -0
  136. flwr/serverapp/strategy/strategy.py +285 -0
  137. flwr/serverapp/strategy/strategy_utils.py +299 -0
  138. flwr/simulation/app.py +161 -164
  139. flwr/simulation/run_simulation.py +25 -30
  140. flwr/supercore/app_utils.py +58 -0
  141. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  142. flwr/supercore/cli/flower_superexec.py +166 -0
  143. flwr/supercore/constant.py +19 -0
  144. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  145. flwr/supercore/corestate/corestate.py +81 -0
  146. flwr/supercore/grpc_health/__init__.py +3 -0
  147. flwr/supercore/grpc_health/health_server.py +53 -0
  148. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  149. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  150. flwr/supercore/superexec/plugin/__init__.py +28 -0
  151. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  152. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  153. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
  154. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  155. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  156. flwr/supercore/superexec/run_superexec.py +199 -0
  157. flwr/superlink/artifact_provider/__init__.py +22 -0
  158. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  159. flwr/superlink/servicer/__init__.py +15 -0
  160. flwr/superlink/servicer/control/__init__.py +22 -0
  161. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  162. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
  163. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  164. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
  165. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  166. flwr/supernode/cli/flower_supernode.py +3 -0
  167. flwr/supernode/cli/flwr_clientapp.py +18 -21
  168. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  169. flwr/supernode/nodestate/nodestate.py +3 -59
  170. flwr/supernode/runtime/run_clientapp.py +39 -102
  171. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  172. flwr/supernode/start_client_internal.py +35 -76
  173. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
  174. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
  175. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
  176. flwr/proto/exec_pb2.py +0 -62
  177. flwr/proto/exec_pb2_grpc.pyi +0 -93
  178. flwr/superexec/app.py +0 -45
  179. flwr/superexec/deployment.py +0 -191
  180. flwr/superexec/executor.py +0 -100
  181. flwr/superexec/simulation.py +0 -129
  182. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
@@ -43,6 +43,8 @@ from flwr.common.serde import (
43
43
  from flwr.common.typing import Fab, RunStatus
44
44
  from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
45
45
  from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
46
+ ListAppsToLaunchRequest,
47
+ ListAppsToLaunchResponse,
46
48
  PullAppInputsRequest,
47
49
  PullAppInputsResponse,
48
50
  PullAppMessagesRequest,
@@ -51,6 +53,8 @@ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
51
53
  PushAppMessagesResponse,
52
54
  PushAppOutputsRequest,
53
55
  PushAppOutputsResponse,
56
+ RequestTokenRequest,
57
+ RequestTokenResponse,
54
58
  )
55
59
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
56
60
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
@@ -104,6 +108,42 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
104
108
  self.objectstore_factory = objectstore_factory
105
109
  self.lock = threading.RLock()
106
110
 
111
+ def ListAppsToLaunch(
112
+ self,
113
+ request: ListAppsToLaunchRequest,
114
+ context: grpc.ServicerContext,
115
+ ) -> ListAppsToLaunchResponse:
116
+ """Get run IDs with pending messages."""
117
+ log(DEBUG, "ServerAppIoServicer.ListAppsToLaunch")
118
+
119
+ # Initialize state connection
120
+ state = self.state_factory.state()
121
+
122
+ # Get IDs of runs in pending status
123
+ run_ids = state.get_run_ids(flwr_aid=None)
124
+ pending_run_ids = []
125
+ for run_id, status in state.get_run_status(run_ids).items():
126
+ if status.status == Status.PENDING:
127
+ pending_run_ids.append(run_id)
128
+
129
+ # Return run IDs
130
+ return ListAppsToLaunchResponse(run_ids=pending_run_ids)
131
+
132
+ def RequestToken(
133
+ self, request: RequestTokenRequest, context: grpc.ServicerContext
134
+ ) -> RequestTokenResponse:
135
+ """Request token."""
136
+ log(DEBUG, "ServerAppIoServicer.RequestToken")
137
+
138
+ # Initialize state connection
139
+ state = self.state_factory.state()
140
+
141
+ # Attempt to create a token for the provided run ID
142
+ token = state.create_token(request.run_id)
143
+
144
+ # Return the token
145
+ return RequestTokenResponse(token=token or "")
146
+
107
147
  def GetNodes(
108
148
  self, request: GetNodesRequest, context: grpc.ServicerContext
109
149
  ) -> GetNodesResponse:
@@ -289,14 +329,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
289
329
  # Init access to LinkState
290
330
  state = self.state_factory.state()
291
331
 
332
+ # Validate the token
333
+ run_id = self._verify_token(request.token, context)
334
+
292
335
  # Lock access to LinkState, preventing obtaining the same pending run_id
293
336
  with self.lock:
294
- # Attempt getting the run_id of a pending run
295
- run_id = state.get_pending_run_id()
296
- # If there's no pending run, return an empty response
297
- if run_id is None:
298
- return PullAppInputsResponse()
299
-
300
337
  # Init access to Ffs
301
338
  ffs = self.ffs_factory.ffs()
302
339
 
@@ -327,6 +364,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
327
364
  """Push ServerApp process outputs."""
328
365
  log(DEBUG, "ServerAppIoServicer.PushAppOutputs")
329
366
 
367
+ # Validate the token
368
+ run_id = self._verify_token(request.token, context)
369
+
330
370
  # Init state and store
331
371
  state = self.state_factory.state()
332
372
  store = self.objectstore_factory.store()
@@ -341,6 +381,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
341
381
  )
342
382
 
343
383
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
384
+
385
+ # Remove the token
386
+ state.delete_token(run_id)
344
387
  return PushAppOutputsResponse()
345
388
 
346
389
  def UpdateRunStatus(
@@ -508,6 +551,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
508
551
 
509
552
  return ConfirmMessageReceivedResponse()
510
553
 
554
+ def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
555
+ """Verify the token and return the associated run ID."""
556
+ state = self.state_factory.state()
557
+ run_id = state.get_run_id_by_token(token)
558
+ if run_id is None or not state.verify_token(run_id, token):
559
+ context.abort(
560
+ grpc.StatusCode.PERMISSION_DENIED,
561
+ "Invalid token.",
562
+ )
563
+ raise RuntimeError("This line should never be reached.")
564
+ return run_id
565
+
511
566
 
512
567
  def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
513
568
  """Raise a `ValueError` with a detailed message if a validation error occurs."""
@@ -34,6 +34,16 @@ from flwr.common.serde import (
34
34
  )
35
35
  from flwr.common.typing import Fab, RunStatus
36
36
  from flwr.proto import simulationio_pb2_grpc
37
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
38
+ ListAppsToLaunchRequest,
39
+ ListAppsToLaunchResponse,
40
+ PullAppInputsRequest,
41
+ PullAppInputsResponse,
42
+ PushAppOutputsRequest,
43
+ PushAppOutputsResponse,
44
+ RequestTokenRequest,
45
+ RequestTokenResponse,
46
+ )
37
47
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
38
48
  SendAppHeartbeatRequest,
39
49
  SendAppHeartbeatResponse,
@@ -45,17 +55,13 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
45
55
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
46
56
  GetFederationOptionsRequest,
47
57
  GetFederationOptionsResponse,
58
+ GetRunRequest,
59
+ GetRunResponse,
48
60
  GetRunStatusRequest,
49
61
  GetRunStatusResponse,
50
62
  UpdateRunStatusRequest,
51
63
  UpdateRunStatusResponse,
52
64
  )
53
- from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
54
- PullSimulationInputsRequest,
55
- PullSimulationInputsResponse,
56
- PushSimulationOutputsRequest,
57
- PushSimulationOutputsResponse,
58
- )
59
65
  from flwr.server.superlink.linkstate import LinkStateFactory
60
66
  from flwr.server.superlink.utils import abort_if
61
67
  from flwr.supercore.ffs import FfsFactory
@@ -71,23 +77,73 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
71
77
  self.ffs_factory = ffs_factory
72
78
  self.lock = threading.RLock()
73
79
 
74
- def PullSimulationInputs(
75
- self, request: PullSimulationInputsRequest, context: ServicerContext
76
- ) -> PullSimulationInputsResponse:
80
+ def ListAppsToLaunch(
81
+ self,
82
+ request: ListAppsToLaunchRequest,
83
+ context: grpc.ServicerContext,
84
+ ) -> ListAppsToLaunchResponse:
85
+ """Get run IDs with pending messages."""
86
+ log(DEBUG, "SimulationIoServicer.ListAppsToLaunch")
87
+
88
+ # Initialize state connection
89
+ state = self.state_factory.state()
90
+
91
+ # Get IDs of runs in pending status
92
+ run_ids = state.get_run_ids(flwr_aid=None)
93
+ pending_run_ids = []
94
+ for run_id, status in state.get_run_status(run_ids).items():
95
+ if status.status == Status.PENDING:
96
+ pending_run_ids.append(run_id)
97
+
98
+ # Return run IDs
99
+ return ListAppsToLaunchResponse(run_ids=pending_run_ids)
100
+
101
+ def RequestToken(
102
+ self, request: RequestTokenRequest, context: grpc.ServicerContext
103
+ ) -> RequestTokenResponse:
104
+ """Request token."""
105
+ log(DEBUG, "SimulationIoServicer.RequestToken")
106
+
107
+ # Initialize state connection
108
+ state = self.state_factory.state()
109
+
110
+ # Attempt to create a token for the provided run ID
111
+ token = state.create_token(request.run_id)
112
+
113
+ # Return the token
114
+ return RequestTokenResponse(token=token or "")
115
+
116
+ def GetRun(
117
+ self, request: GetRunRequest, context: grpc.ServicerContext
118
+ ) -> GetRunResponse:
119
+ """Get run information."""
120
+ log(DEBUG, "SimulationIoServicer.GetRun")
121
+
122
+ # Init state
123
+ state = self.state_factory.state()
124
+
125
+ # Retrieve run information
126
+ run = state.get_run(request.run_id)
127
+
128
+ if run is None:
129
+ return GetRunResponse()
130
+
131
+ return GetRunResponse(run=run_to_proto(run))
132
+
133
+ def PullAppInputs(
134
+ self, request: PullAppInputsRequest, context: ServicerContext
135
+ ) -> PullAppInputsResponse:
77
136
  """Pull SimultionIo process inputs."""
78
137
  log(DEBUG, "SimultionIoServicer.SimultionIoInputs")
79
138
  # Init access to LinkState and Ffs
80
139
  state = self.state_factory.state()
81
140
  ffs = self.ffs_factory.ffs()
82
141
 
142
+ # Validate the token
143
+ run_id = self._verify_token(request.token, context)
144
+
83
145
  # Lock access to LinkState, preventing obtaining the same pending run_id
84
146
  with self.lock:
85
- # Attempt getting the run_id of a pending run
86
- run_id = state.get_pending_run_id()
87
- # If there's no pending run, return an empty response
88
- if run_id is None:
89
- return PullSimulationInputsResponse()
90
-
91
147
  # Retrieve Context, Run and Fab for the run_id
92
148
  serverapp_ctxt = state.get_serverapp_context(run_id)
93
149
  run = state.get_run(run_id)
@@ -99,7 +155,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
99
155
  # Update run status to STARTING
100
156
  if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
101
157
  log(INFO, "Starting run %d", run_id)
102
- return PullSimulationInputsResponse(
158
+ return PullAppInputsResponse(
103
159
  context=context_to_proto(serverapp_ctxt),
104
160
  run=run_to_proto(run),
105
161
  fab=fab_to_proto(fab),
@@ -109,11 +165,16 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
109
165
  # or if the status cannot be updated to STARTING
110
166
  raise RuntimeError(f"Failed to start run {run_id}")
111
167
 
112
- def PushSimulationOutputs(
113
- self, request: PushSimulationOutputsRequest, context: ServicerContext
114
- ) -> PushSimulationOutputsResponse:
168
+ def PushAppOutputs(
169
+ self, request: PushAppOutputsRequest, context: ServicerContext
170
+ ) -> PushAppOutputsResponse:
115
171
  """Push Simulation process outputs."""
116
- log(DEBUG, "SimultionIoServicer.PushSimulationOutputs")
172
+ log(DEBUG, "SimultionIoServicer.PushAppOutputs")
173
+
174
+ # Validate the token
175
+ run_id = self._verify_token(request.token, context)
176
+
177
+ # Init access to LinkState
117
178
  state = self.state_factory.state()
118
179
 
119
180
  # Abort if the run is not running
@@ -126,7 +187,10 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
126
187
  )
127
188
 
128
189
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
129
- return PushSimulationOutputsResponse()
190
+
191
+ # Remove the token
192
+ state.delete_token(run_id)
193
+ return PushAppOutputsResponse()
130
194
 
131
195
  def UpdateRunStatus(
132
196
  self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
@@ -208,3 +272,15 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
208
272
  )
209
273
 
210
274
  return SendAppHeartbeatResponse(success=success)
275
+
276
+ def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
277
+ """Verify the token and return the associated run ID."""
278
+ state = self.state_factory.state()
279
+ run_id = state.get_run_id_by_token(token)
280
+ if run_id is None or not state.verify_token(run_id, token):
281
+ context.abort(
282
+ grpc.StatusCode.PERMISSION_DENIED,
283
+ "Invalid token.",
284
+ )
285
+ raise RuntimeError("This line should never be reached.")
286
+ return run_id
@@ -13,3 +13,15 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """Public Flower ServerApp APIs."""
16
+
17
+
18
+ from flwr.server.grid import Grid
19
+ from flwr.server.server_app import ServerApp
20
+
21
+ from . import strategy
22
+
23
+ __all__ = [
24
+ "Grid",
25
+ "ServerApp",
26
+ "strategy",
27
+ ]
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower ServerApp exceptions."""
16
+
17
+
18
+ from flwr.app.exception import AppExitException
19
+ from flwr.common.exit import ExitCode
20
+
21
+
22
+ class InconsistentMessageReplies(AppExitException):
23
+ """Exception triggered when replies are inconsistent and therefore aggregation must
24
+ be skipped."""
25
+
26
+ exit_code = ExitCode.SERVERAPP_STRATEGY_PRECONDITION_UNMET
27
+
28
+ def __init__(self, reason: str):
29
+ super().__init__(reason)
30
+
31
+
32
+ class AggregationError(AppExitException):
33
+ """Exception triggered when aggregation fails."""
34
+
35
+ exit_code = ExitCode.SERVERAPP_STRATEGY_AGGREGATION_ERROR
36
+
37
+ def __init__(self, reason: str):
38
+ super().__init__(reason)
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ServerApp strategies."""
16
+
17
+
18
+ from .bulyan import Bulyan
19
+ from .dp_adaptive_clipping import (
20
+ DifferentialPrivacyClientSideAdaptiveClipping,
21
+ DifferentialPrivacyServerSideAdaptiveClipping,
22
+ )
23
+ from .dp_fixed_clipping import (
24
+ DifferentialPrivacyClientSideFixedClipping,
25
+ DifferentialPrivacyServerSideFixedClipping,
26
+ )
27
+ from .fedadagrad import FedAdagrad
28
+ from .fedadam import FedAdam
29
+ from .fedavg import FedAvg
30
+ from .fedavgm import FedAvgM
31
+ from .fedmedian import FedMedian
32
+ from .fedprox import FedProx
33
+ from .fedtrimmedavg import FedTrimmedAvg
34
+ from .fedxgb_bagging import FedXgbBagging
35
+ from .fedxgb_cyclic import FedXgbCyclic
36
+ from .fedyogi import FedYogi
37
+ from .krum import Krum
38
+ from .multikrum import MultiKrum
39
+ from .qfedavg import QFedAvg
40
+ from .result import Result
41
+ from .strategy import Strategy
42
+
43
+ __all__ = [
44
+ "Bulyan",
45
+ "DifferentialPrivacyClientSideAdaptiveClipping",
46
+ "DifferentialPrivacyClientSideFixedClipping",
47
+ "DifferentialPrivacyServerSideAdaptiveClipping",
48
+ "DifferentialPrivacyServerSideFixedClipping",
49
+ "FedAdagrad",
50
+ "FedAdam",
51
+ "FedAvg",
52
+ "FedAvgM",
53
+ "FedMedian",
54
+ "FedProx",
55
+ "FedTrimmedAvg",
56
+ "FedXgbBagging",
57
+ "FedXgbCyclic",
58
+ "FedYogi",
59
+ "Krum",
60
+ "MultiKrum",
61
+ "QFedAvg",
62
+ "Result",
63
+ "Strategy",
64
+ ]
@@ -0,0 +1,238 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Bulyan [El Mhamdi et al., 2018] strategy.
16
+
17
+ Paper: arxiv.org/abs/1802.07927
18
+ """
19
+
20
+
21
+ from collections import OrderedDict
22
+ from collections.abc import Iterable
23
+ from logging import INFO, WARN
24
+ from typing import Callable, Optional, cast
25
+
26
+ import numpy as np
27
+
28
+ from flwr.common import (
29
+ Array,
30
+ ArrayRecord,
31
+ Message,
32
+ MetricRecord,
33
+ NDArrays,
34
+ RecordDict,
35
+ log,
36
+ )
37
+
38
+ from .fedavg import FedAvg
39
+ from .multikrum import select_multikrum
40
+
41
+
42
+ # pylint: disable=too-many-instance-attributes
43
+ class Bulyan(FedAvg):
44
+ """Bulyan strategy.
45
+
46
+ Implementation based on https://arxiv.org/abs/1802.07927.
47
+
48
+ Parameters
49
+ ----------
50
+ fraction_train : float (default: 1.0)
51
+ Fraction of nodes used during training. In case `min_train_nodes`
52
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
53
+ will still be sampled.
54
+ fraction_evaluate : float (default: 1.0)
55
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
56
+ is larger than `fraction_evaluate * total_connected_nodes`,
57
+ `min_evaluate_nodes` will still be sampled.
58
+ min_train_nodes : int (default: 2)
59
+ Minimum number of nodes used during training.
60
+ min_evaluate_nodes : int (default: 2)
61
+ Minimum number of nodes used during validation.
62
+ min_available_nodes : int (default: 2)
63
+ Minimum number of total nodes in the system.
64
+ num_malicious_nodes : int (default: 0)
65
+ Number of malicious nodes in the system.
66
+ weighted_by_key : str (default: "num-examples")
67
+ The key within each MetricRecord whose value is used as the weight when
68
+ computing weighted averages for MetricRecords.
69
+ arrayrecord_key : str (default: "arrays")
70
+ Key used to store the ArrayRecord when constructing Messages.
71
+ configrecord_key : str (default: "config")
72
+ Key used to store the ConfigRecord when constructing Messages.
73
+ train_metrics_aggr_fn : Optional[callable] (default: None)
74
+ Function with signature (list[RecordDict], str) -> MetricRecord,
75
+ used to aggregate MetricRecords from training round replies.
76
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
77
+ average using the provided weight factor key.
78
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
79
+ Function with signature (list[RecordDict], str) -> MetricRecord,
80
+ used to aggregate MetricRecords from training round replies.
81
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
82
+ average using the provided weight factor key.
83
+ selection_rule : Optional[Callable] (default: None)
84
+ Function with signature (list[RecordDict], int, int) -> list[RecordDict].
85
+ The inputs are:
86
+ - a list of contents from reply messages,
87
+ - the assumed number of malicious nodes (`num_malicious_nodes`),
88
+ - the number of nodes to select (`num_nodes_to_select`).
89
+
90
+ The function should implement a Byzantine-resilient selection rule that
91
+ serves as the first step of Bulyan. If None, defaults to `select_multikrum`,
92
+ which selects nodes according to the Multi-Krum algorithm.
93
+ """
94
+
95
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
96
+ def __init__(
97
+ self,
98
+ fraction_train: float = 1.0,
99
+ fraction_evaluate: float = 1.0,
100
+ min_train_nodes: int = 2,
101
+ min_evaluate_nodes: int = 2,
102
+ min_available_nodes: int = 2,
103
+ num_malicious_nodes: int = 0,
104
+ weighted_by_key: str = "num-examples",
105
+ arrayrecord_key: str = "arrays",
106
+ configrecord_key: str = "config",
107
+ train_metrics_aggr_fn: Optional[
108
+ Callable[[list[RecordDict], str], MetricRecord]
109
+ ] = None,
110
+ evaluate_metrics_aggr_fn: Optional[
111
+ Callable[[list[RecordDict], str], MetricRecord]
112
+ ] = None,
113
+ selection_rule: Optional[
114
+ Callable[[list[RecordDict], int, int], list[RecordDict]]
115
+ ] = None,
116
+ ) -> None:
117
+ super().__init__(
118
+ fraction_train=fraction_train,
119
+ fraction_evaluate=fraction_evaluate,
120
+ min_train_nodes=min_train_nodes,
121
+ min_evaluate_nodes=min_evaluate_nodes,
122
+ min_available_nodes=min_available_nodes,
123
+ weighted_by_key=weighted_by_key,
124
+ arrayrecord_key=arrayrecord_key,
125
+ configrecord_key=configrecord_key,
126
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
127
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
128
+ )
129
+ self.num_malicious_nodes = num_malicious_nodes
130
+ self.selection_rule = selection_rule or select_multikrum
131
+
132
+ def summary(self) -> None:
133
+ """Log summary configuration of the strategy."""
134
+ log(INFO, "\t├──> Bulyan settings:")
135
+ log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
136
+ log(INFO, "\t│\t└── Selection rule: %s", self.selection_rule.__name__)
137
+ super().summary()
138
+
139
+ def aggregate_train(
140
+ self,
141
+ server_round: int,
142
+ replies: Iterable[Message],
143
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
144
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
145
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
146
+
147
+ # Check if sufficient replies have been received
148
+ if len(valid_replies) < 4 * self.num_malicious_nodes + 3:
149
+ log(
150
+ WARN,
151
+ "Insufficient replies, skipping Bulyan aggregation: "
152
+ "Required at least %d (4*num_malicious_nodes + 3), but received %d.",
153
+ 4 * self.num_malicious_nodes + 3,
154
+ len(valid_replies),
155
+ )
156
+ return None, None
157
+
158
+ reply_contents = [msg.content for msg in valid_replies]
159
+
160
+ # Compute theta and beta
161
+ theta = len(valid_replies) - 2 * self.num_malicious_nodes
162
+ beta = theta - 2 * self.num_malicious_nodes
163
+
164
+ # Byzantine-resilient selection rule
165
+ selected_contents = self.selection_rule(
166
+ reply_contents, self.num_malicious_nodes, theta
167
+ )
168
+
169
+ # Convert each ArrayRecord to a list of NDArray for easier computation
170
+ key = list(selected_contents[0].array_records.keys())[0]
171
+ array_keys = list(selected_contents[0][key].keys())
172
+ selected_ndarrays = [
173
+ cast(ArrayRecord, ctnt[key]).to_numpy_ndarrays(keep_input=False)
174
+ for ctnt in selected_contents
175
+ ]
176
+
177
+ # Compute median
178
+ median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
179
+
180
+ # Aggregate the beta closest weights element-wise
181
+ aggregated_ndarrays = aggregate_n_closest_weights(
182
+ median_ndarrays, selected_ndarrays, beta
183
+ )
184
+
185
+ # Convert to ArrayRecord
186
+ arrays = ArrayRecord(
187
+ OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
188
+ )
189
+
190
+ # Aggregate MetricRecords
191
+ metrics = self.train_metrics_aggr_fn(
192
+ selected_contents,
193
+ self.weighted_by_key,
194
+ )
195
+ return arrays, metrics
196
+
197
+
198
+ def aggregate_n_closest_weights(
199
+ ref_weights: NDArrays, weights_list: list[NDArrays], beta: int
200
+ ) -> NDArrays:
201
+ """Compute the element-wise mean of the `beta` closest weight arrays.
202
+
203
+ For each element (i-th coordinate), the output is the average of the
204
+ `beta` weight arrays that are closest to the reference weights.
205
+
206
+ Parameters
207
+ ----------
208
+ ref_weights : NDArrays
209
+ Reference weights used to compute distances.
210
+ weights_list : list[NDArrays]
211
+ List of weight arrays (e.g., from selected nodes).
212
+ beta : int
213
+ Number of closest weight arrays to include in the averaging.
214
+
215
+ Returns
216
+ -------
217
+ aggregated_weights : NDArrays
218
+ Element-wise average of the `beta` closest weight arrays to the
219
+ reference weights.
220
+ """
221
+ aggregated_weights = []
222
+ for layer_id, ref_layer in enumerate(ref_weights):
223
+ # Shape: (n_models, *layer_shape)
224
+ layer_stack = np.stack([weights[layer_id] for weights in weights_list])
225
+
226
+ # Compute absolute differences: shape (n_models, *layer_shape)
227
+ diffs = np.abs(layer_stack - ref_layer)
228
+
229
+ # Find indices of `beta` smallest per coordinate
230
+ idx = np.argpartition(diffs, beta - 1, axis=0)[:beta]
231
+
232
+ # Gather the closest weights
233
+ closest = np.take_along_axis(layer_stack, idx, axis=0)
234
+
235
+ # Average them
236
+ aggregated_weights.append(np.mean(closest, axis=0))
237
+
238
+ return aggregated_weights