flwr 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (248) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/cli/__init__.py +1 -1
  3. flwr/cli/app.py +21 -2
  4. flwr/cli/build.py +1 -1
  5. flwr/cli/cli_user_auth_interceptor.py +1 -1
  6. flwr/cli/config_utils.py +53 -17
  7. flwr/cli/example.py +1 -1
  8. flwr/cli/install.py +1 -1
  9. flwr/cli/log.py +1 -1
  10. flwr/cli/login/__init__.py +1 -1
  11. flwr/cli/login/login.py +12 -1
  12. flwr/cli/ls.py +1 -1
  13. flwr/cli/new/__init__.py +1 -1
  14. flwr/cli/new/new.py +4 -4
  15. flwr/cli/new/templates/__init__.py +1 -1
  16. flwr/cli/new/templates/app/__init__.py +1 -1
  17. flwr/cli/new/templates/app/code/__init__.py +1 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
  19. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +5 -5
  20. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  23. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  30. flwr/cli/run/__init__.py +1 -1
  31. flwr/cli/run/run.py +6 -10
  32. flwr/cli/stop.py +1 -1
  33. flwr/cli/utils.py +11 -12
  34. flwr/client/__init__.py +1 -1
  35. flwr/client/app.py +58 -56
  36. flwr/client/client.py +1 -1
  37. flwr/client/client_app.py +231 -166
  38. flwr/client/clientapp/__init__.py +1 -1
  39. flwr/client/clientapp/app.py +3 -3
  40. flwr/client/clientapp/clientappio_servicer.py +1 -1
  41. flwr/client/clientapp/utils.py +1 -1
  42. flwr/client/dpfedavg_numpy_client.py +1 -1
  43. flwr/client/grpc_adapter_client/__init__.py +1 -1
  44. flwr/client/grpc_adapter_client/connection.py +1 -1
  45. flwr/client/grpc_client/__init__.py +1 -1
  46. flwr/client/grpc_client/connection.py +37 -34
  47. flwr/client/grpc_rere_client/__init__.py +1 -1
  48. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  49. flwr/client/grpc_rere_client/connection.py +1 -1
  50. flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
  51. flwr/client/heartbeat.py +1 -1
  52. flwr/client/message_handler/__init__.py +1 -1
  53. flwr/client/message_handler/message_handler.py +28 -28
  54. flwr/client/mod/__init__.py +3 -3
  55. flwr/client/mod/centraldp_mods.py +8 -8
  56. flwr/client/mod/comms_mods.py +17 -23
  57. flwr/client/mod/localdp_mod.py +10 -10
  58. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  59. flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
  60. flwr/client/mod/secure_aggregation/secaggplus_mod.py +32 -32
  61. flwr/client/mod/utils.py +1 -1
  62. flwr/client/nodestate/__init__.py +1 -1
  63. flwr/client/nodestate/in_memory_nodestate.py +1 -1
  64. flwr/client/nodestate/nodestate.py +1 -1
  65. flwr/client/nodestate/nodestate_factory.py +1 -1
  66. flwr/client/numpy_client.py +1 -1
  67. flwr/client/rest_client/__init__.py +1 -1
  68. flwr/client/rest_client/connection.py +1 -1
  69. flwr/client/run_info_store.py +3 -3
  70. flwr/client/supernode/__init__.py +1 -1
  71. flwr/client/supernode/app.py +1 -1
  72. flwr/client/typing.py +1 -1
  73. flwr/common/__init__.py +13 -5
  74. flwr/common/address.py +1 -1
  75. flwr/common/args.py +1 -1
  76. flwr/common/auth_plugin/__init__.py +1 -1
  77. flwr/common/auth_plugin/auth_plugin.py +1 -1
  78. flwr/common/config.py +5 -5
  79. flwr/common/constant.py +7 -7
  80. flwr/common/context.py +5 -5
  81. flwr/common/date.py +1 -1
  82. flwr/common/differential_privacy.py +1 -1
  83. flwr/common/differential_privacy_constants.py +1 -1
  84. flwr/common/dp.py +1 -1
  85. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  86. flwr/common/exit/exit.py +6 -6
  87. flwr/common/exit_handlers.py +1 -1
  88. flwr/common/grpc.py +1 -1
  89. flwr/common/logger.py +3 -3
  90. flwr/common/message.py +344 -102
  91. flwr/common/object_ref.py +1 -1
  92. flwr/common/parameter.py +1 -1
  93. flwr/common/pyproject.py +1 -1
  94. flwr/common/record/__init__.py +9 -5
  95. flwr/common/record/arrayrecord.py +626 -0
  96. flwr/common/record/{configsrecord.py → configrecord.py} +83 -37
  97. flwr/common/record/conversion_utils.py +2 -2
  98. flwr/common/record/{metricsrecord.py → metricrecord.py} +90 -44
  99. flwr/common/record/recorddict.py +337 -0
  100. flwr/common/record/typeddict.py +1 -1
  101. flwr/common/recorddict_compat.py +410 -0
  102. flwr/common/retry_invoker.py +10 -10
  103. flwr/common/secure_aggregation/__init__.py +1 -1
  104. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  105. flwr/common/secure_aggregation/crypto/shamir.py +52 -30
  106. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  107. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  108. flwr/common/secure_aggregation/quantization.py +1 -1
  109. flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
  110. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  111. flwr/common/serde.py +67 -72
  112. flwr/common/telemetry.py +2 -2
  113. flwr/common/typing.py +9 -9
  114. flwr/common/version.py +1 -1
  115. flwr/proto/__init__.py +1 -1
  116. flwr/proto/exec_pb2.py +3 -3
  117. flwr/proto/exec_pb2.pyi +3 -3
  118. flwr/proto/message_pb2.py +12 -12
  119. flwr/proto/message_pb2.pyi +9 -9
  120. flwr/proto/recorddict_pb2.py +70 -0
  121. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  122. flwr/proto/run_pb2.py +31 -31
  123. flwr/proto/run_pb2.pyi +3 -3
  124. flwr/server/__init__.py +4 -2
  125. flwr/server/app.py +67 -12
  126. flwr/server/client_manager.py +1 -1
  127. flwr/server/client_proxy.py +1 -1
  128. flwr/server/compat/__init__.py +3 -3
  129. flwr/server/compat/app.py +12 -12
  130. flwr/server/compat/app_utils.py +17 -17
  131. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
  132. flwr/server/compat/legacy_context.py +1 -1
  133. flwr/server/criterion.py +1 -1
  134. flwr/server/fleet_event_log_interceptor.py +94 -0
  135. flwr/server/{driver → grid}/__init__.py +8 -7
  136. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  137. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
  138. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
  139. flwr/server/history.py +1 -1
  140. flwr/server/run_serverapp.py +5 -5
  141. flwr/server/server.py +1 -1
  142. flwr/server/server_app.py +98 -71
  143. flwr/server/server_config.py +1 -1
  144. flwr/server/serverapp/__init__.py +1 -1
  145. flwr/server/serverapp/app.py +11 -11
  146. flwr/server/serverapp_components.py +1 -1
  147. flwr/server/strategy/__init__.py +1 -1
  148. flwr/server/strategy/aggregate.py +1 -1
  149. flwr/server/strategy/bulyan.py +2 -2
  150. flwr/server/strategy/dp_adaptive_clipping.py +17 -17
  151. flwr/server/strategy/dp_fixed_clipping.py +17 -17
  152. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  153. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  154. flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
  155. flwr/server/strategy/fedadagrad.py +1 -1
  156. flwr/server/strategy/fedadam.py +1 -1
  157. flwr/server/strategy/fedavg.py +1 -1
  158. flwr/server/strategy/fedavg_android.py +1 -1
  159. flwr/server/strategy/fedavgm.py +1 -1
  160. flwr/server/strategy/fedmedian.py +1 -1
  161. flwr/server/strategy/fedopt.py +1 -1
  162. flwr/server/strategy/fedprox.py +1 -1
  163. flwr/server/strategy/fedtrimmedavg.py +1 -1
  164. flwr/server/strategy/fedxgb_bagging.py +1 -1
  165. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  166. flwr/server/strategy/fedxgb_nn_avg.py +3 -2
  167. flwr/server/strategy/fedyogi.py +1 -1
  168. flwr/server/strategy/krum.py +1 -1
  169. flwr/server/strategy/qfedavg.py +1 -1
  170. flwr/server/strategy/strategy.py +1 -1
  171. flwr/server/superlink/__init__.py +1 -1
  172. flwr/server/superlink/ffs/__init__.py +1 -1
  173. flwr/server/superlink/ffs/disk_ffs.py +1 -1
  174. flwr/server/superlink/ffs/ffs.py +1 -1
  175. flwr/server/superlink/ffs/ffs_factory.py +1 -1
  176. flwr/server/superlink/fleet/__init__.py +1 -1
  177. flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
  178. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
  179. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  180. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  181. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  182. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  183. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
  184. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  185. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  186. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
  187. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  188. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  189. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  190. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  191. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  192. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  193. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  194. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -3
  195. flwr/server/superlink/fleet/vce/vce_api.py +2 -4
  196. flwr/server/superlink/linkstate/__init__.py +1 -1
  197. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -9
  198. flwr/server/superlink/linkstate/linkstate.py +5 -5
  199. flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
  200. flwr/server/superlink/linkstate/sqlite_linkstate.py +62 -28
  201. flwr/server/superlink/linkstate/utils.py +94 -28
  202. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  203. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  204. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
  205. flwr/server/superlink/simulation/__init__.py +1 -1
  206. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  207. flwr/server/superlink/simulation/simulationio_servicer.py +3 -3
  208. flwr/server/superlink/utils.py +1 -1
  209. flwr/server/typing.py +4 -4
  210. flwr/server/utils/__init__.py +1 -1
  211. flwr/server/utils/tensorboard.py +1 -1
  212. flwr/server/utils/validator.py +5 -5
  213. flwr/server/workflow/__init__.py +1 -1
  214. flwr/server/workflow/constant.py +1 -1
  215. flwr/server/workflow/default_workflows.py +49 -58
  216. flwr/server/workflow/secure_aggregation/__init__.py +1 -1
  217. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
  218. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +49 -51
  219. flwr/simulation/__init__.py +1 -1
  220. flwr/simulation/app.py +3 -3
  221. flwr/simulation/legacy_app.py +1 -1
  222. flwr/simulation/ray_transport/__init__.py +1 -1
  223. flwr/simulation/ray_transport/ray_actor.py +5 -3
  224. flwr/simulation/ray_transport/ray_client_proxy.py +35 -33
  225. flwr/simulation/ray_transport/utils.py +1 -1
  226. flwr/simulation/run_simulation.py +17 -17
  227. flwr/simulation/simulationio_connection.py +1 -1
  228. flwr/superexec/__init__.py +1 -1
  229. flwr/superexec/app.py +1 -1
  230. flwr/superexec/deployment.py +5 -5
  231. flwr/superexec/exec_event_log_interceptor.py +135 -0
  232. flwr/superexec/exec_grpc.py +11 -5
  233. flwr/superexec/exec_servicer.py +3 -3
  234. flwr/superexec/exec_user_auth_interceptor.py +19 -3
  235. flwr/superexec/executor.py +4 -4
  236. flwr/superexec/simulation.py +4 -4
  237. {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/METADATA +3 -3
  238. flwr-1.18.0.dist-info/RECORD +332 -0
  239. flwr/common/record/parametersrecord.py +0 -339
  240. flwr/common/record/recordset.py +0 -209
  241. flwr/common/recordset_compat.py +0 -418
  242. flwr/proto/recordset_pb2.py +0 -70
  243. flwr-1.16.0.dist-info/LICENSE +0 -202
  244. flwr-1.16.0.dist-info/RECORD +0 -331
  245. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  246. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  247. {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/WHEEL +0 -0
  248. {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,410 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """RecordDict utilities."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from collections.abc import Mapping
20
+ from typing import Union, cast, get_args
21
+
22
+ from . import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
23
+ from .typing import (
24
+ Code,
25
+ ConfigRecordValues,
26
+ EvaluateIns,
27
+ EvaluateRes,
28
+ FitIns,
29
+ FitRes,
30
+ GetParametersIns,
31
+ GetParametersRes,
32
+ GetPropertiesIns,
33
+ GetPropertiesRes,
34
+ MetricRecordValues,
35
+ Parameters,
36
+ Scalar,
37
+ Status,
38
+ )
39
+
40
+ EMPTY_TENSOR_KEY = "_empty"
41
+
42
+
43
+ def arrayrecord_to_parameters(record: ArrayRecord, keep_input: bool) -> Parameters:
44
+ """Convert ParameterRecord to legacy Parameters.
45
+
46
+ Warnings
47
+ --------
48
+ Because `Array`s in `ArrayRecord` encode more information of the
49
+ array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
50
+ might not be possible to reconstruct such data structures from `Parameters` objects
51
+ alone. Additional information or metadata must be provided from elsewhere.
52
+
53
+ Parameters
54
+ ----------
55
+ record : ArrayRecord
56
+ The record to be conveted into Parameters.
57
+ keep_input : bool
58
+ A boolean indicating whether entries in the record should be deleted from the
59
+ input dictionary immediately after adding them to the record.
60
+
61
+ Returns
62
+ -------
63
+ parameters : Parameters
64
+ The parameters in the legacy format Parameters.
65
+ """
66
+ parameters = Parameters(tensors=[], tensor_type="")
67
+
68
+ for key in list(record.keys()):
69
+ if key != EMPTY_TENSOR_KEY:
70
+ parameters.tensors.append(record[key].data)
71
+
72
+ if not parameters.tensor_type:
73
+ # Setting from first array in record. Recall the warning in the docstrings
74
+ # of this function.
75
+ parameters.tensor_type = record[key].stype
76
+
77
+ if not keep_input:
78
+ del record[key]
79
+
80
+ return parameters
81
+
82
+
83
+ def parameters_to_arrayrecord(parameters: Parameters, keep_input: bool) -> ArrayRecord:
84
+ """Convert legacy Parameters into a single ArrayRecord.
85
+
86
+ Because there is no concept of names in the legacy Parameters, arbitrary keys will
87
+ be used when constructing the ArrayRecord. Similarly, the shape and data type
88
+ won't be recorded in the Array objects.
89
+
90
+ Parameters
91
+ ----------
92
+ parameters : Parameters
93
+ Parameters object to be represented as a ArrayRecord.
94
+ keep_input : bool
95
+ A boolean indicating whether parameters should be deleted from the input
96
+ Parameters object (i.e. a list of serialized NumPy arrays) immediately after
97
+ adding them to the record.
98
+
99
+ Returns
100
+ -------
101
+ ArrayRecord
102
+ The ArrayRecord containing the provided parameters.
103
+ """
104
+ tensor_type = parameters.tensor_type
105
+
106
+ num_arrays = len(parameters.tensors)
107
+ ordered_dict = OrderedDict()
108
+ for idx in range(num_arrays):
109
+ if keep_input:
110
+ tensor = parameters.tensors[idx]
111
+ else:
112
+ tensor = parameters.tensors.pop(0)
113
+ ordered_dict[str(idx)] = Array(
114
+ data=tensor, dtype="", stype=tensor_type, shape=[]
115
+ )
116
+
117
+ if num_arrays == 0:
118
+ ordered_dict[EMPTY_TENSOR_KEY] = Array(
119
+ data=b"", dtype="", stype=tensor_type, shape=[]
120
+ )
121
+ return ArrayRecord(ordered_dict, keep_input=keep_input)
122
+
123
+
124
+ def _check_mapping_from_recordscalartype_to_scalar(
125
+ record_data: Mapping[str, Union[ConfigRecordValues, MetricRecordValues]]
126
+ ) -> dict[str, Scalar]:
127
+ """Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
128
+ for value in record_data.values():
129
+ if not isinstance(value, get_args(Scalar)):
130
+ raise TypeError(
131
+ "There is not a 1:1 mapping between `common.Scalar` types and those "
132
+ "supported in `common.ConfigRecordValues` or "
133
+ "`common.ConfigRecordValues`. Consider casting your values to a type "
134
+ "supported by the `common.RecordDict` infrastructure. "
135
+ f"You used type: {type(value)}"
136
+ )
137
+ return cast(dict[str, Scalar], record_data)
138
+
139
+
140
+ def _recorddict_to_fit_or_evaluate_ins_components(
141
+ recorddict: RecordDict,
142
+ ins_str: str,
143
+ keep_input: bool,
144
+ ) -> tuple[Parameters, dict[str, Scalar]]:
145
+ """Derive Fit/Evaluate Ins from a RecordDict."""
146
+ # get Array and construct Parameters
147
+ array_record = recorddict.array_records[f"{ins_str}.parameters"]
148
+
149
+ parameters = arrayrecord_to_parameters(array_record, keep_input=keep_input)
150
+
151
+ # get config dict
152
+ config_record = recorddict.config_records[f"{ins_str}.config"]
153
+ # pylint: disable-next=protected-access
154
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
155
+
156
+ return parameters, config_dict
157
+
158
+
159
+ def _fit_or_evaluate_ins_to_recorddict(
160
+ ins: Union[FitIns, EvaluateIns], keep_input: bool
161
+ ) -> RecordDict:
162
+ recorddict = RecordDict()
163
+
164
+ ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins"
165
+ arr_record = parameters_to_arrayrecord(ins.parameters, keep_input)
166
+ recorddict.array_records[f"{ins_str}.parameters"] = arr_record
167
+
168
+ recorddict.config_records[f"{ins_str}.config"] = ConfigRecord(
169
+ ins.config # type: ignore
170
+ )
171
+
172
+ return recorddict
173
+
174
+
175
+ def _embed_status_into_recorddict(
176
+ res_str: str, status: Status, recorddict: RecordDict
177
+ ) -> RecordDict:
178
+ status_dict: dict[str, ConfigRecordValues] = {
179
+ "code": int(status.code.value),
180
+ "message": status.message,
181
+ }
182
+ # we add it to a `ConfigRecord` because the `status.message` is a string
183
+ # and `str` values aren't supported in `MetricRecords`
184
+ recorddict.config_records[f"{res_str}.status"] = ConfigRecord(status_dict)
185
+ return recorddict
186
+
187
+
188
+ def _extract_status_from_recorddict(res_str: str, recorddict: RecordDict) -> Status:
189
+ status = recorddict.config_records[f"{res_str}.status"]
190
+ code = cast(int, status["code"])
191
+ return Status(code=Code(code), message=str(status["message"]))
192
+
193
+
194
+ def recorddict_to_fitins(recorddict: RecordDict, keep_input: bool) -> FitIns:
195
+ """Derive FitIns from a RecordDict object."""
196
+ parameters, config = _recorddict_to_fit_or_evaluate_ins_components(
197
+ recorddict,
198
+ ins_str="fitins",
199
+ keep_input=keep_input,
200
+ )
201
+
202
+ return FitIns(parameters=parameters, config=config)
203
+
204
+
205
+ def fitins_to_recorddict(fitins: FitIns, keep_input: bool) -> RecordDict:
206
+ """Construct a RecordDict from a FitIns object."""
207
+ return _fit_or_evaluate_ins_to_recorddict(fitins, keep_input)
208
+
209
+
210
+ def recorddict_to_fitres(recorddict: RecordDict, keep_input: bool) -> FitRes:
211
+ """Derive FitRes from a RecordDict object."""
212
+ ins_str = "fitres"
213
+ parameters = arrayrecord_to_parameters(
214
+ recorddict.array_records[f"{ins_str}.parameters"], keep_input=keep_input
215
+ )
216
+
217
+ num_examples = cast(
218
+ int, recorddict.metric_records[f"{ins_str}.num_examples"]["num_examples"]
219
+ )
220
+ config_record = recorddict.config_records[f"{ins_str}.metrics"]
221
+ # pylint: disable-next=protected-access
222
+ metrics = _check_mapping_from_recordscalartype_to_scalar(config_record)
223
+ status = _extract_status_from_recorddict(ins_str, recorddict)
224
+
225
+ return FitRes(
226
+ status=status, parameters=parameters, num_examples=num_examples, metrics=metrics
227
+ )
228
+
229
+
230
+ def fitres_to_recorddict(fitres: FitRes, keep_input: bool) -> RecordDict:
231
+ """Construct a RecordDict from a FitRes object."""
232
+ recorddict = RecordDict()
233
+
234
+ res_str = "fitres"
235
+
236
+ recorddict.config_records[f"{res_str}.metrics"] = ConfigRecord(
237
+ fitres.metrics # type: ignore
238
+ )
239
+ recorddict.metric_records[f"{res_str}.num_examples"] = MetricRecord(
240
+ {"num_examples": fitres.num_examples},
241
+ )
242
+ recorddict.array_records[f"{res_str}.parameters"] = parameters_to_arrayrecord(
243
+ fitres.parameters,
244
+ keep_input,
245
+ )
246
+
247
+ # status
248
+ recorddict = _embed_status_into_recorddict(res_str, fitres.status, recorddict)
249
+
250
+ return recorddict
251
+
252
+
253
+ def recorddict_to_evaluateins(recorddict: RecordDict, keep_input: bool) -> EvaluateIns:
254
+ """Derive EvaluateIns from a RecordDict object."""
255
+ parameters, config = _recorddict_to_fit_or_evaluate_ins_components(
256
+ recorddict,
257
+ ins_str="evaluateins",
258
+ keep_input=keep_input,
259
+ )
260
+
261
+ return EvaluateIns(parameters=parameters, config=config)
262
+
263
+
264
+ def evaluateins_to_recorddict(evaluateins: EvaluateIns, keep_input: bool) -> RecordDict:
265
+ """Construct a RecordDict from a EvaluateIns object."""
266
+ return _fit_or_evaluate_ins_to_recorddict(evaluateins, keep_input)
267
+
268
+
269
+ def recorddict_to_evaluateres(recorddict: RecordDict) -> EvaluateRes:
270
+ """Derive EvaluateRes from a RecordDict object."""
271
+ ins_str = "evaluateres"
272
+
273
+ loss = cast(int, recorddict.metric_records[f"{ins_str}.loss"]["loss"])
274
+
275
+ num_examples = cast(
276
+ int, recorddict.metric_records[f"{ins_str}.num_examples"]["num_examples"]
277
+ )
278
+ config_record = recorddict.config_records[f"{ins_str}.metrics"]
279
+
280
+ # pylint: disable-next=protected-access
281
+ metrics = _check_mapping_from_recordscalartype_to_scalar(config_record)
282
+ status = _extract_status_from_recorddict(ins_str, recorddict)
283
+
284
+ return EvaluateRes(
285
+ status=status, loss=loss, num_examples=num_examples, metrics=metrics
286
+ )
287
+
288
+
289
+ def evaluateres_to_recorddict(evaluateres: EvaluateRes) -> RecordDict:
290
+ """Construct a RecordDict from a EvaluateRes object."""
291
+ recorddict = RecordDict()
292
+
293
+ res_str = "evaluateres"
294
+ # loss
295
+ recorddict.metric_records[f"{res_str}.loss"] = MetricRecord(
296
+ {"loss": evaluateres.loss},
297
+ )
298
+
299
+ # num_examples
300
+ recorddict.metric_records[f"{res_str}.num_examples"] = MetricRecord(
301
+ {"num_examples": evaluateres.num_examples},
302
+ )
303
+
304
+ # metrics
305
+ recorddict.config_records[f"{res_str}.metrics"] = ConfigRecord(
306
+ evaluateres.metrics, # type: ignore
307
+ )
308
+
309
+ # status
310
+ recorddict = _embed_status_into_recorddict(
311
+ f"{res_str}", evaluateres.status, recorddict
312
+ )
313
+
314
+ return recorddict
315
+
316
+
317
+ def recorddict_to_getparametersins(recorddict: RecordDict) -> GetParametersIns:
318
+ """Derive GetParametersIns from a RecordDict object."""
319
+ config_record = recorddict.config_records["getparametersins.config"]
320
+ # pylint: disable-next=protected-access
321
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
322
+
323
+ return GetParametersIns(config=config_dict)
324
+
325
+
326
+ def getparametersins_to_recorddict(getparameters_ins: GetParametersIns) -> RecordDict:
327
+ """Construct a RecordDict from a GetParametersIns object."""
328
+ recorddict = RecordDict()
329
+
330
+ recorddict.config_records["getparametersins.config"] = ConfigRecord(
331
+ getparameters_ins.config, # type: ignore
332
+ )
333
+ return recorddict
334
+
335
+
336
+ def getparametersres_to_recorddict(
337
+ getparametersres: GetParametersRes, keep_input: bool
338
+ ) -> RecordDict:
339
+ """Construct a RecordDict from a GetParametersRes object."""
340
+ recorddict = RecordDict()
341
+ res_str = "getparametersres"
342
+ array_record = parameters_to_arrayrecord(
343
+ getparametersres.parameters, keep_input=keep_input
344
+ )
345
+ recorddict.array_records[f"{res_str}.parameters"] = array_record
346
+
347
+ # status
348
+ recorddict = _embed_status_into_recorddict(
349
+ res_str, getparametersres.status, recorddict
350
+ )
351
+
352
+ return recorddict
353
+
354
+
355
+ def recorddict_to_getparametersres(
356
+ recorddict: RecordDict, keep_input: bool
357
+ ) -> GetParametersRes:
358
+ """Derive GetParametersRes from a RecordDict object."""
359
+ res_str = "getparametersres"
360
+ parameters = arrayrecord_to_parameters(
361
+ recorddict.array_records[f"{res_str}.parameters"], keep_input=keep_input
362
+ )
363
+
364
+ status = _extract_status_from_recorddict(res_str, recorddict)
365
+ return GetParametersRes(status=status, parameters=parameters)
366
+
367
+
368
+ def recorddict_to_getpropertiesins(recorddict: RecordDict) -> GetPropertiesIns:
369
+ """Derive GetPropertiesIns from a RecordDict object."""
370
+ config_record = recorddict.config_records["getpropertiesins.config"]
371
+ # pylint: disable-next=protected-access
372
+ config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
373
+
374
+ return GetPropertiesIns(config=config_dict)
375
+
376
+
377
+ def getpropertiesins_to_recorddict(getpropertiesins: GetPropertiesIns) -> RecordDict:
378
+ """Construct a RecordDict from a GetPropertiesRes object."""
379
+ recorddict = RecordDict()
380
+ recorddict.config_records["getpropertiesins.config"] = ConfigRecord(
381
+ getpropertiesins.config, # type: ignore
382
+ )
383
+ return recorddict
384
+
385
+
386
+ def recorddict_to_getpropertiesres(recorddict: RecordDict) -> GetPropertiesRes:
387
+ """Derive GetPropertiesRes from a RecordDict object."""
388
+ res_str = "getpropertiesres"
389
+ config_record = recorddict.config_records[f"{res_str}.properties"]
390
+ # pylint: disable-next=protected-access
391
+ properties = _check_mapping_from_recordscalartype_to_scalar(config_record)
392
+
393
+ status = _extract_status_from_recorddict(res_str, recorddict=recorddict)
394
+
395
+ return GetPropertiesRes(status=status, properties=properties)
396
+
397
+
398
+ def getpropertiesres_to_recorddict(getpropertiesres: GetPropertiesRes) -> RecordDict:
399
+ """Construct a RecordDict from a GetPropertiesRes object."""
400
+ recorddict = RecordDict()
401
+ res_str = "getpropertiesres"
402
+ recorddict.config_records[f"{res_str}.properties"] = ConfigRecord(
403
+ getpropertiesres.properties, # type: ignore
404
+ )
405
+ # status
406
+ recorddict = _embed_status_into_recorddict(
407
+ res_str, getpropertiesres.status, recorddict
408
+ )
409
+
410
+ return recorddict
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -166,15 +166,15 @@ class RetryInvoker:
166
166
 
167
167
  Examples
168
168
  --------
169
- Initialize a `RetryInvoker` with exponential backoff and invoke a function:
170
-
171
- >>> invoker = RetryInvoker(
172
- ... exponential, # Or use `lambda: exponential(3, 2)` to pass arguments
173
- ... grpc.RpcError,
174
- ... max_tries=3,
175
- ... max_time=None,
176
- ... )
177
- >>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1)
169
+ Initialize a `RetryInvoker` with exponential backoff and invoke a function::
170
+
171
+ invoker = RetryInvoker(
172
+ exponential, # Or use `lambda: exponential(3, 2)` to pass arguments
173
+ grpc.RpcError,
174
+ max_tries=3,
175
+ max_time=None,
176
+ )
177
+ invoker.invoke(my_func, arg1, arg2, kw1=kwarg1)
178
178
  """
179
179
 
180
180
  # pylint: disable-next=too-many-arguments
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,61 +15,83 @@
15
15
  """Shamir's secret sharing."""
16
16
 
17
17
 
18
- import pickle
18
+ import os
19
19
  from concurrent.futures import ThreadPoolExecutor
20
- from typing import cast
21
20
 
22
21
  from Crypto.Protocol.SecretSharing import Shamir
23
22
  from Crypto.Util.Padding import pad, unpad
24
23
 
25
24
 
26
25
  def create_shares(secret: bytes, threshold: int, num: int) -> list[bytes]:
27
- """Return list of shares (bytes)."""
26
+ """Return a list of shares (bytes).
27
+
28
+ Shares are created from the provided secret using Shamir's secret sharing.
29
+ """
30
+ # Shamir's secret sharing requires the secret to be a multiple of 16 bytes
31
+ # (AES block size). Pad the secret to the next multiple of 16 bytes.
28
32
  secret_padded = pad(secret, 16)
29
- secret_padded_chunk = [
30
- (threshold, num, secret_padded[i : i + 16])
31
- for i in range(0, len(secret_padded), 16)
32
- ]
33
- share_list: list[list[tuple[int, bytes]]] = [[] for _ in range(num)]
33
+ chunks = [secret_padded[i : i + 16] for i in range(0, len(secret_padded), 16)]
34
+
35
+ # The share list should contain shares of the secret, and each share consists of:
36
+ # <4 bytes of index><share of chunk1><share of chunk2>...<share of chunkN>
37
+ share_list: list[bytearray] = [bytearray() for _ in range(num)]
34
38
 
35
- with ThreadPoolExecutor(max_workers=10) as executor:
39
+ # Create shares for each chunk in parallel
40
+ max_workers = min(len(chunks), os.cpu_count() or 1)
41
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
36
42
  for chunk_shares in executor.map(
37
- lambda arg: _shamir_split(*arg), secret_padded_chunk
43
+ lambda chunk: _shamir_split(threshold, num, chunk), chunks
38
44
  ):
39
45
  for idx, share in chunk_shares:
40
- # Index in `chunk_shares` starts from 1
41
- share_list[idx - 1].append((idx, share))
46
+ # Initialize the share with the index if it is empty
47
+ if not share_list[idx - 1]:
48
+ share_list[idx - 1] += idx.to_bytes(4, "little", signed=False)
42
49
 
43
- return [pickle.dumps(shares) for shares in share_list]
50
+ # Append the share to the bytes
51
+ share_list[idx - 1] += share
52
+
53
+ return [bytes(share) for share in share_list]
44
54
 
45
55
 
46
56
  def _shamir_split(threshold: int, num: int, chunk: bytes) -> list[tuple[int, bytes]]:
57
+ """Create shares for a chunk using Shamir's secret sharing.
58
+
59
+ Each share is a tuple (index, share_bytes), where share_bytes is 16 bytes long.
60
+ """
47
61
  return Shamir.split(threshold, num, chunk, ssss=False)
48
62
 
49
63
 
50
- # Reconstructing secret with PyCryptodome
51
64
  def combine_shares(share_list: list[bytes]) -> bytes:
52
- """Reconstruct secret from shares."""
53
- unpickled_share_list: list[list[tuple[int, bytes]]] = [
54
- cast(list[tuple[int, bytes]], pickle.loads(share)) for share in share_list
55
- ]
65
+ """Reconstruct the secret from a list of shares."""
66
+ # Compute the number of chunks
67
+ # Each share contains 4 bytes of index and 16 bytes of share for each chunk
68
+ chunk_num = (len(share_list[0]) - 4) >> 4
56
69
 
57
- chunk_num = len(unpickled_share_list[0])
58
70
  secret_padded = bytearray(0)
59
- chunk_shares_list: list[list[tuple[int, bytes]]] = []
60
- for i in range(chunk_num):
61
- chunk_shares: list[tuple[int, bytes]] = []
62
- for share in unpickled_share_list:
63
- chunk_shares.append(share[i])
64
- chunk_shares_list.append(chunk_shares)
65
-
66
- with ThreadPoolExecutor(max_workers=10) as executor:
71
+ chunk_shares_list: list[list[tuple[int, bytes]]] = [[] for _ in range(chunk_num)]
72
+
73
+ # Split shares into chunks
74
+ for share in share_list:
75
+ # The first 4 bytes are the index
76
+ index = int.from_bytes(share[:4], "little", signed=False)
77
+ for i in range(chunk_num):
78
+ start = (i << 4) + 4
79
+ chunk_shares_list[i].append((index, share[start : start + 16]))
80
+
81
+ # Combine shares for each chunk in parallel
82
+ max_workers = min(chunk_num, os.cpu_count() or 1)
83
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
67
84
  for chunk in executor.map(_shamir_combine, chunk_shares_list):
68
85
  secret_padded += chunk
69
86
 
70
- secret = unpad(secret_padded, 16)
71
- return bytes(secret)
87
+ try:
88
+ secret = unpad(bytes(secret_padded), 16)
89
+ except ValueError:
90
+ # If unpadding fails, it means the shares are not valid
91
+ raise ValueError("Failed to combine shares") from None
92
+ return secret
72
93
 
73
94
 
74
95
  def _shamir_combine(shares: list[tuple[int, bytes]]) -> bytes:
96
+ """Reconstruct a chunk from shares using Shamir's secret sharing."""
75
97
  return Shamir.combine(shares, ssss=False)
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@ class Stage:
42
42
 
43
43
 
44
44
  class Key:
45
- """Keys for the configs in the ConfigsRecord."""
45
+ """Keys for the configs in the ConfigRecord."""
46
46
 
47
47
  STAGE = "stage"
48
48
  SAMPLE_NUMBER = "sample_num"
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.