flwr 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +11 -0
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +9 -13
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +19 -8
- flwr/client/grpc_rere_client/connection.py +1 -12
- flwr/client/rest_client/connection.py +3 -0
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +26 -0
- flwr/clientapp/mod/centraldp_mods.py +132 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +23 -4
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +26 -7
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +0 -11
- flwr/common/inflatable_utils.py +1 -1
- flwr/common/logger.py +1 -1
- flwr/common/retry_invoker.py +30 -11
- flwr/common/telemetry.py +4 -0
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +25 -17
- flwr/proto/appio_pb2.pyi +46 -2
- flwr/proto/clientappio_pb2.py +3 -11
- flwr/proto/clientappio_pb2.pyi +0 -47
- flwr/proto/clientappio_pb2_grpc.py +19 -20
- flwr/proto/clientappio_pb2_grpc.pyi +10 -11
- flwr/proto/control_pb2.py +62 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
- flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +68 -0
- flwr/proto/serverappio_pb2_grpc.pyi +26 -0
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +129 -152
- flwr/server/grid/grpc_grid.py +3 -0
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +157 -146
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
- flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/dp_fixed_clipping.py +352 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +38 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
- flwr/serverapp/strategy/fedadagrad.py +162 -0
- flwr/serverapp/strategy/fedadam.py +181 -0
- flwr/serverapp/strategy/fedavg.py +295 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedyogi.py +173 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +251 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
- flwr/simulation/app.py +161 -164
- flwr/supercore/app_utils.py +58 -0
- flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
- flwr/supercore/cli/flower_superexec.py +141 -0
- flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/supercore/grpc_health/__init__.py +3 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +4 -4
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +185 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +24 -29
- flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +69 -30
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
- flwr/supernode/cli/flower_supernode.py +3 -0
- flwr/supernode/cli/flwr_clientapp.py +18 -21
- flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
- flwr/supernode/nodestate/nodestate.py +3 -59
- flwr/supernode/runtime/run_clientapp.py +39 -102
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
- flwr/supernode/start_client_internal.py +35 -76
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/METADATA +4 -3
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/RECORD +127 -98
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
- flwr/proto/exec_pb2.py +0 -62
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -191
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -129
- /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower message-based strategy utilities."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import random
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from logging import INFO
|
|
21
|
+
from time import sleep
|
|
22
|
+
from typing import Optional, cast
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
from flwr.common import (
|
|
27
|
+
Array,
|
|
28
|
+
ArrayRecord,
|
|
29
|
+
ConfigRecord,
|
|
30
|
+
MetricRecord,
|
|
31
|
+
NDArray,
|
|
32
|
+
RecordDict,
|
|
33
|
+
log,
|
|
34
|
+
)
|
|
35
|
+
from flwr.server import Grid
|
|
36
|
+
|
|
37
|
+
from ..exception import InconsistentMessageReplies
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def config_to_str(config: ConfigRecord) -> str:
|
|
41
|
+
"""Convert a ConfigRecord to a string representation masking bytes."""
|
|
42
|
+
content = ", ".join(
|
|
43
|
+
f"'{k}': {'<bytes>' if isinstance(v, bytes) else v}" for k, v in config.items()
|
|
44
|
+
)
|
|
45
|
+
return f"{{{content}}}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def log_strategy_start_info(
|
|
49
|
+
num_rounds: int,
|
|
50
|
+
arrays: ArrayRecord,
|
|
51
|
+
train_config: Optional[ConfigRecord],
|
|
52
|
+
evaluate_config: Optional[ConfigRecord],
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Log information about the strategy start."""
|
|
55
|
+
log(INFO, "\t├── Number of rounds: %d", num_rounds)
|
|
56
|
+
log(
|
|
57
|
+
INFO,
|
|
58
|
+
"\t├── ArrayRecord (%.2f MB)",
|
|
59
|
+
sum(len(array.data) for array in arrays.values()) / (1024**2),
|
|
60
|
+
)
|
|
61
|
+
log(
|
|
62
|
+
INFO,
|
|
63
|
+
"\t├── ConfigRecord (train): %s",
|
|
64
|
+
config_to_str(train_config) if train_config else "(empty!)",
|
|
65
|
+
)
|
|
66
|
+
log(
|
|
67
|
+
INFO,
|
|
68
|
+
"\t├── ConfigRecord (evaluate): %s",
|
|
69
|
+
config_to_str(evaluate_config) if evaluate_config else "(empty!)",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def aggregate_arrayrecords(
|
|
74
|
+
records: list[RecordDict], weighting_metric_name: str
|
|
75
|
+
) -> ArrayRecord:
|
|
76
|
+
"""Perform weighted aggregation all ArrayRecords using a specific key."""
|
|
77
|
+
# Retrieve weighting factor from MetricRecord
|
|
78
|
+
weights: list[float] = []
|
|
79
|
+
for record in records:
|
|
80
|
+
# Get the first (and only) MetricRecord in the record
|
|
81
|
+
metricrecord = next(iter(record.metric_records.values()))
|
|
82
|
+
# Because replies have been checked for consistency,
|
|
83
|
+
# we can safely cast the weighting factor to float
|
|
84
|
+
w = cast(float, metricrecord[weighting_metric_name])
|
|
85
|
+
weights.append(w)
|
|
86
|
+
|
|
87
|
+
# Average
|
|
88
|
+
total_weight = sum(weights)
|
|
89
|
+
weight_factors = [w / total_weight for w in weights]
|
|
90
|
+
|
|
91
|
+
# Perform weighted aggregation
|
|
92
|
+
aggregated_np_arrays: dict[str, NDArray] = {}
|
|
93
|
+
|
|
94
|
+
for record, weight in zip(records, weight_factors):
|
|
95
|
+
for record_item in record.array_records.values():
|
|
96
|
+
# aggregate in-place
|
|
97
|
+
for key, value in record_item.items():
|
|
98
|
+
if key not in aggregated_np_arrays:
|
|
99
|
+
aggregated_np_arrays[key] = value.numpy() * weight
|
|
100
|
+
else:
|
|
101
|
+
aggregated_np_arrays[key] += value.numpy() * weight
|
|
102
|
+
|
|
103
|
+
return ArrayRecord(
|
|
104
|
+
OrderedDict({k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()})
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def aggregate_metricrecords(
|
|
109
|
+
records: list[RecordDict], weighting_metric_name: str
|
|
110
|
+
) -> MetricRecord:
|
|
111
|
+
"""Perform weighted aggregation all MetricRecords using a specific key."""
|
|
112
|
+
# Retrieve weighting factor from MetricRecord
|
|
113
|
+
weights: list[float] = []
|
|
114
|
+
for record in records:
|
|
115
|
+
# Get the first (and only) MetricRecord in the record
|
|
116
|
+
metricrecord = next(iter(record.metric_records.values()))
|
|
117
|
+
# Because replies have been checked for consistency,
|
|
118
|
+
# we can safely cast the weighting factor to float
|
|
119
|
+
w = cast(float, metricrecord[weighting_metric_name])
|
|
120
|
+
weights.append(w)
|
|
121
|
+
|
|
122
|
+
# Average
|
|
123
|
+
total_weight = sum(weights)
|
|
124
|
+
weight_factors = [w / total_weight for w in weights]
|
|
125
|
+
|
|
126
|
+
aggregated_metrics = MetricRecord()
|
|
127
|
+
for record, weight in zip(records, weight_factors):
|
|
128
|
+
for record_item in record.metric_records.values():
|
|
129
|
+
# aggregate in-place
|
|
130
|
+
for key, value in record_item.items():
|
|
131
|
+
if key == weighting_metric_name:
|
|
132
|
+
# We exclude the weighting key from the aggregated MetricRecord
|
|
133
|
+
continue
|
|
134
|
+
if key not in aggregated_metrics:
|
|
135
|
+
if isinstance(value, list):
|
|
136
|
+
aggregated_metrics[key] = [v * weight for v in value]
|
|
137
|
+
else:
|
|
138
|
+
aggregated_metrics[key] = value * weight
|
|
139
|
+
else:
|
|
140
|
+
if isinstance(value, list):
|
|
141
|
+
current_list = cast(list[float], aggregated_metrics[key])
|
|
142
|
+
aggregated_metrics[key] = [
|
|
143
|
+
curr + val * weight
|
|
144
|
+
for curr, val in zip(current_list, value)
|
|
145
|
+
]
|
|
146
|
+
else:
|
|
147
|
+
current_value = cast(float, aggregated_metrics[key])
|
|
148
|
+
aggregated_metrics[key] = current_value + value * weight
|
|
149
|
+
|
|
150
|
+
return aggregated_metrics
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def sample_nodes(
|
|
154
|
+
grid: Grid, min_available_nodes: int, sample_size: int
|
|
155
|
+
) -> tuple[list[int], list[int]]:
|
|
156
|
+
"""Sample the specified number of nodes using the Grid.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
grid : Grid
|
|
161
|
+
The grid object.
|
|
162
|
+
min_available_nodes : int
|
|
163
|
+
The minimum number of available nodes to sample from.
|
|
164
|
+
sample_size : int
|
|
165
|
+
The number of nodes to sample.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
tuple[list[int], list[int]]
|
|
170
|
+
A tuple containing the sampled node IDs and the list
|
|
171
|
+
of all connected node IDs.
|
|
172
|
+
"""
|
|
173
|
+
sampled_nodes = []
|
|
174
|
+
|
|
175
|
+
# Ensure min_available_nodes is at least as large as sample_size
|
|
176
|
+
min_available_nodes = max(min_available_nodes, sample_size)
|
|
177
|
+
|
|
178
|
+
# wait for min_available_nodes to be online
|
|
179
|
+
while len(all_nodes := list(grid.get_node_ids())) < min_available_nodes:
|
|
180
|
+
log(
|
|
181
|
+
INFO,
|
|
182
|
+
"Waiting for nodes to connect: %d connected (minimum required: %d).",
|
|
183
|
+
len(all_nodes),
|
|
184
|
+
min_available_nodes,
|
|
185
|
+
)
|
|
186
|
+
sleep(1)
|
|
187
|
+
|
|
188
|
+
# Sample nodes
|
|
189
|
+
sampled_nodes = random.sample(all_nodes, sample_size)
|
|
190
|
+
|
|
191
|
+
return sampled_nodes, all_nodes
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# pylint: disable=too-many-return-statements
|
|
195
|
+
def validate_message_reply_consistency(
|
|
196
|
+
replies: list[RecordDict], weighted_by_key: str, check_arrayrecord: bool
|
|
197
|
+
) -> None:
|
|
198
|
+
"""Validate that replies contain exactly one ArrayRecord and one MetricRecord, and
|
|
199
|
+
that the MetricRecord includes a weight factor key.
|
|
200
|
+
|
|
201
|
+
These checks ensure that Message-based strategies behave consistently with
|
|
202
|
+
*Ins/*Res-based strategies.
|
|
203
|
+
"""
|
|
204
|
+
# Checking for ArrayRecord consistency
|
|
205
|
+
if check_arrayrecord:
|
|
206
|
+
if any(len(msg.array_records) != 1 for msg in replies):
|
|
207
|
+
raise InconsistentMessageReplies(
|
|
208
|
+
reason="Expected exactly one ArrayRecord in replies. "
|
|
209
|
+
"Skipping aggregation."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Ensure all key are present in all ArrayRecords
|
|
213
|
+
record_key = next(iter(replies[0].array_records.keys()))
|
|
214
|
+
all_keys = set(replies[0][record_key].keys())
|
|
215
|
+
if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
|
|
216
|
+
raise InconsistentMessageReplies(
|
|
217
|
+
reason="All ArrayRecords must have the same keys for aggregation. "
|
|
218
|
+
"This condition wasn't met. Skipping aggregation."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Checking for MetricRecord consistency
|
|
222
|
+
if any(len(msg.metric_records) != 1 for msg in replies):
|
|
223
|
+
raise InconsistentMessageReplies(
|
|
224
|
+
reason="Expected exactly one MetricRecord in replies, but found more. "
|
|
225
|
+
"Skipping aggregation."
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Ensure all key are present in all MetricRecords
|
|
229
|
+
record_key = next(iter(replies[0].metric_records.keys()))
|
|
230
|
+
all_keys = set(replies[0][record_key].keys())
|
|
231
|
+
if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
|
|
232
|
+
raise InconsistentMessageReplies(
|
|
233
|
+
reason="All MetricRecords must have the same keys for aggregation. "
|
|
234
|
+
"This condition wasn't met. Skipping aggregation."
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Verify the weight factor key presence in all MetricRecords
|
|
238
|
+
if weighted_by_key not in all_keys:
|
|
239
|
+
raise InconsistentMessageReplies(
|
|
240
|
+
reason=f"Missing required key `{weighted_by_key}` in the MetricRecord of "
|
|
241
|
+
"reply messages. Cannot average ArrayRecords and MetricRecords. Skipping "
|
|
242
|
+
"aggregation."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Check that it is not a list
|
|
246
|
+
if any(isinstance(msg[record_key][weighted_by_key], list) for msg in replies):
|
|
247
|
+
raise InconsistentMessageReplies(
|
|
248
|
+
reason=f"Key `{weighted_by_key}` in the MetricRecord of reply messages "
|
|
249
|
+
"must be a single value (int or float), but a list was found. Skipping "
|
|
250
|
+
"aggregation."
|
|
251
|
+
)
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Tests for message-based strategy utilities."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pytest
|
|
22
|
+
from parameterized import parameterized
|
|
23
|
+
|
|
24
|
+
from flwr.common import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
|
25
|
+
from flwr.serverapp.exception import InconsistentMessageReplies
|
|
26
|
+
|
|
27
|
+
from .strategy_utils import (
|
|
28
|
+
aggregate_arrayrecords,
|
|
29
|
+
aggregate_metricrecords,
|
|
30
|
+
config_to_str,
|
|
31
|
+
validate_message_reply_consistency,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_config_to_str() -> None:
|
|
36
|
+
"""Test that items of types bytes are masked out."""
|
|
37
|
+
config = ConfigRecord({"a": 123, "b": [1, 2, 3], "c": b"bytes"})
|
|
38
|
+
expected_str = "{'a': 123, 'b': [1, 2, 3], 'c': <bytes>}"
|
|
39
|
+
assert config_to_str(config) == expected_str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_arrayrecords_aggregation() -> None:
|
|
43
|
+
"""Test aggregation of ArrayRecords."""
|
|
44
|
+
num_replies = 3
|
|
45
|
+
num_arrays = 4
|
|
46
|
+
weights = [0.25, 0.4, 0.35]
|
|
47
|
+
np_arrays = [
|
|
48
|
+
[np.random.randn(7, 3) for _ in range(num_arrays)] for _ in range(num_replies)
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
avg_list = [
|
|
52
|
+
np.average([lst[i] for lst in np_arrays], axis=0, weights=weights)
|
|
53
|
+
for i in range(num_arrays)
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
# Construct RecordDicts (mimicing replies)
|
|
57
|
+
records = [
|
|
58
|
+
RecordDict(
|
|
59
|
+
{
|
|
60
|
+
"arrays": ArrayRecord(np_arrays[i]),
|
|
61
|
+
"metrics": MetricRecord({"weight": weights[i]}),
|
|
62
|
+
}
|
|
63
|
+
)
|
|
64
|
+
for i in range(num_replies)
|
|
65
|
+
]
|
|
66
|
+
# Execute aggregate
|
|
67
|
+
aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
|
|
68
|
+
|
|
69
|
+
# Assert consistency
|
|
70
|
+
assert all(np.allclose(a, b) for a, b in zip(aggrd.to_numpy_ndarrays(), avg_list))
|
|
71
|
+
assert aggrd.object_id == ArrayRecord(avg_list).object_id
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_arrayrecords_aggregation_with_ndim_zero() -> None:
|
|
75
|
+
"""Test aggregation of ArrayRecords with 0-dim arrays."""
|
|
76
|
+
num_replies = 3
|
|
77
|
+
weights = [0.25, 0.4, 0.35]
|
|
78
|
+
np_arrays = [np.array(np.random.randn()) for _ in range(num_replies)]
|
|
79
|
+
|
|
80
|
+
# For 0-dimensional arrays, we just compute the weighted average directly
|
|
81
|
+
avg_list = [np.average(np_arrays, axis=0, weights=weights)]
|
|
82
|
+
|
|
83
|
+
# Construct RecordDicts (mimicing replies)
|
|
84
|
+
records = [
|
|
85
|
+
RecordDict(
|
|
86
|
+
{
|
|
87
|
+
"arrays": ArrayRecord([np_arrays[i]]),
|
|
88
|
+
"metrics": MetricRecord({"weight": weights[i]}),
|
|
89
|
+
}
|
|
90
|
+
)
|
|
91
|
+
for i in range(num_replies)
|
|
92
|
+
]
|
|
93
|
+
# Execute aggregate
|
|
94
|
+
aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
|
|
95
|
+
|
|
96
|
+
# Assert consistency
|
|
97
|
+
assert np.isclose(aggrd.to_numpy_ndarrays()[0], avg_list[0])
|
|
98
|
+
assert aggrd.object_id == ArrayRecord([np.array(avg_list[0])]).object_id
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_metricrecords_aggregation() -> None:
|
|
102
|
+
"""Test aggregation of MetricRecords."""
|
|
103
|
+
num_replies = 3
|
|
104
|
+
weights = [0.25, 0.4, 0.35]
|
|
105
|
+
metric_records = [
|
|
106
|
+
MetricRecord({"a": 1, "b": 2.0, "c": np.random.randn(3).tolist()})
|
|
107
|
+
for _ in range(num_replies)
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# Compute expected aggregated MetricRecord.
|
|
111
|
+
# For ease, we convert everything into numpy arrays, then aggregate
|
|
112
|
+
as_np_entries = [
|
|
113
|
+
{
|
|
114
|
+
k: np.array(v) if isinstance(v, (int, float, list)) else v
|
|
115
|
+
for k, v in record.items()
|
|
116
|
+
}
|
|
117
|
+
for record in metric_records
|
|
118
|
+
]
|
|
119
|
+
avg_list = [
|
|
120
|
+
np.average(
|
|
121
|
+
[list(entries.values())[i] for entries in as_np_entries],
|
|
122
|
+
axis=0,
|
|
123
|
+
weights=weights,
|
|
124
|
+
).tolist()
|
|
125
|
+
for i in range(len(as_np_entries[0]))
|
|
126
|
+
]
|
|
127
|
+
expected_record = MetricRecord(dict(zip(as_np_entries[0].keys(), avg_list)))
|
|
128
|
+
expected_record["a"] = float(expected_record["a"]) # type: ignore
|
|
129
|
+
expected_record["b"] = float(expected_record["b"]) # type: ignore
|
|
130
|
+
|
|
131
|
+
# Construct RecordDicts (mimicing replies)
|
|
132
|
+
# Inject weighting factor
|
|
133
|
+
records = [
|
|
134
|
+
RecordDict(
|
|
135
|
+
{
|
|
136
|
+
"metrics": MetricRecord(
|
|
137
|
+
record.__dict__["_data"] | {"weight": weights[i]}
|
|
138
|
+
),
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
for i, record in enumerate(metric_records)
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
# Execute aggregate
|
|
145
|
+
aggrd = aggregate_metricrecords(records, weighting_metric_name="weight")
|
|
146
|
+
# Assert
|
|
147
|
+
assert expected_record.object_id == aggrd.object_id
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@parameterized.expand( # type: ignore
|
|
151
|
+
[
|
|
152
|
+
(
|
|
153
|
+
True,
|
|
154
|
+
RecordDict(
|
|
155
|
+
{
|
|
156
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
157
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
158
|
+
}
|
|
159
|
+
),
|
|
160
|
+
), # Compliant
|
|
161
|
+
(
|
|
162
|
+
False,
|
|
163
|
+
RecordDict(
|
|
164
|
+
{
|
|
165
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
166
|
+
"metrics": MetricRecord({"weight": [0.123]}),
|
|
167
|
+
}
|
|
168
|
+
),
|
|
169
|
+
), # Weighting key is not a scalar (BAD)
|
|
170
|
+
(
|
|
171
|
+
False,
|
|
172
|
+
RecordDict(
|
|
173
|
+
{
|
|
174
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
175
|
+
"metrics": MetricRecord({"loss": 0.01}),
|
|
176
|
+
}
|
|
177
|
+
),
|
|
178
|
+
), # No weighting key in MetricRecord (BAD)
|
|
179
|
+
(
|
|
180
|
+
False,
|
|
181
|
+
RecordDict({"global-model": ArrayRecord([np.random.randn(7, 3)])}),
|
|
182
|
+
), # No MetricsRecord (BAD)
|
|
183
|
+
(
|
|
184
|
+
False,
|
|
185
|
+
RecordDict(
|
|
186
|
+
{
|
|
187
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
188
|
+
"another-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
189
|
+
}
|
|
190
|
+
),
|
|
191
|
+
), # Two ArrayRecords (BAD)
|
|
192
|
+
(
|
|
193
|
+
False,
|
|
194
|
+
RecordDict(
|
|
195
|
+
{
|
|
196
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
197
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
198
|
+
"more-metrics": MetricRecord({"loss": 0.321}),
|
|
199
|
+
}
|
|
200
|
+
),
|
|
201
|
+
), # Two MetricRecords (BAD)
|
|
202
|
+
]
|
|
203
|
+
)
|
|
204
|
+
def test_consistency_of_replies_with_matching_keys(
|
|
205
|
+
is_valid: bool, recorddict: RecordDict
|
|
206
|
+
) -> None:
|
|
207
|
+
"""Test consistency in replies."""
|
|
208
|
+
# Create dummy records
|
|
209
|
+
records = [recorddict for _ in range(3)]
|
|
210
|
+
|
|
211
|
+
if not is_valid:
|
|
212
|
+
# Should raise InconsistentMessageReplies exception
|
|
213
|
+
with pytest.raises(InconsistentMessageReplies):
|
|
214
|
+
validate_message_reply_consistency(
|
|
215
|
+
records, weighted_by_key="weight", check_arrayrecord=True
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
# Should not raise an exception
|
|
219
|
+
validate_message_reply_consistency(
|
|
220
|
+
records, weighted_by_key="weight", check_arrayrecord=True
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@parameterized.expand( # type: ignore
|
|
225
|
+
[
|
|
226
|
+
(
|
|
227
|
+
[
|
|
228
|
+
RecordDict(
|
|
229
|
+
{
|
|
230
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
231
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
232
|
+
}
|
|
233
|
+
),
|
|
234
|
+
RecordDict(
|
|
235
|
+
{
|
|
236
|
+
"model": ArrayRecord([np.random.randn(7, 3)]),
|
|
237
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
238
|
+
}
|
|
239
|
+
),
|
|
240
|
+
],
|
|
241
|
+
), # top-level keys don't match for ArrayRecords
|
|
242
|
+
(
|
|
243
|
+
[
|
|
244
|
+
RecordDict(
|
|
245
|
+
{
|
|
246
|
+
"global-model": ArrayRecord(
|
|
247
|
+
OrderedDict({"a": Array(np.random.randn(7, 3))})
|
|
248
|
+
),
|
|
249
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
250
|
+
}
|
|
251
|
+
),
|
|
252
|
+
RecordDict(
|
|
253
|
+
{
|
|
254
|
+
"global-model": ArrayRecord(
|
|
255
|
+
OrderedDict({"b": Array(np.random.randn(7, 3))})
|
|
256
|
+
),
|
|
257
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
258
|
+
}
|
|
259
|
+
),
|
|
260
|
+
],
|
|
261
|
+
), # top-level keys match for ArrayRecords but not those for Arrays
|
|
262
|
+
(
|
|
263
|
+
[
|
|
264
|
+
RecordDict(
|
|
265
|
+
{
|
|
266
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
267
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
268
|
+
}
|
|
269
|
+
),
|
|
270
|
+
RecordDict(
|
|
271
|
+
{
|
|
272
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
273
|
+
"my-metrics": MetricRecord({"weight": 0.123}),
|
|
274
|
+
}
|
|
275
|
+
),
|
|
276
|
+
],
|
|
277
|
+
), # top-level keys don't match for MetricRecords
|
|
278
|
+
(
|
|
279
|
+
[
|
|
280
|
+
RecordDict(
|
|
281
|
+
{
|
|
282
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
283
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
|
284
|
+
}
|
|
285
|
+
),
|
|
286
|
+
RecordDict(
|
|
287
|
+
{
|
|
288
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
289
|
+
"my-metrics": MetricRecord({"my-weights": 0.123}),
|
|
290
|
+
}
|
|
291
|
+
),
|
|
292
|
+
],
|
|
293
|
+
), # top-level keys match for MetricRecords but not inner ones
|
|
294
|
+
]
|
|
295
|
+
)
|
|
296
|
+
def test_consistency_of_replies_with_different_keys(
|
|
297
|
+
list_records: list[RecordDict],
|
|
298
|
+
) -> None:
|
|
299
|
+
"""Test consistency in replies when records don't have matching keys."""
|
|
300
|
+
# All test cases expect InconsistentMessageReplies exception to be raised
|
|
301
|
+
with pytest.raises(InconsistentMessageReplies):
|
|
302
|
+
validate_message_reply_consistency(
|
|
303
|
+
list_records, weighted_by_key="weight", check_arrayrecord=True
|
|
304
|
+
)
|