flwr 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/new/new.py +9 -7
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/utils.py +17 -0
- flwr/clientapp/mod/__init__.py +4 -1
- flwr/clientapp/mod/centraldp_mods.py +156 -40
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/common/constant.py +3 -0
- flwr/common/exit/exit_code.py +4 -0
- flwr/common/record/typeddict.py +12 -0
- flwr/proto/control_pb2.py +7 -3
- flwr/proto/control_pb2.pyi +24 -0
- flwr/proto/control_pb2_grpc.py +34 -0
- flwr/proto/control_pb2_grpc.pyi +13 -0
- flwr/server/app.py +13 -0
- flwr/serverapp/strategy/__init__.py +26 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavg.py +89 -64
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +19 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/control/control_grpc.py +3 -0
- flwr/superlink/servicer/control/control_servicer.py +59 -2
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/METADATA +6 -16
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/RECORD +93 -74
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
- {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Fair Resource Allocation in Federated Learning [Li et al., 2020] strategy.
|
|
16
|
+
|
|
17
|
+
Paper: openreview.net/pdf?id=ByexElSYDr
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from collections import OrderedDict
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
from logging import INFO
|
|
24
|
+
from typing import Callable, Optional
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from flwr.common import (
|
|
29
|
+
Array,
|
|
30
|
+
ArrayRecord,
|
|
31
|
+
ConfigRecord,
|
|
32
|
+
Message,
|
|
33
|
+
MetricRecord,
|
|
34
|
+
NDArray,
|
|
35
|
+
RecordDict,
|
|
36
|
+
)
|
|
37
|
+
from flwr.common.logger import log
|
|
38
|
+
from flwr.server import Grid
|
|
39
|
+
|
|
40
|
+
from ..exception import AggregationError
|
|
41
|
+
from .fedavg import FedAvg
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class QFedAvg(FedAvg):
|
|
45
|
+
"""Q-FedAvg strategy.
|
|
46
|
+
|
|
47
|
+
Implementation based on openreview.net/pdf?id=ByexElSYDr
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
client_learning_rate : float
|
|
52
|
+
Local learning rate used by clients during training. This value is used by
|
|
53
|
+
the strategy to approximate the base Lipschitz constant L, via
|
|
54
|
+
L = 1 / client_learning_rate.
|
|
55
|
+
q : float (default: 0.1)
|
|
56
|
+
The parameter q that controls the degree of fairness of the algorithm. Please
|
|
57
|
+
tune this parameter based on your use case.
|
|
58
|
+
When set to 0, q-FedAvg is equivalent to FedAvg.
|
|
59
|
+
train_loss_key : str (default: "train_loss")
|
|
60
|
+
The key within the MetricRecord whose value is used as the training loss when
|
|
61
|
+
aggregating ArrayRecords following q-FedAvg.
|
|
62
|
+
fraction_train : float (default: 1.0)
|
|
63
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
|
64
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
|
65
|
+
will still be sampled.
|
|
66
|
+
fraction_evaluate : float (default: 1.0)
|
|
67
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
|
68
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
|
69
|
+
`min_evaluate_nodes` will still be sampled.
|
|
70
|
+
min_train_nodes : int (default: 2)
|
|
71
|
+
Minimum number of nodes used during training.
|
|
72
|
+
min_evaluate_nodes : int (default: 2)
|
|
73
|
+
Minimum number of nodes used during validation.
|
|
74
|
+
min_available_nodes : int (default: 2)
|
|
75
|
+
Minimum number of total nodes in the system.
|
|
76
|
+
weighted_by_key : str (default: "num-examples")
|
|
77
|
+
The key within each MetricRecord whose value is used as the weight when
|
|
78
|
+
computing weighted averages for MetricRecords.
|
|
79
|
+
arrayrecord_key : str (default: "arrays")
|
|
80
|
+
Key used to store the ArrayRecord when constructing Messages.
|
|
81
|
+
configrecord_key : str (default: "config")
|
|
82
|
+
Key used to store the ConfigRecord when constructing Messages.
|
|
83
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
|
84
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
85
|
+
used to aggregate MetricRecords from training round replies.
|
|
86
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
87
|
+
average using the provided weight factor key.
|
|
88
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
|
89
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
90
|
+
used to aggregate MetricRecords from training round replies.
|
|
91
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
92
|
+
average using the provided weight factor key.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__( # pylint: disable=R0913, R0917
|
|
96
|
+
self,
|
|
97
|
+
client_learning_rate: float,
|
|
98
|
+
q: float = 0.1,
|
|
99
|
+
train_loss_key: str = "train_loss",
|
|
100
|
+
fraction_train: float = 1.0,
|
|
101
|
+
fraction_evaluate: float = 1.0,
|
|
102
|
+
min_train_nodes: int = 2,
|
|
103
|
+
min_evaluate_nodes: int = 2,
|
|
104
|
+
min_available_nodes: int = 2,
|
|
105
|
+
weighted_by_key: str = "num-examples",
|
|
106
|
+
arrayrecord_key: str = "arrays",
|
|
107
|
+
configrecord_key: str = "config",
|
|
108
|
+
train_metrics_aggr_fn: Optional[
|
|
109
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
110
|
+
] = None,
|
|
111
|
+
evaluate_metrics_aggr_fn: Optional[
|
|
112
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
113
|
+
] = None,
|
|
114
|
+
) -> None:
|
|
115
|
+
super().__init__(
|
|
116
|
+
fraction_train=fraction_train,
|
|
117
|
+
fraction_evaluate=fraction_evaluate,
|
|
118
|
+
min_train_nodes=min_train_nodes,
|
|
119
|
+
min_evaluate_nodes=min_evaluate_nodes,
|
|
120
|
+
min_available_nodes=min_available_nodes,
|
|
121
|
+
weighted_by_key=weighted_by_key,
|
|
122
|
+
arrayrecord_key=arrayrecord_key,
|
|
123
|
+
configrecord_key=configrecord_key,
|
|
124
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
|
125
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
|
126
|
+
)
|
|
127
|
+
self.q = q
|
|
128
|
+
self.client_learning_rate = client_learning_rate
|
|
129
|
+
self.train_loss_key = train_loss_key
|
|
130
|
+
self.current_arrays: Optional[ArrayRecord] = None
|
|
131
|
+
|
|
132
|
+
def summary(self) -> None:
|
|
133
|
+
"""Log summary configuration of the strategy."""
|
|
134
|
+
log(INFO, "\t├──> q-FedAvg settings:")
|
|
135
|
+
log(INFO, "\t│\t├── client_learning_rate: %s", self.client_learning_rate)
|
|
136
|
+
log(INFO, "\t│\t├── q: %s", self.q)
|
|
137
|
+
log(INFO, "\t│\t└── train_loss_key: '%s'", self.train_loss_key)
|
|
138
|
+
super().summary()
|
|
139
|
+
|
|
140
|
+
def configure_train(
|
|
141
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
142
|
+
) -> Iterable[Message]:
|
|
143
|
+
"""Configure the next round of federated training."""
|
|
144
|
+
self.current_arrays = arrays.copy()
|
|
145
|
+
return super().configure_train(server_round, arrays, config, grid)
|
|
146
|
+
|
|
147
|
+
def aggregate_train( # pylint: disable=too-many-locals
|
|
148
|
+
self,
|
|
149
|
+
server_round: int,
|
|
150
|
+
replies: Iterable[Message],
|
|
151
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
152
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
153
|
+
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
154
|
+
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
155
|
+
|
|
156
|
+
if not valid_replies:
|
|
157
|
+
return None, None
|
|
158
|
+
|
|
159
|
+
# Compute estimate of Lipschitz constant L
|
|
160
|
+
L = 1.0 / self.client_learning_rate # pylint: disable=C0103
|
|
161
|
+
|
|
162
|
+
# q-FedAvg aggregation
|
|
163
|
+
if self.current_arrays is None:
|
|
164
|
+
raise AggregationError(
|
|
165
|
+
"Current global model weights are not available. Make sure to call"
|
|
166
|
+
"`configure_train` before calling `aggregate_train`."
|
|
167
|
+
)
|
|
168
|
+
array_keys = list(self.current_arrays.keys()) # Preserve keys
|
|
169
|
+
global_weights = self.current_arrays.to_numpy_ndarrays(keep_input=False)
|
|
170
|
+
sum_delta = None
|
|
171
|
+
sum_h = 0.0
|
|
172
|
+
|
|
173
|
+
for msg in valid_replies:
|
|
174
|
+
# Extract local weights and training loss from Message
|
|
175
|
+
local_weights = get_local_weights(msg)
|
|
176
|
+
loss = get_train_loss(msg, self.train_loss_key)
|
|
177
|
+
|
|
178
|
+
# Compute delta and h
|
|
179
|
+
delta, h = compute_delta_and_h(
|
|
180
|
+
global_weights, local_weights, self.q, L, loss
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Compute sum of deltas and sum of h
|
|
184
|
+
if sum_delta is None:
|
|
185
|
+
sum_delta = delta
|
|
186
|
+
else:
|
|
187
|
+
sum_delta = [sd + d for sd, d in zip(sum_delta, delta)]
|
|
188
|
+
sum_h += h
|
|
189
|
+
|
|
190
|
+
# Compute new global weights and convert to Array type
|
|
191
|
+
# `np.asarray` can convert numpy scalars to 0-dim arrays
|
|
192
|
+
assert sum_delta is not None # Make mypy happy
|
|
193
|
+
array_list = [
|
|
194
|
+
Array(np.asarray(gw - (d / sum_h)))
|
|
195
|
+
for gw, d in zip(global_weights, sum_delta)
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
# Aggregate MetricRecords
|
|
199
|
+
metrics = self.train_metrics_aggr_fn(
|
|
200
|
+
[msg.content for msg in valid_replies],
|
|
201
|
+
self.weighted_by_key,
|
|
202
|
+
)
|
|
203
|
+
return ArrayRecord(OrderedDict(zip(array_keys, array_list))), metrics
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def get_train_loss(msg: Message, loss_key: str) -> float:
|
|
207
|
+
"""Extract training loss from a Message."""
|
|
208
|
+
metrics = list(msg.content.metric_records.values())[0]
|
|
209
|
+
if (loss := metrics.get(loss_key)) is None or not isinstance(loss, (int, float)):
|
|
210
|
+
raise AggregationError(
|
|
211
|
+
"Missing or invalid training loss. "
|
|
212
|
+
f"The strategy expected a float value for the key '{loss_key}' "
|
|
213
|
+
"as the training loss in each MetricRecord from the clients. "
|
|
214
|
+
f"Ensure that '{loss_key}' is present and maps to a valid float."
|
|
215
|
+
)
|
|
216
|
+
return float(loss)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def get_local_weights(msg: Message) -> list[NDArray]:
|
|
220
|
+
"""Extract local weights from a Message."""
|
|
221
|
+
arrays = list(msg.content.array_records.values())[0]
|
|
222
|
+
return arrays.to_numpy_ndarrays(keep_input=False)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def l2_norm(ndarrays: list[NDArray]) -> float:
|
|
226
|
+
"""Compute the squared L2 norm of a list of numpy.ndarray."""
|
|
227
|
+
return float(sum(np.sum(np.square(g)) for g in ndarrays))
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def compute_delta_and_h(
|
|
231
|
+
global_weights: list[NDArray],
|
|
232
|
+
local_weights: list[NDArray],
|
|
233
|
+
q: float,
|
|
234
|
+
L: float, # Lipschitz constant # pylint: disable=C0103
|
|
235
|
+
loss: float,
|
|
236
|
+
) -> tuple[list[NDArray], float]:
|
|
237
|
+
"""Compute delta and h used in q-FedAvg aggregation."""
|
|
238
|
+
# Compute gradient_k = L * (w - w_k)
|
|
239
|
+
for gw, lw in zip(global_weights, local_weights):
|
|
240
|
+
np.subtract(gw, lw, out=lw)
|
|
241
|
+
lw *= L
|
|
242
|
+
grad = local_weights # After in-place operations, local_weights is now grad
|
|
243
|
+
# Compute ||w_k - w||^2
|
|
244
|
+
norm = l2_norm(grad)
|
|
245
|
+
# Compute delta_k = loss_k^q * gradient_k
|
|
246
|
+
loss_pow_q: float = np.float_power(loss + 1e-10, q)
|
|
247
|
+
for g in grad:
|
|
248
|
+
g *= loss_pow_q
|
|
249
|
+
delta = grad # After in-place multiplication, grad is now delta
|
|
250
|
+
# Compute h_k
|
|
251
|
+
h = q * np.float_power(loss + 1e-10, q - 1) * norm + L * loss_pow_q
|
|
252
|
+
return delta, h
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Flower message-based strategy utilities."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import json
|
|
18
19
|
import random
|
|
19
20
|
from collections import OrderedDict
|
|
20
21
|
from logging import INFO
|
|
@@ -249,3 +250,50 @@ def validate_message_reply_consistency(
|
|
|
249
250
|
"must be a single value (int or float), but a list was found. Skipping "
|
|
250
251
|
"aggregation."
|
|
251
252
|
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def aggregate_bagging(
|
|
256
|
+
bst_prev_org: bytes,
|
|
257
|
+
bst_curr_org: bytes,
|
|
258
|
+
) -> bytes:
|
|
259
|
+
"""Conduct bagging aggregation for given trees."""
|
|
260
|
+
if bst_prev_org == b"":
|
|
261
|
+
return bst_curr_org
|
|
262
|
+
|
|
263
|
+
# Get the tree numbers
|
|
264
|
+
tree_num_prev, _ = _get_tree_nums(bst_prev_org)
|
|
265
|
+
_, paral_tree_num_curr = _get_tree_nums(bst_curr_org)
|
|
266
|
+
|
|
267
|
+
bst_prev = json.loads(bytearray(bst_prev_org))
|
|
268
|
+
bst_curr = json.loads(bytearray(bst_curr_org))
|
|
269
|
+
|
|
270
|
+
previous_model = bst_prev["learner"]["gradient_booster"]["model"]
|
|
271
|
+
previous_model["gbtree_model_param"]["num_trees"] = str(
|
|
272
|
+
tree_num_prev + paral_tree_num_curr
|
|
273
|
+
)
|
|
274
|
+
iteration_indptr = previous_model["iteration_indptr"]
|
|
275
|
+
previous_model["iteration_indptr"].append(
|
|
276
|
+
iteration_indptr[-1] + paral_tree_num_curr
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Aggregate new trees
|
|
280
|
+
trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
|
|
281
|
+
for tree_count in range(paral_tree_num_curr):
|
|
282
|
+
trees_curr[tree_count]["id"] = tree_num_prev + tree_count
|
|
283
|
+
previous_model["trees"].append(trees_curr[tree_count])
|
|
284
|
+
previous_model["tree_info"].append(0)
|
|
285
|
+
|
|
286
|
+
bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")
|
|
287
|
+
|
|
288
|
+
return bst_prev_bytes
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]:
|
|
292
|
+
xgb_model = json.loads(bytearray(xgb_model_org))
|
|
293
|
+
|
|
294
|
+
# Access model parameters
|
|
295
|
+
model_param = xgb_model["learner"]["gradient_booster"]["model"][
|
|
296
|
+
"gbtree_model_param"
|
|
297
|
+
]
|
|
298
|
+
# Return the number of trees and the number of parallel trees
|
|
299
|
+
return int(model_param["num_trees"]), int(model_param["num_parallel_tree"])
|
flwr/simulation/app.py
CHANGED
|
@@ -245,7 +245,7 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
245
245
|
run=run,
|
|
246
246
|
enable_tf_gpu_growth=enable_tf_gpu_growth,
|
|
247
247
|
verbose_logging=verbose,
|
|
248
|
-
|
|
248
|
+
server_app_context=context,
|
|
249
249
|
is_app=True,
|
|
250
250
|
exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
|
|
251
251
|
)
|
|
@@ -143,6 +143,15 @@ def run_simulation_from_cli() -> None:
|
|
|
143
143
|
run = Run.create_empty(run_id)
|
|
144
144
|
run.override_config = override_config
|
|
145
145
|
|
|
146
|
+
# Create Context
|
|
147
|
+
server_app_context = Context(
|
|
148
|
+
run_id=run_id,
|
|
149
|
+
node_id=0,
|
|
150
|
+
node_config=UserConfig(),
|
|
151
|
+
state=RecordDict(),
|
|
152
|
+
run_config=fused_config,
|
|
153
|
+
)
|
|
154
|
+
|
|
146
155
|
_ = _run_simulation(
|
|
147
156
|
server_app_attr=server_app_attr,
|
|
148
157
|
client_app_attr=client_app_attr,
|
|
@@ -153,7 +162,7 @@ def run_simulation_from_cli() -> None:
|
|
|
153
162
|
run=run,
|
|
154
163
|
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
|
|
155
164
|
verbose_logging=args.verbose,
|
|
156
|
-
|
|
165
|
+
server_app_context=server_app_context,
|
|
157
166
|
is_app=True,
|
|
158
167
|
exit_event=EventType.CLI_FLOWER_SIMULATION_LEAVE,
|
|
159
168
|
)
|
|
@@ -241,13 +250,12 @@ def run_simulation(
|
|
|
241
250
|
def run_serverapp_th(
|
|
242
251
|
server_app_attr: Optional[str],
|
|
243
252
|
server_app: Optional[ServerApp],
|
|
244
|
-
|
|
253
|
+
server_app_context: Context,
|
|
245
254
|
grid: Grid,
|
|
246
255
|
app_dir: str,
|
|
247
256
|
f_stop: threading.Event,
|
|
248
257
|
has_exception: threading.Event,
|
|
249
258
|
enable_tf_gpu_growth: bool,
|
|
250
|
-
run_id: int,
|
|
251
259
|
ctx_queue: "Queue[Context]",
|
|
252
260
|
) -> threading.Thread:
|
|
253
261
|
"""Run SeverApp in a thread."""
|
|
@@ -258,7 +266,6 @@ def run_serverapp_th(
|
|
|
258
266
|
exception_event: threading.Event,
|
|
259
267
|
_grid: Grid,
|
|
260
268
|
_server_app_dir: str,
|
|
261
|
-
_server_app_run_config: UserConfig,
|
|
262
269
|
_server_app_attr: Optional[str],
|
|
263
270
|
_server_app: Optional[ServerApp],
|
|
264
271
|
_ctx_queue: "Queue[Context]",
|
|
@@ -272,19 +279,10 @@ def run_serverapp_th(
|
|
|
272
279
|
log(INFO, "Enabling GPU growth for Tensorflow on the server thread.")
|
|
273
280
|
enable_gpu_growth()
|
|
274
281
|
|
|
275
|
-
# Initialize Context
|
|
276
|
-
context = Context(
|
|
277
|
-
run_id=run_id,
|
|
278
|
-
node_id=0,
|
|
279
|
-
node_config={},
|
|
280
|
-
state=RecordDict(),
|
|
281
|
-
run_config=_server_app_run_config,
|
|
282
|
-
)
|
|
283
|
-
|
|
284
282
|
# Run ServerApp
|
|
285
283
|
updated_context = _run(
|
|
286
284
|
grid=_grid,
|
|
287
|
-
context=
|
|
285
|
+
context=server_app_context,
|
|
288
286
|
server_app_dir=_server_app_dir,
|
|
289
287
|
server_app_attr=_server_app_attr,
|
|
290
288
|
loaded_server_app=_server_app,
|
|
@@ -310,7 +308,6 @@ def run_serverapp_th(
|
|
|
310
308
|
has_exception,
|
|
311
309
|
grid,
|
|
312
310
|
app_dir,
|
|
313
|
-
server_app_run_config,
|
|
314
311
|
server_app_attr,
|
|
315
312
|
server_app,
|
|
316
313
|
ctx_queue,
|
|
@@ -335,7 +332,7 @@ def _main_loop(
|
|
|
335
332
|
client_app_attr: Optional[str] = None,
|
|
336
333
|
server_app: Optional[ServerApp] = None,
|
|
337
334
|
server_app_attr: Optional[str] = None,
|
|
338
|
-
|
|
335
|
+
server_app_context: Optional[Context] = None,
|
|
339
336
|
) -> Context:
|
|
340
337
|
"""Start ServerApp on a separate thread, then launch Simulation Engine."""
|
|
341
338
|
# Initialize StateFactory
|
|
@@ -346,13 +343,15 @@ def _main_loop(
|
|
|
346
343
|
server_app_thread_has_exception = threading.Event()
|
|
347
344
|
serverapp_th = None
|
|
348
345
|
success = True
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
346
|
+
if server_app_context is None:
|
|
347
|
+
server_app_context = Context(
|
|
348
|
+
run_id=run.run_id,
|
|
349
|
+
node_id=0,
|
|
350
|
+
node_config=UserConfig(),
|
|
351
|
+
state=RecordDict(),
|
|
352
|
+
run_config=UserConfig(),
|
|
353
|
+
)
|
|
354
|
+
updated_context = server_app_context
|
|
356
355
|
try:
|
|
357
356
|
# Register run
|
|
358
357
|
log(DEBUG, "Pre-registering run with id %s", run.run_id)
|
|
@@ -361,9 +360,6 @@ def _main_loop(
|
|
|
361
360
|
run.running_at = run.starting_at
|
|
362
361
|
state_factory.state().run_ids[run.run_id] = RunRecord(run=run) # type: ignore
|
|
363
362
|
|
|
364
|
-
if server_app_run_config is None:
|
|
365
|
-
server_app_run_config = {}
|
|
366
|
-
|
|
367
363
|
# Initialize Grid
|
|
368
364
|
grid = InMemoryGrid(state_factory=state_factory)
|
|
369
365
|
grid.set_run(run_id=run.run_id)
|
|
@@ -373,13 +369,12 @@ def _main_loop(
|
|
|
373
369
|
serverapp_th = run_serverapp_th(
|
|
374
370
|
server_app_attr=server_app_attr,
|
|
375
371
|
server_app=server_app,
|
|
376
|
-
|
|
372
|
+
server_app_context=server_app_context,
|
|
377
373
|
grid=grid,
|
|
378
374
|
app_dir=app_dir,
|
|
379
375
|
f_stop=f_stop,
|
|
380
376
|
has_exception=server_app_thread_has_exception,
|
|
381
377
|
enable_tf_gpu_growth=enable_tf_gpu_growth,
|
|
382
|
-
run_id=run.run_id,
|
|
383
378
|
ctx_queue=output_context_queue,
|
|
384
379
|
)
|
|
385
380
|
|
|
@@ -438,7 +433,7 @@ def _run_simulation(
|
|
|
438
433
|
backend_config: Optional[BackendConfig] = None,
|
|
439
434
|
client_app_attr: Optional[str] = None,
|
|
440
435
|
server_app_attr: Optional[str] = None,
|
|
441
|
-
|
|
436
|
+
server_app_context: Optional[Context] = None,
|
|
442
437
|
app_dir: str = "",
|
|
443
438
|
flwr_dir: Optional[str] = None,
|
|
444
439
|
run: Optional[Run] = None,
|
|
@@ -502,7 +497,7 @@ def _run_simulation(
|
|
|
502
497
|
client_app_attr,
|
|
503
498
|
server_app,
|
|
504
499
|
server_app_attr,
|
|
505
|
-
|
|
500
|
+
server_app_context,
|
|
506
501
|
)
|
|
507
502
|
# Detect if there is an Asyncio event loop already running.
|
|
508
503
|
# If yes, disable logger propagation. In environmnets
|
|
@@ -17,7 +17,9 @@
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
19
|
from logging import INFO
|
|
20
|
-
from typing import Optional
|
|
20
|
+
from typing import Any, Optional
|
|
21
|
+
|
|
22
|
+
import yaml
|
|
21
23
|
|
|
22
24
|
from flwr.common import EventType, event
|
|
23
25
|
from flwr.common.constant import ExecPluginType
|
|
@@ -26,6 +28,7 @@ from flwr.common.logger import log
|
|
|
26
28
|
from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
|
|
27
29
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
|
|
28
30
|
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub
|
|
31
|
+
from flwr.supercore.constant import EXEC_PLUGIN_SECTION
|
|
29
32
|
from flwr.supercore.grpc_health import add_args_health
|
|
30
33
|
from flwr.supercore.superexec.plugin import (
|
|
31
34
|
ClientAppExecPlugin,
|
|
@@ -36,6 +39,7 @@ from flwr.supercore.superexec.plugin import (
|
|
|
36
39
|
from flwr.supercore.superexec.run_superexec import run_superexec
|
|
37
40
|
|
|
38
41
|
try:
|
|
42
|
+
from flwr.ee import add_ee_args_superexec
|
|
39
43
|
from flwr.ee.constant import ExecEePluginType
|
|
40
44
|
from flwr.ee.exec_plugin import get_ee_plugin_and_stub_class
|
|
41
45
|
except ImportError:
|
|
@@ -54,6 +58,10 @@ except ImportError:
|
|
|
54
58
|
"""Get the EE plugin class and stub class based on the plugin type."""
|
|
55
59
|
return None
|
|
56
60
|
|
|
61
|
+
# pylint: disable-next=unused-argument
|
|
62
|
+
def add_ee_args_superexec(parser: argparse.ArgumentParser) -> None:
|
|
63
|
+
"""Add EE-specific arguments to the parser."""
|
|
64
|
+
|
|
57
65
|
|
|
58
66
|
def flower_superexec() -> None:
|
|
59
67
|
"""Run `flower-superexec` command."""
|
|
@@ -70,12 +78,28 @@ def flower_superexec() -> None:
|
|
|
70
78
|
# Trigger telemetry event
|
|
71
79
|
event(EventType.RUN_SUPEREXEC_ENTER, {"plugin_type": args.plugin_type})
|
|
72
80
|
|
|
81
|
+
# Load plugin config from YAML file if provided
|
|
82
|
+
plugin_config = None
|
|
83
|
+
if plugin_config_path := getattr(args, "plugin_config", None):
|
|
84
|
+
try:
|
|
85
|
+
with open(plugin_config_path, encoding="utf-8") as file:
|
|
86
|
+
yaml_config: Optional[dict[str, Any]] = yaml.safe_load(file)
|
|
87
|
+
if yaml_config is None or EXEC_PLUGIN_SECTION not in yaml_config:
|
|
88
|
+
raise ValueError(f"Missing '{EXEC_PLUGIN_SECTION}' section.")
|
|
89
|
+
plugin_config = yaml_config[EXEC_PLUGIN_SECTION]
|
|
90
|
+
except (FileNotFoundError, yaml.YAMLError, ValueError) as e:
|
|
91
|
+
flwr_exit(
|
|
92
|
+
ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG,
|
|
93
|
+
f"Failed to load plugin config from '{plugin_config_path}': {e!r}",
|
|
94
|
+
)
|
|
95
|
+
|
|
73
96
|
# Get the plugin class and stub class based on the plugin type
|
|
74
97
|
plugin_class, stub_class = _get_plugin_and_stub_class(args.plugin_type)
|
|
75
98
|
run_superexec(
|
|
76
99
|
plugin_class=plugin_class,
|
|
77
100
|
stub_class=stub_class, # type: ignore
|
|
78
101
|
appio_api_address=args.appio_api_address,
|
|
102
|
+
plugin_config=plugin_config,
|
|
79
103
|
flwr_dir=args.flwr_dir,
|
|
80
104
|
parent_pid=args.parent_pid,
|
|
81
105
|
health_server_address=args.health_server_address,
|
|
@@ -122,6 +146,7 @@ def _parse_args() -> argparse.ArgumentParser:
|
|
|
122
146
|
help="The PID of the parent process. When set, the process will terminate "
|
|
123
147
|
"when the parent process exits.",
|
|
124
148
|
)
|
|
149
|
+
add_ee_args_superexec(parser)
|
|
125
150
|
add_args_health(parser)
|
|
126
151
|
return parser
|
|
127
152
|
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Constants for Flower infrastructure."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Top-level key in YAML config for exec plugin settings
|
|
19
|
+
EXEC_PLUGIN_SECTION = "exec_plugin"
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from collections.abc import Sequence
|
|
20
|
-
from typing import Callable, Optional
|
|
20
|
+
from typing import Any, Callable, Optional
|
|
21
21
|
|
|
22
22
|
from flwr.common.typing import Run
|
|
23
23
|
|
|
@@ -69,3 +69,13 @@ class ExecPlugin(ABC):
|
|
|
69
69
|
The ID of the run associated with the token, used for tracking or
|
|
70
70
|
logging purposes.
|
|
71
71
|
"""
|
|
72
|
+
|
|
73
|
+
# This method is optional to implement
|
|
74
|
+
def load_config(self, yaml_config: dict[str, Any]) -> None:
|
|
75
|
+
"""Load configuration from a YAML dictionary.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
yaml_config : dict[str, Any]
|
|
80
|
+
A dictionary representing the YAML configuration.
|
|
81
|
+
"""
|
|
@@ -17,10 +17,10 @@
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
19
|
from logging import WARN
|
|
20
|
-
from typing import Optional, Union
|
|
20
|
+
from typing import Any, Optional, Union
|
|
21
21
|
|
|
22
22
|
from flwr.common.config import get_flwr_dir
|
|
23
|
-
from flwr.common.exit import register_signal_handlers
|
|
23
|
+
from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
|
|
24
24
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
@@ -47,6 +47,7 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
|
|
|
47
47
|
type[ClientAppIoStub], type[ServerAppIoStub], type[SimulationIoStub]
|
|
48
48
|
],
|
|
49
49
|
appio_api_address: str,
|
|
50
|
+
plugin_config: Optional[dict[str, Any]] = None,
|
|
50
51
|
flwr_dir: Optional[str] = None,
|
|
51
52
|
parent_pid: Optional[int] = None,
|
|
52
53
|
health_server_address: Optional[str] = None,
|
|
@@ -61,6 +62,9 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
|
|
|
61
62
|
The gRPC stub class for the AppIO API.
|
|
62
63
|
appio_api_address : str
|
|
63
64
|
The address of the AppIO API.
|
|
65
|
+
plugin_config : Optional[dict[str, Any]] (default: None)
|
|
66
|
+
The configuration dictionary for the plugin. If `None`, the plugin will use
|
|
67
|
+
its default configuration.
|
|
64
68
|
flwr_dir : Optional[str] (default: None)
|
|
65
69
|
The Flower directory.
|
|
66
70
|
parent_pid : Optional[int] (default: None)
|
|
@@ -113,6 +117,16 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
|
|
|
113
117
|
get_run=get_run,
|
|
114
118
|
)
|
|
115
119
|
|
|
120
|
+
# Load plugin configuration from file if provided
|
|
121
|
+
try:
|
|
122
|
+
if plugin_config is not None:
|
|
123
|
+
plugin.load_config(plugin_config)
|
|
124
|
+
except (KeyError, ValueError) as e:
|
|
125
|
+
flwr_exit(
|
|
126
|
+
code=ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG,
|
|
127
|
+
message=f"Invalid plugin config: {e!r}",
|
|
128
|
+
)
|
|
129
|
+
|
|
116
130
|
# Start the main loop
|
|
117
131
|
try:
|
|
118
132
|
while True:
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ArtifactProvider for SuperLink."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .artifact_provider import ArtifactProvider
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"ArtifactProvider",
|
|
22
|
+
]
|