flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +17 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +196 -42
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +109 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/clientapp/mod/__init__.py +4 -1
- flwr/clientapp/mod/centraldp_mods.py +156 -40
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +56 -13
- flwr/common/exit/exit_code.py +24 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -31
- flwr/proto/control_pb2.pyi +95 -5
- flwr/proto/control_pb2_grpc.py +136 -0
- flwr/proto/control_pb2_grpc.pyi +52 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +152 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/serverapp/strategy/__init__.py +26 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavg.py +89 -64
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +28 -32
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +41 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +16 -11
- flwr/superlink/servicer/control/control_servicer.py +207 -58
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,304 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""Tests for message-based strategy utilities."""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from collections import OrderedDict
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
import pytest
|
|
22
|
-
from parameterized import parameterized
|
|
23
|
-
|
|
24
|
-
from flwr.common import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
|
25
|
-
from flwr.serverapp.exception import InconsistentMessageReplies
|
|
26
|
-
|
|
27
|
-
from .strategy_utils import (
|
|
28
|
-
aggregate_arrayrecords,
|
|
29
|
-
aggregate_metricrecords,
|
|
30
|
-
config_to_str,
|
|
31
|
-
validate_message_reply_consistency,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def test_config_to_str() -> None:
|
|
36
|
-
"""Test that items of types bytes are masked out."""
|
|
37
|
-
config = ConfigRecord({"a": 123, "b": [1, 2, 3], "c": b"bytes"})
|
|
38
|
-
expected_str = "{'a': 123, 'b': [1, 2, 3], 'c': <bytes>}"
|
|
39
|
-
assert config_to_str(config) == expected_str
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def test_arrayrecords_aggregation() -> None:
|
|
43
|
-
"""Test aggregation of ArrayRecords."""
|
|
44
|
-
num_replies = 3
|
|
45
|
-
num_arrays = 4
|
|
46
|
-
weights = [0.25, 0.4, 0.35]
|
|
47
|
-
np_arrays = [
|
|
48
|
-
[np.random.randn(7, 3) for _ in range(num_arrays)] for _ in range(num_replies)
|
|
49
|
-
]
|
|
50
|
-
|
|
51
|
-
avg_list = [
|
|
52
|
-
np.average([lst[i] for lst in np_arrays], axis=0, weights=weights)
|
|
53
|
-
for i in range(num_arrays)
|
|
54
|
-
]
|
|
55
|
-
|
|
56
|
-
# Construct RecordDicts (mimicing replies)
|
|
57
|
-
records = [
|
|
58
|
-
RecordDict(
|
|
59
|
-
{
|
|
60
|
-
"arrays": ArrayRecord(np_arrays[i]),
|
|
61
|
-
"metrics": MetricRecord({"weight": weights[i]}),
|
|
62
|
-
}
|
|
63
|
-
)
|
|
64
|
-
for i in range(num_replies)
|
|
65
|
-
]
|
|
66
|
-
# Execute aggregate
|
|
67
|
-
aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
|
|
68
|
-
|
|
69
|
-
# Assert consistency
|
|
70
|
-
assert all(np.allclose(a, b) for a, b in zip(aggrd.to_numpy_ndarrays(), avg_list))
|
|
71
|
-
assert aggrd.object_id == ArrayRecord(avg_list).object_id
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def test_arrayrecords_aggregation_with_ndim_zero() -> None:
|
|
75
|
-
"""Test aggregation of ArrayRecords with 0-dim arrays."""
|
|
76
|
-
num_replies = 3
|
|
77
|
-
weights = [0.25, 0.4, 0.35]
|
|
78
|
-
np_arrays = [np.array(np.random.randn()) for _ in range(num_replies)]
|
|
79
|
-
|
|
80
|
-
# For 0-dimensional arrays, we just compute the weighted average directly
|
|
81
|
-
avg_list = [np.average(np_arrays, axis=0, weights=weights)]
|
|
82
|
-
|
|
83
|
-
# Construct RecordDicts (mimicing replies)
|
|
84
|
-
records = [
|
|
85
|
-
RecordDict(
|
|
86
|
-
{
|
|
87
|
-
"arrays": ArrayRecord([np_arrays[i]]),
|
|
88
|
-
"metrics": MetricRecord({"weight": weights[i]}),
|
|
89
|
-
}
|
|
90
|
-
)
|
|
91
|
-
for i in range(num_replies)
|
|
92
|
-
]
|
|
93
|
-
# Execute aggregate
|
|
94
|
-
aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
|
|
95
|
-
|
|
96
|
-
# Assert consistency
|
|
97
|
-
assert np.isclose(aggrd.to_numpy_ndarrays()[0], avg_list[0])
|
|
98
|
-
assert aggrd.object_id == ArrayRecord([np.array(avg_list[0])]).object_id
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def test_metricrecords_aggregation() -> None:
|
|
102
|
-
"""Test aggregation of MetricRecords."""
|
|
103
|
-
num_replies = 3
|
|
104
|
-
weights = [0.25, 0.4, 0.35]
|
|
105
|
-
metric_records = [
|
|
106
|
-
MetricRecord({"a": 1, "b": 2.0, "c": np.random.randn(3).tolist()})
|
|
107
|
-
for _ in range(num_replies)
|
|
108
|
-
]
|
|
109
|
-
|
|
110
|
-
# Compute expected aggregated MetricRecord.
|
|
111
|
-
# For ease, we convert everything into numpy arrays, then aggregate
|
|
112
|
-
as_np_entries = [
|
|
113
|
-
{
|
|
114
|
-
k: np.array(v) if isinstance(v, (int, float, list)) else v
|
|
115
|
-
for k, v in record.items()
|
|
116
|
-
}
|
|
117
|
-
for record in metric_records
|
|
118
|
-
]
|
|
119
|
-
avg_list = [
|
|
120
|
-
np.average(
|
|
121
|
-
[list(entries.values())[i] for entries in as_np_entries],
|
|
122
|
-
axis=0,
|
|
123
|
-
weights=weights,
|
|
124
|
-
).tolist()
|
|
125
|
-
for i in range(len(as_np_entries[0]))
|
|
126
|
-
]
|
|
127
|
-
expected_record = MetricRecord(dict(zip(as_np_entries[0].keys(), avg_list)))
|
|
128
|
-
expected_record["a"] = float(expected_record["a"]) # type: ignore
|
|
129
|
-
expected_record["b"] = float(expected_record["b"]) # type: ignore
|
|
130
|
-
|
|
131
|
-
# Construct RecordDicts (mimicing replies)
|
|
132
|
-
# Inject weighting factor
|
|
133
|
-
records = [
|
|
134
|
-
RecordDict(
|
|
135
|
-
{
|
|
136
|
-
"metrics": MetricRecord(
|
|
137
|
-
record.__dict__["_data"] | {"weight": weights[i]}
|
|
138
|
-
),
|
|
139
|
-
}
|
|
140
|
-
)
|
|
141
|
-
for i, record in enumerate(metric_records)
|
|
142
|
-
]
|
|
143
|
-
|
|
144
|
-
# Execute aggregate
|
|
145
|
-
aggrd = aggregate_metricrecords(records, weighting_metric_name="weight")
|
|
146
|
-
# Assert
|
|
147
|
-
assert expected_record.object_id == aggrd.object_id
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
@parameterized.expand( # type: ignore
|
|
151
|
-
[
|
|
152
|
-
(
|
|
153
|
-
True,
|
|
154
|
-
RecordDict(
|
|
155
|
-
{
|
|
156
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
157
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
158
|
-
}
|
|
159
|
-
),
|
|
160
|
-
), # Compliant
|
|
161
|
-
(
|
|
162
|
-
False,
|
|
163
|
-
RecordDict(
|
|
164
|
-
{
|
|
165
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
166
|
-
"metrics": MetricRecord({"weight": [0.123]}),
|
|
167
|
-
}
|
|
168
|
-
),
|
|
169
|
-
), # Weighting key is not a scalar (BAD)
|
|
170
|
-
(
|
|
171
|
-
False,
|
|
172
|
-
RecordDict(
|
|
173
|
-
{
|
|
174
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
175
|
-
"metrics": MetricRecord({"loss": 0.01}),
|
|
176
|
-
}
|
|
177
|
-
),
|
|
178
|
-
), # No weighting key in MetricRecord (BAD)
|
|
179
|
-
(
|
|
180
|
-
False,
|
|
181
|
-
RecordDict({"global-model": ArrayRecord([np.random.randn(7, 3)])}),
|
|
182
|
-
), # No MetricsRecord (BAD)
|
|
183
|
-
(
|
|
184
|
-
False,
|
|
185
|
-
RecordDict(
|
|
186
|
-
{
|
|
187
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
188
|
-
"another-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
189
|
-
}
|
|
190
|
-
),
|
|
191
|
-
), # Two ArrayRecords (BAD)
|
|
192
|
-
(
|
|
193
|
-
False,
|
|
194
|
-
RecordDict(
|
|
195
|
-
{
|
|
196
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
197
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
198
|
-
"more-metrics": MetricRecord({"loss": 0.321}),
|
|
199
|
-
}
|
|
200
|
-
),
|
|
201
|
-
), # Two MetricRecords (BAD)
|
|
202
|
-
]
|
|
203
|
-
)
|
|
204
|
-
def test_consistency_of_replies_with_matching_keys(
|
|
205
|
-
is_valid: bool, recorddict: RecordDict
|
|
206
|
-
) -> None:
|
|
207
|
-
"""Test consistency in replies."""
|
|
208
|
-
# Create dummy records
|
|
209
|
-
records = [recorddict for _ in range(3)]
|
|
210
|
-
|
|
211
|
-
if not is_valid:
|
|
212
|
-
# Should raise InconsistentMessageReplies exception
|
|
213
|
-
with pytest.raises(InconsistentMessageReplies):
|
|
214
|
-
validate_message_reply_consistency(
|
|
215
|
-
records, weighted_by_key="weight", check_arrayrecord=True
|
|
216
|
-
)
|
|
217
|
-
else:
|
|
218
|
-
# Should not raise an exception
|
|
219
|
-
validate_message_reply_consistency(
|
|
220
|
-
records, weighted_by_key="weight", check_arrayrecord=True
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
@parameterized.expand( # type: ignore
|
|
225
|
-
[
|
|
226
|
-
(
|
|
227
|
-
[
|
|
228
|
-
RecordDict(
|
|
229
|
-
{
|
|
230
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
231
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
232
|
-
}
|
|
233
|
-
),
|
|
234
|
-
RecordDict(
|
|
235
|
-
{
|
|
236
|
-
"model": ArrayRecord([np.random.randn(7, 3)]),
|
|
237
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
238
|
-
}
|
|
239
|
-
),
|
|
240
|
-
],
|
|
241
|
-
), # top-level keys don't match for ArrayRecords
|
|
242
|
-
(
|
|
243
|
-
[
|
|
244
|
-
RecordDict(
|
|
245
|
-
{
|
|
246
|
-
"global-model": ArrayRecord(
|
|
247
|
-
OrderedDict({"a": Array(np.random.randn(7, 3))})
|
|
248
|
-
),
|
|
249
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
250
|
-
}
|
|
251
|
-
),
|
|
252
|
-
RecordDict(
|
|
253
|
-
{
|
|
254
|
-
"global-model": ArrayRecord(
|
|
255
|
-
OrderedDict({"b": Array(np.random.randn(7, 3))})
|
|
256
|
-
),
|
|
257
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
258
|
-
}
|
|
259
|
-
),
|
|
260
|
-
],
|
|
261
|
-
), # top-level keys match for ArrayRecords but not those for Arrays
|
|
262
|
-
(
|
|
263
|
-
[
|
|
264
|
-
RecordDict(
|
|
265
|
-
{
|
|
266
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
267
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
268
|
-
}
|
|
269
|
-
),
|
|
270
|
-
RecordDict(
|
|
271
|
-
{
|
|
272
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
273
|
-
"my-metrics": MetricRecord({"weight": 0.123}),
|
|
274
|
-
}
|
|
275
|
-
),
|
|
276
|
-
],
|
|
277
|
-
), # top-level keys don't match for MetricRecords
|
|
278
|
-
(
|
|
279
|
-
[
|
|
280
|
-
RecordDict(
|
|
281
|
-
{
|
|
282
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
283
|
-
"metrics": MetricRecord({"weight": 0.123}),
|
|
284
|
-
}
|
|
285
|
-
),
|
|
286
|
-
RecordDict(
|
|
287
|
-
{
|
|
288
|
-
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
|
289
|
-
"my-metrics": MetricRecord({"my-weights": 0.123}),
|
|
290
|
-
}
|
|
291
|
-
),
|
|
292
|
-
],
|
|
293
|
-
), # top-level keys match for MetricRecords but not inner ones
|
|
294
|
-
]
|
|
295
|
-
)
|
|
296
|
-
def test_consistency_of_replies_with_different_keys(
|
|
297
|
-
list_records: list[RecordDict],
|
|
298
|
-
) -> None:
|
|
299
|
-
"""Test consistency in replies when records don't have matching keys."""
|
|
300
|
-
# All test cases expect InconsistentMessageReplies exception to be raised
|
|
301
|
-
with pytest.raises(InconsistentMessageReplies):
|
|
302
|
-
validate_message_reply_consistency(
|
|
303
|
-
list_records, weighted_by_key="weight", check_arrayrecord=True
|
|
304
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|