flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/app.py +2 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +15 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +9 -13
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +36 -8
- flwr/client/grpc_rere_client/connection.py +1 -12
- flwr/client/rest_client/connection.py +3 -0
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +29 -0
- flwr/clientapp/mod/centraldp_mods.py +248 -0
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -4
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +30 -7
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +0 -11
- flwr/common/inflatable_utils.py +1 -1
- flwr/common/logger.py +1 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/telemetry.py +4 -0
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +25 -17
- flwr/proto/appio_pb2.pyi +46 -2
- flwr/proto/clientappio_pb2.py +3 -11
- flwr/proto/clientappio_pb2.pyi +0 -47
- flwr/proto/clientappio_pb2_grpc.py +19 -20
- flwr/proto/clientappio_pb2_grpc.pyi +10 -11
- flwr/proto/control_pb2.py +66 -0
- flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
- flwr/proto/control_pb2_grpc.pyi +106 -0
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +68 -0
- flwr/proto/serverappio_pb2_grpc.pyi +26 -0
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +142 -152
- flwr/server/grid/grpc_grid.py +3 -0
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +157 -146
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
- flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +64 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
- flwr/serverapp/strategy/fedadagrad.py +159 -0
- flwr/serverapp/strategy/fedadam.py +178 -0
- flwr/serverapp/strategy/fedavg.py +320 -0
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +170 -0
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +299 -0
- flwr/simulation/app.py +161 -164
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/app_utils.py +58 -0
- flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
- flwr/supercore/cli/flower_superexec.py +166 -0
- flwr/supercore/constant.py +19 -0
- flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/supercore/grpc_health/__init__.py +3 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +199 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
- flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
- flwr/supernode/cli/flower_supernode.py +3 -0
- flwr/supernode/cli/flwr_clientapp.py +18 -21
- flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
- flwr/supernode/nodestate/nodestate.py +3 -59
- flwr/supernode/runtime/run_clientapp.py +39 -102
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
- flwr/supernode/start_client_internal.py +35 -76
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
- flwr/proto/exec_pb2.py +0 -62
- flwr/proto/exec_pb2_grpc.pyi +0 -93
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -191
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -129
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
|
@@ -43,6 +43,8 @@ from flwr.common.serde import (
|
|
|
43
43
|
from flwr.common.typing import Fab, RunStatus
|
|
44
44
|
from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
|
|
45
45
|
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
|
46
|
+
ListAppsToLaunchRequest,
|
|
47
|
+
ListAppsToLaunchResponse,
|
|
46
48
|
PullAppInputsRequest,
|
|
47
49
|
PullAppInputsResponse,
|
|
48
50
|
PullAppMessagesRequest,
|
|
@@ -51,6 +53,8 @@ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
|
|
51
53
|
PushAppMessagesResponse,
|
|
52
54
|
PushAppOutputsRequest,
|
|
53
55
|
PushAppOutputsResponse,
|
|
56
|
+
RequestTokenRequest,
|
|
57
|
+
RequestTokenResponse,
|
|
54
58
|
)
|
|
55
59
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
56
60
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
@@ -104,6 +108,42 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
104
108
|
self.objectstore_factory = objectstore_factory
|
|
105
109
|
self.lock = threading.RLock()
|
|
106
110
|
|
|
111
|
+
def ListAppsToLaunch(
|
|
112
|
+
self,
|
|
113
|
+
request: ListAppsToLaunchRequest,
|
|
114
|
+
context: grpc.ServicerContext,
|
|
115
|
+
) -> ListAppsToLaunchResponse:
|
|
116
|
+
"""Get run IDs with pending messages."""
|
|
117
|
+
log(DEBUG, "ServerAppIoServicer.ListAppsToLaunch")
|
|
118
|
+
|
|
119
|
+
# Initialize state connection
|
|
120
|
+
state = self.state_factory.state()
|
|
121
|
+
|
|
122
|
+
# Get IDs of runs in pending status
|
|
123
|
+
run_ids = state.get_run_ids(flwr_aid=None)
|
|
124
|
+
pending_run_ids = []
|
|
125
|
+
for run_id, status in state.get_run_status(run_ids).items():
|
|
126
|
+
if status.status == Status.PENDING:
|
|
127
|
+
pending_run_ids.append(run_id)
|
|
128
|
+
|
|
129
|
+
# Return run IDs
|
|
130
|
+
return ListAppsToLaunchResponse(run_ids=pending_run_ids)
|
|
131
|
+
|
|
132
|
+
def RequestToken(
|
|
133
|
+
self, request: RequestTokenRequest, context: grpc.ServicerContext
|
|
134
|
+
) -> RequestTokenResponse:
|
|
135
|
+
"""Request token."""
|
|
136
|
+
log(DEBUG, "ServerAppIoServicer.RequestToken")
|
|
137
|
+
|
|
138
|
+
# Initialize state connection
|
|
139
|
+
state = self.state_factory.state()
|
|
140
|
+
|
|
141
|
+
# Attempt to create a token for the provided run ID
|
|
142
|
+
token = state.create_token(request.run_id)
|
|
143
|
+
|
|
144
|
+
# Return the token
|
|
145
|
+
return RequestTokenResponse(token=token or "")
|
|
146
|
+
|
|
107
147
|
def GetNodes(
|
|
108
148
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
|
109
149
|
) -> GetNodesResponse:
|
|
@@ -289,14 +329,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
289
329
|
# Init access to LinkState
|
|
290
330
|
state = self.state_factory.state()
|
|
291
331
|
|
|
332
|
+
# Validate the token
|
|
333
|
+
run_id = self._verify_token(request.token, context)
|
|
334
|
+
|
|
292
335
|
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
293
336
|
with self.lock:
|
|
294
|
-
# Attempt getting the run_id of a pending run
|
|
295
|
-
run_id = state.get_pending_run_id()
|
|
296
|
-
# If there's no pending run, return an empty response
|
|
297
|
-
if run_id is None:
|
|
298
|
-
return PullAppInputsResponse()
|
|
299
|
-
|
|
300
337
|
# Init access to Ffs
|
|
301
338
|
ffs = self.ffs_factory.ffs()
|
|
302
339
|
|
|
@@ -327,6 +364,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
327
364
|
"""Push ServerApp process outputs."""
|
|
328
365
|
log(DEBUG, "ServerAppIoServicer.PushAppOutputs")
|
|
329
366
|
|
|
367
|
+
# Validate the token
|
|
368
|
+
run_id = self._verify_token(request.token, context)
|
|
369
|
+
|
|
330
370
|
# Init state and store
|
|
331
371
|
state = self.state_factory.state()
|
|
332
372
|
store = self.objectstore_factory.store()
|
|
@@ -341,6 +381,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
341
381
|
)
|
|
342
382
|
|
|
343
383
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
384
|
+
|
|
385
|
+
# Remove the token
|
|
386
|
+
state.delete_token(run_id)
|
|
344
387
|
return PushAppOutputsResponse()
|
|
345
388
|
|
|
346
389
|
def UpdateRunStatus(
|
|
@@ -508,6 +551,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
508
551
|
|
|
509
552
|
return ConfirmMessageReceivedResponse()
|
|
510
553
|
|
|
554
|
+
def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
|
|
555
|
+
"""Verify the token and return the associated run ID."""
|
|
556
|
+
state = self.state_factory.state()
|
|
557
|
+
run_id = state.get_run_id_by_token(token)
|
|
558
|
+
if run_id is None or not state.verify_token(run_id, token):
|
|
559
|
+
context.abort(
|
|
560
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
|
561
|
+
"Invalid token.",
|
|
562
|
+
)
|
|
563
|
+
raise RuntimeError("This line should never be reached.")
|
|
564
|
+
return run_id
|
|
565
|
+
|
|
511
566
|
|
|
512
567
|
def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
|
|
513
568
|
"""Raise a `ValueError` with a detailed message if a validation error occurs."""
|
|
@@ -34,6 +34,16 @@ from flwr.common.serde import (
|
|
|
34
34
|
)
|
|
35
35
|
from flwr.common.typing import Fab, RunStatus
|
|
36
36
|
from flwr.proto import simulationio_pb2_grpc
|
|
37
|
+
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
|
38
|
+
ListAppsToLaunchRequest,
|
|
39
|
+
ListAppsToLaunchResponse,
|
|
40
|
+
PullAppInputsRequest,
|
|
41
|
+
PullAppInputsResponse,
|
|
42
|
+
PushAppOutputsRequest,
|
|
43
|
+
PushAppOutputsResponse,
|
|
44
|
+
RequestTokenRequest,
|
|
45
|
+
RequestTokenResponse,
|
|
46
|
+
)
|
|
37
47
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
38
48
|
SendAppHeartbeatRequest,
|
|
39
49
|
SendAppHeartbeatResponse,
|
|
@@ -45,17 +55,13 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
|
45
55
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
46
56
|
GetFederationOptionsRequest,
|
|
47
57
|
GetFederationOptionsResponse,
|
|
58
|
+
GetRunRequest,
|
|
59
|
+
GetRunResponse,
|
|
48
60
|
GetRunStatusRequest,
|
|
49
61
|
GetRunStatusResponse,
|
|
50
62
|
UpdateRunStatusRequest,
|
|
51
63
|
UpdateRunStatusResponse,
|
|
52
64
|
)
|
|
53
|
-
from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
|
|
54
|
-
PullSimulationInputsRequest,
|
|
55
|
-
PullSimulationInputsResponse,
|
|
56
|
-
PushSimulationOutputsRequest,
|
|
57
|
-
PushSimulationOutputsResponse,
|
|
58
|
-
)
|
|
59
65
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
60
66
|
from flwr.server.superlink.utils import abort_if
|
|
61
67
|
from flwr.supercore.ffs import FfsFactory
|
|
@@ -71,23 +77,73 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
71
77
|
self.ffs_factory = ffs_factory
|
|
72
78
|
self.lock = threading.RLock()
|
|
73
79
|
|
|
74
|
-
def
|
|
75
|
-
self,
|
|
76
|
-
|
|
80
|
+
def ListAppsToLaunch(
|
|
81
|
+
self,
|
|
82
|
+
request: ListAppsToLaunchRequest,
|
|
83
|
+
context: grpc.ServicerContext,
|
|
84
|
+
) -> ListAppsToLaunchResponse:
|
|
85
|
+
"""Get run IDs with pending messages."""
|
|
86
|
+
log(DEBUG, "SimulationIoServicer.ListAppsToLaunch")
|
|
87
|
+
|
|
88
|
+
# Initialize state connection
|
|
89
|
+
state = self.state_factory.state()
|
|
90
|
+
|
|
91
|
+
# Get IDs of runs in pending status
|
|
92
|
+
run_ids = state.get_run_ids(flwr_aid=None)
|
|
93
|
+
pending_run_ids = []
|
|
94
|
+
for run_id, status in state.get_run_status(run_ids).items():
|
|
95
|
+
if status.status == Status.PENDING:
|
|
96
|
+
pending_run_ids.append(run_id)
|
|
97
|
+
|
|
98
|
+
# Return run IDs
|
|
99
|
+
return ListAppsToLaunchResponse(run_ids=pending_run_ids)
|
|
100
|
+
|
|
101
|
+
def RequestToken(
|
|
102
|
+
self, request: RequestTokenRequest, context: grpc.ServicerContext
|
|
103
|
+
) -> RequestTokenResponse:
|
|
104
|
+
"""Request token."""
|
|
105
|
+
log(DEBUG, "SimulationIoServicer.RequestToken")
|
|
106
|
+
|
|
107
|
+
# Initialize state connection
|
|
108
|
+
state = self.state_factory.state()
|
|
109
|
+
|
|
110
|
+
# Attempt to create a token for the provided run ID
|
|
111
|
+
token = state.create_token(request.run_id)
|
|
112
|
+
|
|
113
|
+
# Return the token
|
|
114
|
+
return RequestTokenResponse(token=token or "")
|
|
115
|
+
|
|
116
|
+
def GetRun(
|
|
117
|
+
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
118
|
+
) -> GetRunResponse:
|
|
119
|
+
"""Get run information."""
|
|
120
|
+
log(DEBUG, "SimulationIoServicer.GetRun")
|
|
121
|
+
|
|
122
|
+
# Init state
|
|
123
|
+
state = self.state_factory.state()
|
|
124
|
+
|
|
125
|
+
# Retrieve run information
|
|
126
|
+
run = state.get_run(request.run_id)
|
|
127
|
+
|
|
128
|
+
if run is None:
|
|
129
|
+
return GetRunResponse()
|
|
130
|
+
|
|
131
|
+
return GetRunResponse(run=run_to_proto(run))
|
|
132
|
+
|
|
133
|
+
def PullAppInputs(
|
|
134
|
+
self, request: PullAppInputsRequest, context: ServicerContext
|
|
135
|
+
) -> PullAppInputsResponse:
|
|
77
136
|
"""Pull SimultionIo process inputs."""
|
|
78
137
|
log(DEBUG, "SimultionIoServicer.SimultionIoInputs")
|
|
79
138
|
# Init access to LinkState and Ffs
|
|
80
139
|
state = self.state_factory.state()
|
|
81
140
|
ffs = self.ffs_factory.ffs()
|
|
82
141
|
|
|
142
|
+
# Validate the token
|
|
143
|
+
run_id = self._verify_token(request.token, context)
|
|
144
|
+
|
|
83
145
|
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
84
146
|
with self.lock:
|
|
85
|
-
# Attempt getting the run_id of a pending run
|
|
86
|
-
run_id = state.get_pending_run_id()
|
|
87
|
-
# If there's no pending run, return an empty response
|
|
88
|
-
if run_id is None:
|
|
89
|
-
return PullSimulationInputsResponse()
|
|
90
|
-
|
|
91
147
|
# Retrieve Context, Run and Fab for the run_id
|
|
92
148
|
serverapp_ctxt = state.get_serverapp_context(run_id)
|
|
93
149
|
run = state.get_run(run_id)
|
|
@@ -99,7 +155,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
99
155
|
# Update run status to STARTING
|
|
100
156
|
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
|
101
157
|
log(INFO, "Starting run %d", run_id)
|
|
102
|
-
return
|
|
158
|
+
return PullAppInputsResponse(
|
|
103
159
|
context=context_to_proto(serverapp_ctxt),
|
|
104
160
|
run=run_to_proto(run),
|
|
105
161
|
fab=fab_to_proto(fab),
|
|
@@ -109,11 +165,16 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
109
165
|
# or if the status cannot be updated to STARTING
|
|
110
166
|
raise RuntimeError(f"Failed to start run {run_id}")
|
|
111
167
|
|
|
112
|
-
def
|
|
113
|
-
self, request:
|
|
114
|
-
) ->
|
|
168
|
+
def PushAppOutputs(
|
|
169
|
+
self, request: PushAppOutputsRequest, context: ServicerContext
|
|
170
|
+
) -> PushAppOutputsResponse:
|
|
115
171
|
"""Push Simulation process outputs."""
|
|
116
|
-
log(DEBUG, "SimultionIoServicer.
|
|
172
|
+
log(DEBUG, "SimultionIoServicer.PushAppOutputs")
|
|
173
|
+
|
|
174
|
+
# Validate the token
|
|
175
|
+
run_id = self._verify_token(request.token, context)
|
|
176
|
+
|
|
177
|
+
# Init access to LinkState
|
|
117
178
|
state = self.state_factory.state()
|
|
118
179
|
|
|
119
180
|
# Abort if the run is not running
|
|
@@ -126,7 +187,10 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
126
187
|
)
|
|
127
188
|
|
|
128
189
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
129
|
-
|
|
190
|
+
|
|
191
|
+
# Remove the token
|
|
192
|
+
state.delete_token(run_id)
|
|
193
|
+
return PushAppOutputsResponse()
|
|
130
194
|
|
|
131
195
|
def UpdateRunStatus(
|
|
132
196
|
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
@@ -208,3 +272,15 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
208
272
|
)
|
|
209
273
|
|
|
210
274
|
return SendAppHeartbeatResponse(success=success)
|
|
275
|
+
|
|
276
|
+
def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
|
|
277
|
+
"""Verify the token and return the associated run ID."""
|
|
278
|
+
state = self.state_factory.state()
|
|
279
|
+
run_id = state.get_run_id_by_token(token)
|
|
280
|
+
if run_id is None or not state.verify_token(run_id, token):
|
|
281
|
+
context.abort(
|
|
282
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
|
283
|
+
"Invalid token.",
|
|
284
|
+
)
|
|
285
|
+
raise RuntimeError("This line should never be reached.")
|
|
286
|
+
return run_id
|
flwr/serverapp/__init__.py
CHANGED
|
@@ -13,3 +13,15 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Public Flower ServerApp APIs."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from flwr.server.grid import Grid
|
|
19
|
+
from flwr.server.server_app import ServerApp
|
|
20
|
+
|
|
21
|
+
from . import strategy
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"Grid",
|
|
25
|
+
"ServerApp",
|
|
26
|
+
"strategy",
|
|
27
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower ServerApp exceptions."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from flwr.app.exception import AppExitException
|
|
19
|
+
from flwr.common.exit import ExitCode
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class InconsistentMessageReplies(AppExitException):
|
|
23
|
+
"""Exception triggered when replies are inconsistent and therefore aggregation must
|
|
24
|
+
be skipped."""
|
|
25
|
+
|
|
26
|
+
exit_code = ExitCode.SERVERAPP_STRATEGY_PRECONDITION_UNMET
|
|
27
|
+
|
|
28
|
+
def __init__(self, reason: str):
|
|
29
|
+
super().__init__(reason)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AggregationError(AppExitException):
|
|
33
|
+
"""Exception triggered when aggregation fails."""
|
|
34
|
+
|
|
35
|
+
exit_code = ExitCode.SERVERAPP_STRATEGY_AGGREGATION_ERROR
|
|
36
|
+
|
|
37
|
+
def __init__(self, reason: str):
|
|
38
|
+
super().__init__(reason)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ServerApp strategies."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .bulyan import Bulyan
|
|
19
|
+
from .dp_adaptive_clipping import (
|
|
20
|
+
DifferentialPrivacyClientSideAdaptiveClipping,
|
|
21
|
+
DifferentialPrivacyServerSideAdaptiveClipping,
|
|
22
|
+
)
|
|
23
|
+
from .dp_fixed_clipping import (
|
|
24
|
+
DifferentialPrivacyClientSideFixedClipping,
|
|
25
|
+
DifferentialPrivacyServerSideFixedClipping,
|
|
26
|
+
)
|
|
27
|
+
from .fedadagrad import FedAdagrad
|
|
28
|
+
from .fedadam import FedAdam
|
|
29
|
+
from .fedavg import FedAvg
|
|
30
|
+
from .fedavgm import FedAvgM
|
|
31
|
+
from .fedmedian import FedMedian
|
|
32
|
+
from .fedprox import FedProx
|
|
33
|
+
from .fedtrimmedavg import FedTrimmedAvg
|
|
34
|
+
from .fedxgb_bagging import FedXgbBagging
|
|
35
|
+
from .fedxgb_cyclic import FedXgbCyclic
|
|
36
|
+
from .fedyogi import FedYogi
|
|
37
|
+
from .krum import Krum
|
|
38
|
+
from .multikrum import MultiKrum
|
|
39
|
+
from .qfedavg import QFedAvg
|
|
40
|
+
from .result import Result
|
|
41
|
+
from .strategy import Strategy
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"Bulyan",
|
|
45
|
+
"DifferentialPrivacyClientSideAdaptiveClipping",
|
|
46
|
+
"DifferentialPrivacyClientSideFixedClipping",
|
|
47
|
+
"DifferentialPrivacyServerSideAdaptiveClipping",
|
|
48
|
+
"DifferentialPrivacyServerSideFixedClipping",
|
|
49
|
+
"FedAdagrad",
|
|
50
|
+
"FedAdam",
|
|
51
|
+
"FedAvg",
|
|
52
|
+
"FedAvgM",
|
|
53
|
+
"FedMedian",
|
|
54
|
+
"FedProx",
|
|
55
|
+
"FedTrimmedAvg",
|
|
56
|
+
"FedXgbBagging",
|
|
57
|
+
"FedXgbCyclic",
|
|
58
|
+
"FedYogi",
|
|
59
|
+
"Krum",
|
|
60
|
+
"MultiKrum",
|
|
61
|
+
"QFedAvg",
|
|
62
|
+
"Result",
|
|
63
|
+
"Strategy",
|
|
64
|
+
]
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Bulyan [El Mhamdi et al., 2018] strategy.
|
|
16
|
+
|
|
17
|
+
Paper: arxiv.org/abs/1802.07927
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from collections import OrderedDict
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
from logging import INFO, WARN
|
|
24
|
+
from typing import Callable, Optional, cast
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from flwr.common import (
|
|
29
|
+
Array,
|
|
30
|
+
ArrayRecord,
|
|
31
|
+
Message,
|
|
32
|
+
MetricRecord,
|
|
33
|
+
NDArrays,
|
|
34
|
+
RecordDict,
|
|
35
|
+
log,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from .fedavg import FedAvg
|
|
39
|
+
from .multikrum import select_multikrum
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# pylint: disable=too-many-instance-attributes
|
|
43
|
+
class Bulyan(FedAvg):
|
|
44
|
+
"""Bulyan strategy.
|
|
45
|
+
|
|
46
|
+
Implementation based on https://arxiv.org/abs/1802.07927.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
fraction_train : float (default: 1.0)
|
|
51
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
|
52
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
|
53
|
+
will still be sampled.
|
|
54
|
+
fraction_evaluate : float (default: 1.0)
|
|
55
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
|
56
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
|
57
|
+
`min_evaluate_nodes` will still be sampled.
|
|
58
|
+
min_train_nodes : int (default: 2)
|
|
59
|
+
Minimum number of nodes used during training.
|
|
60
|
+
min_evaluate_nodes : int (default: 2)
|
|
61
|
+
Minimum number of nodes used during validation.
|
|
62
|
+
min_available_nodes : int (default: 2)
|
|
63
|
+
Minimum number of total nodes in the system.
|
|
64
|
+
num_malicious_nodes : int (default: 0)
|
|
65
|
+
Number of malicious nodes in the system.
|
|
66
|
+
weighted_by_key : str (default: "num-examples")
|
|
67
|
+
The key within each MetricRecord whose value is used as the weight when
|
|
68
|
+
computing weighted averages for MetricRecords.
|
|
69
|
+
arrayrecord_key : str (default: "arrays")
|
|
70
|
+
Key used to store the ArrayRecord when constructing Messages.
|
|
71
|
+
configrecord_key : str (default: "config")
|
|
72
|
+
Key used to store the ConfigRecord when constructing Messages.
|
|
73
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
|
74
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
75
|
+
used to aggregate MetricRecords from training round replies.
|
|
76
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
77
|
+
average using the provided weight factor key.
|
|
78
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
|
79
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
80
|
+
used to aggregate MetricRecords from training round replies.
|
|
81
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
82
|
+
average using the provided weight factor key.
|
|
83
|
+
selection_rule : Optional[Callable] (default: None)
|
|
84
|
+
Function with signature (list[RecordDict], int, int) -> list[RecordDict].
|
|
85
|
+
The inputs are:
|
|
86
|
+
- a list of contents from reply messages,
|
|
87
|
+
- the assumed number of malicious nodes (`num_malicious_nodes`),
|
|
88
|
+
- the number of nodes to select (`num_nodes_to_select`).
|
|
89
|
+
|
|
90
|
+
The function should implement a Byzantine-resilient selection rule that
|
|
91
|
+
serves as the first step of Bulyan. If None, defaults to `select_multikrum`,
|
|
92
|
+
which selects nodes according to the Multi-Krum algorithm.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
fraction_train: float = 1.0,
|
|
99
|
+
fraction_evaluate: float = 1.0,
|
|
100
|
+
min_train_nodes: int = 2,
|
|
101
|
+
min_evaluate_nodes: int = 2,
|
|
102
|
+
min_available_nodes: int = 2,
|
|
103
|
+
num_malicious_nodes: int = 0,
|
|
104
|
+
weighted_by_key: str = "num-examples",
|
|
105
|
+
arrayrecord_key: str = "arrays",
|
|
106
|
+
configrecord_key: str = "config",
|
|
107
|
+
train_metrics_aggr_fn: Optional[
|
|
108
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
109
|
+
] = None,
|
|
110
|
+
evaluate_metrics_aggr_fn: Optional[
|
|
111
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
112
|
+
] = None,
|
|
113
|
+
selection_rule: Optional[
|
|
114
|
+
Callable[[list[RecordDict], int, int], list[RecordDict]]
|
|
115
|
+
] = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
super().__init__(
|
|
118
|
+
fraction_train=fraction_train,
|
|
119
|
+
fraction_evaluate=fraction_evaluate,
|
|
120
|
+
min_train_nodes=min_train_nodes,
|
|
121
|
+
min_evaluate_nodes=min_evaluate_nodes,
|
|
122
|
+
min_available_nodes=min_available_nodes,
|
|
123
|
+
weighted_by_key=weighted_by_key,
|
|
124
|
+
arrayrecord_key=arrayrecord_key,
|
|
125
|
+
configrecord_key=configrecord_key,
|
|
126
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
|
127
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
|
128
|
+
)
|
|
129
|
+
self.num_malicious_nodes = num_malicious_nodes
|
|
130
|
+
self.selection_rule = selection_rule or select_multikrum
|
|
131
|
+
|
|
132
|
+
def summary(self) -> None:
|
|
133
|
+
"""Log summary configuration of the strategy."""
|
|
134
|
+
log(INFO, "\t├──> Bulyan settings:")
|
|
135
|
+
log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
|
|
136
|
+
log(INFO, "\t│\t└── Selection rule: %s", self.selection_rule.__name__)
|
|
137
|
+
super().summary()
|
|
138
|
+
|
|
139
|
+
def aggregate_train(
|
|
140
|
+
self,
|
|
141
|
+
server_round: int,
|
|
142
|
+
replies: Iterable[Message],
|
|
143
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
144
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
145
|
+
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
146
|
+
|
|
147
|
+
# Check if sufficient replies have been received
|
|
148
|
+
if len(valid_replies) < 4 * self.num_malicious_nodes + 3:
|
|
149
|
+
log(
|
|
150
|
+
WARN,
|
|
151
|
+
"Insufficient replies, skipping Bulyan aggregation: "
|
|
152
|
+
"Required at least %d (4*num_malicious_nodes + 3), but received %d.",
|
|
153
|
+
4 * self.num_malicious_nodes + 3,
|
|
154
|
+
len(valid_replies),
|
|
155
|
+
)
|
|
156
|
+
return None, None
|
|
157
|
+
|
|
158
|
+
reply_contents = [msg.content for msg in valid_replies]
|
|
159
|
+
|
|
160
|
+
# Compute theta and beta
|
|
161
|
+
theta = len(valid_replies) - 2 * self.num_malicious_nodes
|
|
162
|
+
beta = theta - 2 * self.num_malicious_nodes
|
|
163
|
+
|
|
164
|
+
# Byzantine-resilient selection rule
|
|
165
|
+
selected_contents = self.selection_rule(
|
|
166
|
+
reply_contents, self.num_malicious_nodes, theta
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Convert each ArrayRecord to a list of NDArray for easier computation
|
|
170
|
+
key = list(selected_contents[0].array_records.keys())[0]
|
|
171
|
+
array_keys = list(selected_contents[0][key].keys())
|
|
172
|
+
selected_ndarrays = [
|
|
173
|
+
cast(ArrayRecord, ctnt[key]).to_numpy_ndarrays(keep_input=False)
|
|
174
|
+
for ctnt in selected_contents
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
# Compute median
|
|
178
|
+
median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
|
|
179
|
+
|
|
180
|
+
# Aggregate the beta closest weights element-wise
|
|
181
|
+
aggregated_ndarrays = aggregate_n_closest_weights(
|
|
182
|
+
median_ndarrays, selected_ndarrays, beta
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Convert to ArrayRecord
|
|
186
|
+
arrays = ArrayRecord(
|
|
187
|
+
OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Aggregate MetricRecords
|
|
191
|
+
metrics = self.train_metrics_aggr_fn(
|
|
192
|
+
selected_contents,
|
|
193
|
+
self.weighted_by_key,
|
|
194
|
+
)
|
|
195
|
+
return arrays, metrics
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def aggregate_n_closest_weights(
|
|
199
|
+
ref_weights: NDArrays, weights_list: list[NDArrays], beta: int
|
|
200
|
+
) -> NDArrays:
|
|
201
|
+
"""Compute the element-wise mean of the `beta` closest weight arrays.
|
|
202
|
+
|
|
203
|
+
For each element (i-th coordinate), the output is the average of the
|
|
204
|
+
`beta` weight arrays that are closest to the reference weights.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
ref_weights : NDArrays
|
|
209
|
+
Reference weights used to compute distances.
|
|
210
|
+
weights_list : list[NDArrays]
|
|
211
|
+
List of weight arrays (e.g., from selected nodes).
|
|
212
|
+
beta : int
|
|
213
|
+
Number of closest weight arrays to include in the averaging.
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
aggregated_weights : NDArrays
|
|
218
|
+
Element-wise average of the `beta` closest weight arrays to the
|
|
219
|
+
reference weights.
|
|
220
|
+
"""
|
|
221
|
+
aggregated_weights = []
|
|
222
|
+
for layer_id, ref_layer in enumerate(ref_weights):
|
|
223
|
+
# Shape: (n_models, *layer_shape)
|
|
224
|
+
layer_stack = np.stack([weights[layer_id] for weights in weights_list])
|
|
225
|
+
|
|
226
|
+
# Compute absolute differences: shape (n_models, *layer_shape)
|
|
227
|
+
diffs = np.abs(layer_stack - ref_layer)
|
|
228
|
+
|
|
229
|
+
# Find indices of `beta` smallest per coordinate
|
|
230
|
+
idx = np.argpartition(diffs, beta - 1, axis=0)[:beta]
|
|
231
|
+
|
|
232
|
+
# Gather the closest weights
|
|
233
|
+
closest = np.take_along_axis(layer_stack, idx, axis=0)
|
|
234
|
+
|
|
235
|
+
# Average them
|
|
236
|
+
aggregated_weights.append(np.mean(closest, axis=0))
|
|
237
|
+
|
|
238
|
+
return aggregated_weights
|