flwr 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  5. flwr/cli/cli_user_auth_interceptor.py +1 -1
  6. flwr/cli/config_utils.py +3 -3
  7. flwr/cli/constant.py +25 -8
  8. flwr/cli/log.py +9 -9
  9. flwr/cli/login/login.py +3 -3
  10. flwr/cli/ls.py +5 -5
  11. flwr/cli/new/new.py +11 -0
  12. flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
  13. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
  14. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
  15. flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
  16. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  19. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  23. flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
  24. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  26. flwr/cli/run/run.py +9 -13
  27. flwr/cli/stop.py +7 -4
  28. flwr/cli/utils.py +19 -8
  29. flwr/client/grpc_rere_client/connection.py +1 -12
  30. flwr/client/rest_client/connection.py +3 -0
  31. flwr/clientapp/__init__.py +10 -0
  32. flwr/clientapp/mod/__init__.py +26 -0
  33. flwr/clientapp/mod/centraldp_mods.py +132 -0
  34. flwr/common/args.py +20 -6
  35. flwr/common/auth_plugin/__init__.py +4 -4
  36. flwr/common/auth_plugin/auth_plugin.py +7 -7
  37. flwr/common/constant.py +23 -4
  38. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  39. flwr/common/exit/__init__.py +4 -0
  40. flwr/common/exit/exit.py +8 -1
  41. flwr/common/exit/exit_code.py +26 -7
  42. flwr/common/exit/exit_handler.py +62 -0
  43. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  44. flwr/common/grpc.py +0 -11
  45. flwr/common/inflatable_utils.py +1 -1
  46. flwr/common/logger.py +1 -1
  47. flwr/common/retry_invoker.py +30 -11
  48. flwr/common/telemetry.py +4 -0
  49. flwr/compat/server/app.py +2 -2
  50. flwr/proto/appio_pb2.py +25 -17
  51. flwr/proto/appio_pb2.pyi +46 -2
  52. flwr/proto/clientappio_pb2.py +3 -11
  53. flwr/proto/clientappio_pb2.pyi +0 -47
  54. flwr/proto/clientappio_pb2_grpc.py +19 -20
  55. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  56. flwr/proto/control_pb2.py +62 -0
  57. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
  58. flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
  59. flwr/proto/serverappio_pb2.py +2 -2
  60. flwr/proto/serverappio_pb2_grpc.py +68 -0
  61. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  62. flwr/proto/simulationio_pb2.py +4 -11
  63. flwr/proto/simulationio_pb2.pyi +0 -58
  64. flwr/proto/simulationio_pb2_grpc.py +129 -27
  65. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  66. flwr/server/app.py +129 -152
  67. flwr/server/grid/grpc_grid.py +3 -0
  68. flwr/server/grid/inmemory_grid.py +1 -0
  69. flwr/server/serverapp/app.py +157 -146
  70. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  71. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  72. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  73. flwr/server/superlink/linkstate/linkstate.py +2 -1
  74. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  75. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  76. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  77. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  78. flwr/serverapp/__init__.py +12 -0
  79. flwr/serverapp/dp_fixed_clipping.py +352 -0
  80. flwr/serverapp/exception.py +38 -0
  81. flwr/serverapp/strategy/__init__.py +38 -0
  82. flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
  83. flwr/serverapp/strategy/fedadagrad.py +162 -0
  84. flwr/serverapp/strategy/fedadam.py +181 -0
  85. flwr/serverapp/strategy/fedavg.py +295 -0
  86. flwr/serverapp/strategy/fedopt.py +218 -0
  87. flwr/serverapp/strategy/fedyogi.py +173 -0
  88. flwr/serverapp/strategy/result.py +105 -0
  89. flwr/serverapp/strategy/strategy.py +285 -0
  90. flwr/serverapp/strategy/strategy_utils.py +251 -0
  91. flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
  92. flwr/simulation/app.py +161 -164
  93. flwr/supercore/app_utils.py +58 -0
  94. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  95. flwr/supercore/cli/flower_superexec.py +141 -0
  96. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  97. flwr/supercore/corestate/corestate.py +81 -0
  98. flwr/supercore/grpc_health/__init__.py +3 -0
  99. flwr/supercore/grpc_health/health_server.py +53 -0
  100. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  101. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  102. flwr/supercore/superexec/plugin/__init__.py +28 -0
  103. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  104. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  105. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +4 -4
  106. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  107. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  108. flwr/supercore/superexec/run_superexec.py +185 -0
  109. flwr/superlink/servicer/__init__.py +15 -0
  110. flwr/superlink/servicer/control/__init__.py +22 -0
  111. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  112. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +24 -29
  113. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  114. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +69 -30
  115. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  116. flwr/supernode/cli/flower_supernode.py +3 -0
  117. flwr/supernode/cli/flwr_clientapp.py +18 -21
  118. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  119. flwr/supernode/nodestate/nodestate.py +3 -59
  120. flwr/supernode/runtime/run_clientapp.py +39 -102
  121. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  122. flwr/supernode/start_client_internal.py +35 -76
  123. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/METADATA +4 -3
  124. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/RECORD +127 -98
  125. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
  126. flwr/proto/exec_pb2.py +0 -62
  127. flwr/superexec/app.py +0 -45
  128. flwr/superexec/deployment.py +0 -191
  129. flwr/superexec/executor.py +0 -100
  130. flwr/superexec/simulation.py +0 -129
  131. /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
  132. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,251 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based strategy utilities."""
16
+
17
+
18
+ import random
19
+ from collections import OrderedDict
20
+ from logging import INFO
21
+ from time import sleep
22
+ from typing import Optional, cast
23
+
24
+ import numpy as np
25
+
26
+ from flwr.common import (
27
+ Array,
28
+ ArrayRecord,
29
+ ConfigRecord,
30
+ MetricRecord,
31
+ NDArray,
32
+ RecordDict,
33
+ log,
34
+ )
35
+ from flwr.server import Grid
36
+
37
+ from ..exception import InconsistentMessageReplies
38
+
39
+
40
+ def config_to_str(config: ConfigRecord) -> str:
41
+ """Convert a ConfigRecord to a string representation masking bytes."""
42
+ content = ", ".join(
43
+ f"'{k}': {'<bytes>' if isinstance(v, bytes) else v}" for k, v in config.items()
44
+ )
45
+ return f"{{{content}}}"
46
+
47
+
48
+ def log_strategy_start_info(
49
+ num_rounds: int,
50
+ arrays: ArrayRecord,
51
+ train_config: Optional[ConfigRecord],
52
+ evaluate_config: Optional[ConfigRecord],
53
+ ) -> None:
54
+ """Log information about the strategy start."""
55
+ log(INFO, "\t├── Number of rounds: %d", num_rounds)
56
+ log(
57
+ INFO,
58
+ "\t├── ArrayRecord (%.2f MB)",
59
+ sum(len(array.data) for array in arrays.values()) / (1024**2),
60
+ )
61
+ log(
62
+ INFO,
63
+ "\t├── ConfigRecord (train): %s",
64
+ config_to_str(train_config) if train_config else "(empty!)",
65
+ )
66
+ log(
67
+ INFO,
68
+ "\t├── ConfigRecord (evaluate): %s",
69
+ config_to_str(evaluate_config) if evaluate_config else "(empty!)",
70
+ )
71
+
72
+
73
+ def aggregate_arrayrecords(
74
+ records: list[RecordDict], weighting_metric_name: str
75
+ ) -> ArrayRecord:
76
+ """Perform weighted aggregation all ArrayRecords using a specific key."""
77
+ # Retrieve weighting factor from MetricRecord
78
+ weights: list[float] = []
79
+ for record in records:
80
+ # Get the first (and only) MetricRecord in the record
81
+ metricrecord = next(iter(record.metric_records.values()))
82
+ # Because replies have been checked for consistency,
83
+ # we can safely cast the weighting factor to float
84
+ w = cast(float, metricrecord[weighting_metric_name])
85
+ weights.append(w)
86
+
87
+ # Average
88
+ total_weight = sum(weights)
89
+ weight_factors = [w / total_weight for w in weights]
90
+
91
+ # Perform weighted aggregation
92
+ aggregated_np_arrays: dict[str, NDArray] = {}
93
+
94
+ for record, weight in zip(records, weight_factors):
95
+ for record_item in record.array_records.values():
96
+ # aggregate in-place
97
+ for key, value in record_item.items():
98
+ if key not in aggregated_np_arrays:
99
+ aggregated_np_arrays[key] = value.numpy() * weight
100
+ else:
101
+ aggregated_np_arrays[key] += value.numpy() * weight
102
+
103
+ return ArrayRecord(
104
+ OrderedDict({k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()})
105
+ )
106
+
107
+
108
+ def aggregate_metricrecords(
109
+ records: list[RecordDict], weighting_metric_name: str
110
+ ) -> MetricRecord:
111
+ """Perform weighted aggregation all MetricRecords using a specific key."""
112
+ # Retrieve weighting factor from MetricRecord
113
+ weights: list[float] = []
114
+ for record in records:
115
+ # Get the first (and only) MetricRecord in the record
116
+ metricrecord = next(iter(record.metric_records.values()))
117
+ # Because replies have been checked for consistency,
118
+ # we can safely cast the weighting factor to float
119
+ w = cast(float, metricrecord[weighting_metric_name])
120
+ weights.append(w)
121
+
122
+ # Average
123
+ total_weight = sum(weights)
124
+ weight_factors = [w / total_weight for w in weights]
125
+
126
+ aggregated_metrics = MetricRecord()
127
+ for record, weight in zip(records, weight_factors):
128
+ for record_item in record.metric_records.values():
129
+ # aggregate in-place
130
+ for key, value in record_item.items():
131
+ if key == weighting_metric_name:
132
+ # We exclude the weighting key from the aggregated MetricRecord
133
+ continue
134
+ if key not in aggregated_metrics:
135
+ if isinstance(value, list):
136
+ aggregated_metrics[key] = [v * weight for v in value]
137
+ else:
138
+ aggregated_metrics[key] = value * weight
139
+ else:
140
+ if isinstance(value, list):
141
+ current_list = cast(list[float], aggregated_metrics[key])
142
+ aggregated_metrics[key] = [
143
+ curr + val * weight
144
+ for curr, val in zip(current_list, value)
145
+ ]
146
+ else:
147
+ current_value = cast(float, aggregated_metrics[key])
148
+ aggregated_metrics[key] = current_value + value * weight
149
+
150
+ return aggregated_metrics
151
+
152
+
153
+ def sample_nodes(
154
+ grid: Grid, min_available_nodes: int, sample_size: int
155
+ ) -> tuple[list[int], list[int]]:
156
+ """Sample the specified number of nodes using the Grid.
157
+
158
+ Parameters
159
+ ----------
160
+ grid : Grid
161
+ The grid object.
162
+ min_available_nodes : int
163
+ The minimum number of available nodes to sample from.
164
+ sample_size : int
165
+ The number of nodes to sample.
166
+
167
+ Returns
168
+ -------
169
+ tuple[list[int], list[int]]
170
+ A tuple containing the sampled node IDs and the list
171
+ of all connected node IDs.
172
+ """
173
+ sampled_nodes = []
174
+
175
+ # Ensure min_available_nodes is at least as large as sample_size
176
+ min_available_nodes = max(min_available_nodes, sample_size)
177
+
178
+ # wait for min_available_nodes to be online
179
+ while len(all_nodes := list(grid.get_node_ids())) < min_available_nodes:
180
+ log(
181
+ INFO,
182
+ "Waiting for nodes to connect: %d connected (minimum required: %d).",
183
+ len(all_nodes),
184
+ min_available_nodes,
185
+ )
186
+ sleep(1)
187
+
188
+ # Sample nodes
189
+ sampled_nodes = random.sample(all_nodes, sample_size)
190
+
191
+ return sampled_nodes, all_nodes
192
+
193
+
194
+ # pylint: disable=too-many-return-statements
195
+ def validate_message_reply_consistency(
196
+ replies: list[RecordDict], weighted_by_key: str, check_arrayrecord: bool
197
+ ) -> None:
198
+ """Validate that replies contain exactly one ArrayRecord and one MetricRecord, and
199
+ that the MetricRecord includes a weight factor key.
200
+
201
+ These checks ensure that Message-based strategies behave consistently with
202
+ *Ins/*Res-based strategies.
203
+ """
204
+ # Checking for ArrayRecord consistency
205
+ if check_arrayrecord:
206
+ if any(len(msg.array_records) != 1 for msg in replies):
207
+ raise InconsistentMessageReplies(
208
+ reason="Expected exactly one ArrayRecord in replies. "
209
+ "Skipping aggregation."
210
+ )
211
+
212
+ # Ensure all key are present in all ArrayRecords
213
+ record_key = next(iter(replies[0].array_records.keys()))
214
+ all_keys = set(replies[0][record_key].keys())
215
+ if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
216
+ raise InconsistentMessageReplies(
217
+ reason="All ArrayRecords must have the same keys for aggregation. "
218
+ "This condition wasn't met. Skipping aggregation."
219
+ )
220
+
221
+ # Checking for MetricRecord consistency
222
+ if any(len(msg.metric_records) != 1 for msg in replies):
223
+ raise InconsistentMessageReplies(
224
+ reason="Expected exactly one MetricRecord in replies, but found more. "
225
+ "Skipping aggregation."
226
+ )
227
+
228
+ # Ensure all key are present in all MetricRecords
229
+ record_key = next(iter(replies[0].metric_records.keys()))
230
+ all_keys = set(replies[0][record_key].keys())
231
+ if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
232
+ raise InconsistentMessageReplies(
233
+ reason="All MetricRecords must have the same keys for aggregation. "
234
+ "This condition wasn't met. Skipping aggregation."
235
+ )
236
+
237
+ # Verify the weight factor key presence in all MetricRecords
238
+ if weighted_by_key not in all_keys:
239
+ raise InconsistentMessageReplies(
240
+ reason=f"Missing required key `{weighted_by_key}` in the MetricRecord of "
241
+ "reply messages. Cannot average ArrayRecords and MetricRecords. Skipping "
242
+ "aggregation."
243
+ )
244
+
245
+ # Check that it is not a list
246
+ if any(isinstance(msg[record_key][weighted_by_key], list) for msg in replies):
247
+ raise InconsistentMessageReplies(
248
+ reason=f"Key `{weighted_by_key}` in the MetricRecord of reply messages "
249
+ "must be a single value (int or float), but a list was found. Skipping "
250
+ "aggregation."
251
+ )
@@ -0,0 +1,304 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for message-based strategy utilities."""
16
+
17
+
18
+ from collections import OrderedDict
19
+
20
+ import numpy as np
21
+ import pytest
22
+ from parameterized import parameterized
23
+
24
+ from flwr.common import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
25
+ from flwr.serverapp.exception import InconsistentMessageReplies
26
+
27
+ from .strategy_utils import (
28
+ aggregate_arrayrecords,
29
+ aggregate_metricrecords,
30
+ config_to_str,
31
+ validate_message_reply_consistency,
32
+ )
33
+
34
+
35
+ def test_config_to_str() -> None:
36
+ """Test that items of types bytes are masked out."""
37
+ config = ConfigRecord({"a": 123, "b": [1, 2, 3], "c": b"bytes"})
38
+ expected_str = "{'a': 123, 'b': [1, 2, 3], 'c': <bytes>}"
39
+ assert config_to_str(config) == expected_str
40
+
41
+
42
+ def test_arrayrecords_aggregation() -> None:
43
+ """Test aggregation of ArrayRecords."""
44
+ num_replies = 3
45
+ num_arrays = 4
46
+ weights = [0.25, 0.4, 0.35]
47
+ np_arrays = [
48
+ [np.random.randn(7, 3) for _ in range(num_arrays)] for _ in range(num_replies)
49
+ ]
50
+
51
+ avg_list = [
52
+ np.average([lst[i] for lst in np_arrays], axis=0, weights=weights)
53
+ for i in range(num_arrays)
54
+ ]
55
+
56
+ # Construct RecordDicts (mimicing replies)
57
+ records = [
58
+ RecordDict(
59
+ {
60
+ "arrays": ArrayRecord(np_arrays[i]),
61
+ "metrics": MetricRecord({"weight": weights[i]}),
62
+ }
63
+ )
64
+ for i in range(num_replies)
65
+ ]
66
+ # Execute aggregate
67
+ aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
68
+
69
+ # Assert consistency
70
+ assert all(np.allclose(a, b) for a, b in zip(aggrd.to_numpy_ndarrays(), avg_list))
71
+ assert aggrd.object_id == ArrayRecord(avg_list).object_id
72
+
73
+
74
+ def test_arrayrecords_aggregation_with_ndim_zero() -> None:
75
+ """Test aggregation of ArrayRecords with 0-dim arrays."""
76
+ num_replies = 3
77
+ weights = [0.25, 0.4, 0.35]
78
+ np_arrays = [np.array(np.random.randn()) for _ in range(num_replies)]
79
+
80
+ # For 0-dimensional arrays, we just compute the weighted average directly
81
+ avg_list = [np.average(np_arrays, axis=0, weights=weights)]
82
+
83
+ # Construct RecordDicts (mimicing replies)
84
+ records = [
85
+ RecordDict(
86
+ {
87
+ "arrays": ArrayRecord([np_arrays[i]]),
88
+ "metrics": MetricRecord({"weight": weights[i]}),
89
+ }
90
+ )
91
+ for i in range(num_replies)
92
+ ]
93
+ # Execute aggregate
94
+ aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
95
+
96
+ # Assert consistency
97
+ assert np.isclose(aggrd.to_numpy_ndarrays()[0], avg_list[0])
98
+ assert aggrd.object_id == ArrayRecord([np.array(avg_list[0])]).object_id
99
+
100
+
101
+ def test_metricrecords_aggregation() -> None:
102
+ """Test aggregation of MetricRecords."""
103
+ num_replies = 3
104
+ weights = [0.25, 0.4, 0.35]
105
+ metric_records = [
106
+ MetricRecord({"a": 1, "b": 2.0, "c": np.random.randn(3).tolist()})
107
+ for _ in range(num_replies)
108
+ ]
109
+
110
+ # Compute expected aggregated MetricRecord.
111
+ # For ease, we convert everything into numpy arrays, then aggregate
112
+ as_np_entries = [
113
+ {
114
+ k: np.array(v) if isinstance(v, (int, float, list)) else v
115
+ for k, v in record.items()
116
+ }
117
+ for record in metric_records
118
+ ]
119
+ avg_list = [
120
+ np.average(
121
+ [list(entries.values())[i] for entries in as_np_entries],
122
+ axis=0,
123
+ weights=weights,
124
+ ).tolist()
125
+ for i in range(len(as_np_entries[0]))
126
+ ]
127
+ expected_record = MetricRecord(dict(zip(as_np_entries[0].keys(), avg_list)))
128
+ expected_record["a"] = float(expected_record["a"]) # type: ignore
129
+ expected_record["b"] = float(expected_record["b"]) # type: ignore
130
+
131
+ # Construct RecordDicts (mimicing replies)
132
+ # Inject weighting factor
133
+ records = [
134
+ RecordDict(
135
+ {
136
+ "metrics": MetricRecord(
137
+ record.__dict__["_data"] | {"weight": weights[i]}
138
+ ),
139
+ }
140
+ )
141
+ for i, record in enumerate(metric_records)
142
+ ]
143
+
144
+ # Execute aggregate
145
+ aggrd = aggregate_metricrecords(records, weighting_metric_name="weight")
146
+ # Assert
147
+ assert expected_record.object_id == aggrd.object_id
148
+
149
+
150
+ @parameterized.expand( # type: ignore
151
+ [
152
+ (
153
+ True,
154
+ RecordDict(
155
+ {
156
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
157
+ "metrics": MetricRecord({"weight": 0.123}),
158
+ }
159
+ ),
160
+ ), # Compliant
161
+ (
162
+ False,
163
+ RecordDict(
164
+ {
165
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
166
+ "metrics": MetricRecord({"weight": [0.123]}),
167
+ }
168
+ ),
169
+ ), # Weighting key is not a scalar (BAD)
170
+ (
171
+ False,
172
+ RecordDict(
173
+ {
174
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
175
+ "metrics": MetricRecord({"loss": 0.01}),
176
+ }
177
+ ),
178
+ ), # No weighting key in MetricRecord (BAD)
179
+ (
180
+ False,
181
+ RecordDict({"global-model": ArrayRecord([np.random.randn(7, 3)])}),
182
+ ), # No MetricsRecord (BAD)
183
+ (
184
+ False,
185
+ RecordDict(
186
+ {
187
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
188
+ "another-model": ArrayRecord([np.random.randn(7, 3)]),
189
+ }
190
+ ),
191
+ ), # Two ArrayRecords (BAD)
192
+ (
193
+ False,
194
+ RecordDict(
195
+ {
196
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
197
+ "metrics": MetricRecord({"weight": 0.123}),
198
+ "more-metrics": MetricRecord({"loss": 0.321}),
199
+ }
200
+ ),
201
+ ), # Two MetricRecords (BAD)
202
+ ]
203
+ )
204
+ def test_consistency_of_replies_with_matching_keys(
205
+ is_valid: bool, recorddict: RecordDict
206
+ ) -> None:
207
+ """Test consistency in replies."""
208
+ # Create dummy records
209
+ records = [recorddict for _ in range(3)]
210
+
211
+ if not is_valid:
212
+ # Should raise InconsistentMessageReplies exception
213
+ with pytest.raises(InconsistentMessageReplies):
214
+ validate_message_reply_consistency(
215
+ records, weighted_by_key="weight", check_arrayrecord=True
216
+ )
217
+ else:
218
+ # Should not raise an exception
219
+ validate_message_reply_consistency(
220
+ records, weighted_by_key="weight", check_arrayrecord=True
221
+ )
222
+
223
+
224
+ @parameterized.expand( # type: ignore
225
+ [
226
+ (
227
+ [
228
+ RecordDict(
229
+ {
230
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
231
+ "metrics": MetricRecord({"weight": 0.123}),
232
+ }
233
+ ),
234
+ RecordDict(
235
+ {
236
+ "model": ArrayRecord([np.random.randn(7, 3)]),
237
+ "metrics": MetricRecord({"weight": 0.123}),
238
+ }
239
+ ),
240
+ ],
241
+ ), # top-level keys don't match for ArrayRecords
242
+ (
243
+ [
244
+ RecordDict(
245
+ {
246
+ "global-model": ArrayRecord(
247
+ OrderedDict({"a": Array(np.random.randn(7, 3))})
248
+ ),
249
+ "metrics": MetricRecord({"weight": 0.123}),
250
+ }
251
+ ),
252
+ RecordDict(
253
+ {
254
+ "global-model": ArrayRecord(
255
+ OrderedDict({"b": Array(np.random.randn(7, 3))})
256
+ ),
257
+ "metrics": MetricRecord({"weight": 0.123}),
258
+ }
259
+ ),
260
+ ],
261
+ ), # top-level keys match for ArrayRecords but not those for Arrays
262
+ (
263
+ [
264
+ RecordDict(
265
+ {
266
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
267
+ "metrics": MetricRecord({"weight": 0.123}),
268
+ }
269
+ ),
270
+ RecordDict(
271
+ {
272
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
273
+ "my-metrics": MetricRecord({"weight": 0.123}),
274
+ }
275
+ ),
276
+ ],
277
+ ), # top-level keys don't match for MetricRecords
278
+ (
279
+ [
280
+ RecordDict(
281
+ {
282
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
283
+ "metrics": MetricRecord({"weight": 0.123}),
284
+ }
285
+ ),
286
+ RecordDict(
287
+ {
288
+ "global-model": ArrayRecord([np.random.randn(7, 3)]),
289
+ "my-metrics": MetricRecord({"my-weights": 0.123}),
290
+ }
291
+ ),
292
+ ],
293
+ ), # top-level keys match for MetricRecords but not inner ones
294
+ ]
295
+ )
296
+ def test_consistency_of_replies_with_different_keys(
297
+ list_records: list[RecordDict],
298
+ ) -> None:
299
+ """Test consistency in replies when records don't have matching keys."""
300
+ # All test cases expect InconsistentMessageReplies exception to be raised
301
+ with pytest.raises(InconsistentMessageReplies):
302
+ validate_message_reply_consistency(
303
+ list_records, weighted_by_key="weight", check_arrayrecord=True
304
+ )