flwr 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/cli/__init__.py +1 -1
- flwr/cli/app.py +21 -2
- flwr/cli/build.py +1 -1
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +53 -17
- flwr/cli/example.py +1 -1
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +1 -1
- flwr/cli/login/__init__.py +1 -1
- flwr/cli/login/login.py +12 -1
- flwr/cli/ls.py +1 -1
- flwr/cli/new/__init__.py +1 -1
- flwr/cli/new/new.py +4 -4
- flwr/cli/new/templates/__init__.py +1 -1
- flwr/cli/new/templates/app/__init__.py +1 -1
- flwr/cli/new/templates/app/code/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +5 -5
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -1
- flwr/cli/run/run.py +6 -10
- flwr/cli/stop.py +1 -1
- flwr/cli/utils.py +11 -12
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +58 -56
- flwr/client/client.py +1 -1
- flwr/client/client_app.py +231 -166
- flwr/client/clientapp/__init__.py +1 -1
- flwr/client/clientapp/app.py +3 -3
- flwr/client/clientapp/clientappio_servicer.py +1 -1
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +1 -1
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/__init__.py +1 -1
- flwr/client/grpc_client/connection.py +37 -34
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
- flwr/client/heartbeat.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +3 -3
- flwr/client/mod/centraldp_mods.py +8 -8
- flwr/client/mod/comms_mods.py +17 -23
- flwr/client/mod/localdp_mod.py +10 -10
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +32 -32
- flwr/client/mod/utils.py +1 -1
- flwr/client/nodestate/__init__.py +1 -1
- flwr/client/nodestate/in_memory_nodestate.py +1 -1
- flwr/client/nodestate/nodestate.py +1 -1
- flwr/client/nodestate/nodestate_factory.py +1 -1
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/run_info_store.py +3 -3
- flwr/client/supernode/__init__.py +1 -1
- flwr/client/supernode/app.py +1 -1
- flwr/client/typing.py +1 -1
- flwr/common/__init__.py +13 -5
- flwr/common/address.py +1 -1
- flwr/common/args.py +1 -1
- flwr/common/auth_plugin/__init__.py +1 -1
- flwr/common/auth_plugin/auth_plugin.py +1 -1
- flwr/common/config.py +5 -5
- flwr/common/constant.py +7 -7
- flwr/common/context.py +5 -5
- flwr/common/date.py +1 -1
- flwr/common/differential_privacy.py +1 -1
- flwr/common/differential_privacy_constants.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit.py +6 -6
- flwr/common/exit_handlers.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +3 -3
- flwr/common/message.py +344 -102
- flwr/common/object_ref.py +1 -1
- flwr/common/parameter.py +1 -1
- flwr/common/pyproject.py +1 -1
- flwr/common/record/__init__.py +9 -5
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +83 -37
- flwr/common/record/conversion_utils.py +2 -2
- flwr/common/record/{metricsrecord.py → metricrecord.py} +90 -44
- flwr/common/record/recorddict.py +337 -0
- flwr/common/record/typeddict.py +1 -1
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/retry_invoker.py +10 -10
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +52 -30
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +67 -72
- flwr/common/telemetry.py +2 -2
- flwr/common/typing.py +9 -9
- flwr/common/version.py +1 -1
- flwr/proto/__init__.py +1 -1
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +4 -2
- flwr/server/app.py +67 -12
- flwr/server/client_manager.py +1 -1
- flwr/server/client_proxy.py +1 -1
- flwr/server/compat/__init__.py +3 -3
- flwr/server/compat/app.py +12 -12
- flwr/server/compat/app_utils.py +17 -17
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/criterion.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/history.py +1 -1
- flwr/server/run_serverapp.py +5 -5
- flwr/server/server.py +1 -1
- flwr/server/server_app.py +98 -71
- flwr/server/server_config.py +1 -1
- flwr/server/serverapp/__init__.py +1 -1
- flwr/server/serverapp/app.py +11 -11
- flwr/server/serverapp_components.py +1 -1
- flwr/server/strategy/__init__.py +1 -1
- flwr/server/strategy/aggregate.py +1 -1
- flwr/server/strategy/bulyan.py +2 -2
- flwr/server/strategy/dp_adaptive_clipping.py +17 -17
- flwr/server/strategy/dp_fixed_clipping.py +17 -17
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedtrimmedavg.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +3 -2
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/strategy/strategy.py +1 -1
- flwr/server/superlink/__init__.py +1 -1
- flwr/server/superlink/ffs/__init__.py +1 -1
- flwr/server/superlink/ffs/disk_ffs.py +1 -1
- flwr/server/superlink/ffs/ffs.py +1 -1
- flwr/server/superlink/ffs/ffs_factory.py +1 -1
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -3
- flwr/server/superlink/fleet/vce/vce_api.py +2 -4
- flwr/server/superlink/linkstate/__init__.py +1 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -9
- flwr/server/superlink/linkstate/linkstate.py +5 -5
- flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +62 -28
- flwr/server/superlink/linkstate/utils.py +94 -28
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/__init__.py +1 -1
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +3 -3
- flwr/server/superlink/utils.py +1 -1
- flwr/server/typing.py +4 -4
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +5 -5
- flwr/server/workflow/__init__.py +1 -1
- flwr/server/workflow/constant.py +1 -1
- flwr/server/workflow/default_workflows.py +49 -58
- flwr/server/workflow/secure_aggregation/__init__.py +1 -1
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +49 -51
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +3 -3
- flwr/simulation/legacy_app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +5 -3
- flwr/simulation/ray_transport/ray_client_proxy.py +35 -33
- flwr/simulation/ray_transport/utils.py +1 -1
- flwr/simulation/run_simulation.py +17 -17
- flwr/simulation/simulationio_connection.py +1 -1
- flwr/superexec/__init__.py +1 -1
- flwr/superexec/app.py +1 -1
- flwr/superexec/deployment.py +5 -5
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +11 -5
- flwr/superexec/exec_servicer.py +3 -3
- flwr/superexec/exec_user_auth_interceptor.py +19 -3
- flwr/superexec/executor.py +4 -4
- flwr/superexec/simulation.py +4 -4
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/METADATA +3 -3
- flwr-1.18.0.dist-info/RECORD +332 -0
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr-1.16.0.dist-info/LICENSE +0 -202
- flwr-1.16.0.dist-info/RECORD +0 -331
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""RecordDict utilities."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from typing import Union, cast, get_args
|
|
21
|
+
|
|
22
|
+
from . import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
|
23
|
+
from .typing import (
|
|
24
|
+
Code,
|
|
25
|
+
ConfigRecordValues,
|
|
26
|
+
EvaluateIns,
|
|
27
|
+
EvaluateRes,
|
|
28
|
+
FitIns,
|
|
29
|
+
FitRes,
|
|
30
|
+
GetParametersIns,
|
|
31
|
+
GetParametersRes,
|
|
32
|
+
GetPropertiesIns,
|
|
33
|
+
GetPropertiesRes,
|
|
34
|
+
MetricRecordValues,
|
|
35
|
+
Parameters,
|
|
36
|
+
Scalar,
|
|
37
|
+
Status,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
EMPTY_TENSOR_KEY = "_empty"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def arrayrecord_to_parameters(record: ArrayRecord, keep_input: bool) -> Parameters:
|
|
44
|
+
"""Convert ParameterRecord to legacy Parameters.
|
|
45
|
+
|
|
46
|
+
Warnings
|
|
47
|
+
--------
|
|
48
|
+
Because `Array`s in `ArrayRecord` encode more information of the
|
|
49
|
+
array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
|
|
50
|
+
might not be possible to reconstruct such data structures from `Parameters` objects
|
|
51
|
+
alone. Additional information or metadata must be provided from elsewhere.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
record : ArrayRecord
|
|
56
|
+
The record to be conveted into Parameters.
|
|
57
|
+
keep_input : bool
|
|
58
|
+
A boolean indicating whether entries in the record should be deleted from the
|
|
59
|
+
input dictionary immediately after adding them to the record.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
parameters : Parameters
|
|
64
|
+
The parameters in the legacy format Parameters.
|
|
65
|
+
"""
|
|
66
|
+
parameters = Parameters(tensors=[], tensor_type="")
|
|
67
|
+
|
|
68
|
+
for key in list(record.keys()):
|
|
69
|
+
if key != EMPTY_TENSOR_KEY:
|
|
70
|
+
parameters.tensors.append(record[key].data)
|
|
71
|
+
|
|
72
|
+
if not parameters.tensor_type:
|
|
73
|
+
# Setting from first array in record. Recall the warning in the docstrings
|
|
74
|
+
# of this function.
|
|
75
|
+
parameters.tensor_type = record[key].stype
|
|
76
|
+
|
|
77
|
+
if not keep_input:
|
|
78
|
+
del record[key]
|
|
79
|
+
|
|
80
|
+
return parameters
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def parameters_to_arrayrecord(parameters: Parameters, keep_input: bool) -> ArrayRecord:
|
|
84
|
+
"""Convert legacy Parameters into a single ArrayRecord.
|
|
85
|
+
|
|
86
|
+
Because there is no concept of names in the legacy Parameters, arbitrary keys will
|
|
87
|
+
be used when constructing the ArrayRecord. Similarly, the shape and data type
|
|
88
|
+
won't be recorded in the Array objects.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
parameters : Parameters
|
|
93
|
+
Parameters object to be represented as a ArrayRecord.
|
|
94
|
+
keep_input : bool
|
|
95
|
+
A boolean indicating whether parameters should be deleted from the input
|
|
96
|
+
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
|
|
97
|
+
adding them to the record.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
ArrayRecord
|
|
102
|
+
The ArrayRecord containing the provided parameters.
|
|
103
|
+
"""
|
|
104
|
+
tensor_type = parameters.tensor_type
|
|
105
|
+
|
|
106
|
+
num_arrays = len(parameters.tensors)
|
|
107
|
+
ordered_dict = OrderedDict()
|
|
108
|
+
for idx in range(num_arrays):
|
|
109
|
+
if keep_input:
|
|
110
|
+
tensor = parameters.tensors[idx]
|
|
111
|
+
else:
|
|
112
|
+
tensor = parameters.tensors.pop(0)
|
|
113
|
+
ordered_dict[str(idx)] = Array(
|
|
114
|
+
data=tensor, dtype="", stype=tensor_type, shape=[]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if num_arrays == 0:
|
|
118
|
+
ordered_dict[EMPTY_TENSOR_KEY] = Array(
|
|
119
|
+
data=b"", dtype="", stype=tensor_type, shape=[]
|
|
120
|
+
)
|
|
121
|
+
return ArrayRecord(ordered_dict, keep_input=keep_input)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _check_mapping_from_recordscalartype_to_scalar(
|
|
125
|
+
record_data: Mapping[str, Union[ConfigRecordValues, MetricRecordValues]]
|
|
126
|
+
) -> dict[str, Scalar]:
|
|
127
|
+
"""Check mapping `common.*RecordValues` into `common.Scalar` is possible."""
|
|
128
|
+
for value in record_data.values():
|
|
129
|
+
if not isinstance(value, get_args(Scalar)):
|
|
130
|
+
raise TypeError(
|
|
131
|
+
"There is not a 1:1 mapping between `common.Scalar` types and those "
|
|
132
|
+
"supported in `common.ConfigRecordValues` or "
|
|
133
|
+
"`common.ConfigRecordValues`. Consider casting your values to a type "
|
|
134
|
+
"supported by the `common.RecordDict` infrastructure. "
|
|
135
|
+
f"You used type: {type(value)}"
|
|
136
|
+
)
|
|
137
|
+
return cast(dict[str, Scalar], record_data)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _recorddict_to_fit_or_evaluate_ins_components(
|
|
141
|
+
recorddict: RecordDict,
|
|
142
|
+
ins_str: str,
|
|
143
|
+
keep_input: bool,
|
|
144
|
+
) -> tuple[Parameters, dict[str, Scalar]]:
|
|
145
|
+
"""Derive Fit/Evaluate Ins from a RecordDict."""
|
|
146
|
+
# get Array and construct Parameters
|
|
147
|
+
array_record = recorddict.array_records[f"{ins_str}.parameters"]
|
|
148
|
+
|
|
149
|
+
parameters = arrayrecord_to_parameters(array_record, keep_input=keep_input)
|
|
150
|
+
|
|
151
|
+
# get config dict
|
|
152
|
+
config_record = recorddict.config_records[f"{ins_str}.config"]
|
|
153
|
+
# pylint: disable-next=protected-access
|
|
154
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
155
|
+
|
|
156
|
+
return parameters, config_dict
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _fit_or_evaluate_ins_to_recorddict(
|
|
160
|
+
ins: Union[FitIns, EvaluateIns], keep_input: bool
|
|
161
|
+
) -> RecordDict:
|
|
162
|
+
recorddict = RecordDict()
|
|
163
|
+
|
|
164
|
+
ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins"
|
|
165
|
+
arr_record = parameters_to_arrayrecord(ins.parameters, keep_input)
|
|
166
|
+
recorddict.array_records[f"{ins_str}.parameters"] = arr_record
|
|
167
|
+
|
|
168
|
+
recorddict.config_records[f"{ins_str}.config"] = ConfigRecord(
|
|
169
|
+
ins.config # type: ignore
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return recorddict
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _embed_status_into_recorddict(
|
|
176
|
+
res_str: str, status: Status, recorddict: RecordDict
|
|
177
|
+
) -> RecordDict:
|
|
178
|
+
status_dict: dict[str, ConfigRecordValues] = {
|
|
179
|
+
"code": int(status.code.value),
|
|
180
|
+
"message": status.message,
|
|
181
|
+
}
|
|
182
|
+
# we add it to a `ConfigRecord` because the `status.message` is a string
|
|
183
|
+
# and `str` values aren't supported in `MetricRecords`
|
|
184
|
+
recorddict.config_records[f"{res_str}.status"] = ConfigRecord(status_dict)
|
|
185
|
+
return recorddict
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _extract_status_from_recorddict(res_str: str, recorddict: RecordDict) -> Status:
|
|
189
|
+
status = recorddict.config_records[f"{res_str}.status"]
|
|
190
|
+
code = cast(int, status["code"])
|
|
191
|
+
return Status(code=Code(code), message=str(status["message"]))
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def recorddict_to_fitins(recorddict: RecordDict, keep_input: bool) -> FitIns:
|
|
195
|
+
"""Derive FitIns from a RecordDict object."""
|
|
196
|
+
parameters, config = _recorddict_to_fit_or_evaluate_ins_components(
|
|
197
|
+
recorddict,
|
|
198
|
+
ins_str="fitins",
|
|
199
|
+
keep_input=keep_input,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
return FitIns(parameters=parameters, config=config)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def fitins_to_recorddict(fitins: FitIns, keep_input: bool) -> RecordDict:
|
|
206
|
+
"""Construct a RecordDict from a FitIns object."""
|
|
207
|
+
return _fit_or_evaluate_ins_to_recorddict(fitins, keep_input)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def recorddict_to_fitres(recorddict: RecordDict, keep_input: bool) -> FitRes:
|
|
211
|
+
"""Derive FitRes from a RecordDict object."""
|
|
212
|
+
ins_str = "fitres"
|
|
213
|
+
parameters = arrayrecord_to_parameters(
|
|
214
|
+
recorddict.array_records[f"{ins_str}.parameters"], keep_input=keep_input
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
num_examples = cast(
|
|
218
|
+
int, recorddict.metric_records[f"{ins_str}.num_examples"]["num_examples"]
|
|
219
|
+
)
|
|
220
|
+
config_record = recorddict.config_records[f"{ins_str}.metrics"]
|
|
221
|
+
# pylint: disable-next=protected-access
|
|
222
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
223
|
+
status = _extract_status_from_recorddict(ins_str, recorddict)
|
|
224
|
+
|
|
225
|
+
return FitRes(
|
|
226
|
+
status=status, parameters=parameters, num_examples=num_examples, metrics=metrics
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def fitres_to_recorddict(fitres: FitRes, keep_input: bool) -> RecordDict:
|
|
231
|
+
"""Construct a RecordDict from a FitRes object."""
|
|
232
|
+
recorddict = RecordDict()
|
|
233
|
+
|
|
234
|
+
res_str = "fitres"
|
|
235
|
+
|
|
236
|
+
recorddict.config_records[f"{res_str}.metrics"] = ConfigRecord(
|
|
237
|
+
fitres.metrics # type: ignore
|
|
238
|
+
)
|
|
239
|
+
recorddict.metric_records[f"{res_str}.num_examples"] = MetricRecord(
|
|
240
|
+
{"num_examples": fitres.num_examples},
|
|
241
|
+
)
|
|
242
|
+
recorddict.array_records[f"{res_str}.parameters"] = parameters_to_arrayrecord(
|
|
243
|
+
fitres.parameters,
|
|
244
|
+
keep_input,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# status
|
|
248
|
+
recorddict = _embed_status_into_recorddict(res_str, fitres.status, recorddict)
|
|
249
|
+
|
|
250
|
+
return recorddict
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def recorddict_to_evaluateins(recorddict: RecordDict, keep_input: bool) -> EvaluateIns:
|
|
254
|
+
"""Derive EvaluateIns from a RecordDict object."""
|
|
255
|
+
parameters, config = _recorddict_to_fit_or_evaluate_ins_components(
|
|
256
|
+
recorddict,
|
|
257
|
+
ins_str="evaluateins",
|
|
258
|
+
keep_input=keep_input,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
return EvaluateIns(parameters=parameters, config=config)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def evaluateins_to_recorddict(evaluateins: EvaluateIns, keep_input: bool) -> RecordDict:
|
|
265
|
+
"""Construct a RecordDict from a EvaluateIns object."""
|
|
266
|
+
return _fit_or_evaluate_ins_to_recorddict(evaluateins, keep_input)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def recorddict_to_evaluateres(recorddict: RecordDict) -> EvaluateRes:
|
|
270
|
+
"""Derive EvaluateRes from a RecordDict object."""
|
|
271
|
+
ins_str = "evaluateres"
|
|
272
|
+
|
|
273
|
+
loss = cast(int, recorddict.metric_records[f"{ins_str}.loss"]["loss"])
|
|
274
|
+
|
|
275
|
+
num_examples = cast(
|
|
276
|
+
int, recorddict.metric_records[f"{ins_str}.num_examples"]["num_examples"]
|
|
277
|
+
)
|
|
278
|
+
config_record = recorddict.config_records[f"{ins_str}.metrics"]
|
|
279
|
+
|
|
280
|
+
# pylint: disable-next=protected-access
|
|
281
|
+
metrics = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
282
|
+
status = _extract_status_from_recorddict(ins_str, recorddict)
|
|
283
|
+
|
|
284
|
+
return EvaluateRes(
|
|
285
|
+
status=status, loss=loss, num_examples=num_examples, metrics=metrics
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def evaluateres_to_recorddict(evaluateres: EvaluateRes) -> RecordDict:
|
|
290
|
+
"""Construct a RecordDict from a EvaluateRes object."""
|
|
291
|
+
recorddict = RecordDict()
|
|
292
|
+
|
|
293
|
+
res_str = "evaluateres"
|
|
294
|
+
# loss
|
|
295
|
+
recorddict.metric_records[f"{res_str}.loss"] = MetricRecord(
|
|
296
|
+
{"loss": evaluateres.loss},
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# num_examples
|
|
300
|
+
recorddict.metric_records[f"{res_str}.num_examples"] = MetricRecord(
|
|
301
|
+
{"num_examples": evaluateres.num_examples},
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# metrics
|
|
305
|
+
recorddict.config_records[f"{res_str}.metrics"] = ConfigRecord(
|
|
306
|
+
evaluateres.metrics, # type: ignore
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# status
|
|
310
|
+
recorddict = _embed_status_into_recorddict(
|
|
311
|
+
f"{res_str}", evaluateres.status, recorddict
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return recorddict
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def recorddict_to_getparametersins(recorddict: RecordDict) -> GetParametersIns:
|
|
318
|
+
"""Derive GetParametersIns from a RecordDict object."""
|
|
319
|
+
config_record = recorddict.config_records["getparametersins.config"]
|
|
320
|
+
# pylint: disable-next=protected-access
|
|
321
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
322
|
+
|
|
323
|
+
return GetParametersIns(config=config_dict)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def getparametersins_to_recorddict(getparameters_ins: GetParametersIns) -> RecordDict:
|
|
327
|
+
"""Construct a RecordDict from a GetParametersIns object."""
|
|
328
|
+
recorddict = RecordDict()
|
|
329
|
+
|
|
330
|
+
recorddict.config_records["getparametersins.config"] = ConfigRecord(
|
|
331
|
+
getparameters_ins.config, # type: ignore
|
|
332
|
+
)
|
|
333
|
+
return recorddict
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def getparametersres_to_recorddict(
|
|
337
|
+
getparametersres: GetParametersRes, keep_input: bool
|
|
338
|
+
) -> RecordDict:
|
|
339
|
+
"""Construct a RecordDict from a GetParametersRes object."""
|
|
340
|
+
recorddict = RecordDict()
|
|
341
|
+
res_str = "getparametersres"
|
|
342
|
+
array_record = parameters_to_arrayrecord(
|
|
343
|
+
getparametersres.parameters, keep_input=keep_input
|
|
344
|
+
)
|
|
345
|
+
recorddict.array_records[f"{res_str}.parameters"] = array_record
|
|
346
|
+
|
|
347
|
+
# status
|
|
348
|
+
recorddict = _embed_status_into_recorddict(
|
|
349
|
+
res_str, getparametersres.status, recorddict
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
return recorddict
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def recorddict_to_getparametersres(
|
|
356
|
+
recorddict: RecordDict, keep_input: bool
|
|
357
|
+
) -> GetParametersRes:
|
|
358
|
+
"""Derive GetParametersRes from a RecordDict object."""
|
|
359
|
+
res_str = "getparametersres"
|
|
360
|
+
parameters = arrayrecord_to_parameters(
|
|
361
|
+
recorddict.array_records[f"{res_str}.parameters"], keep_input=keep_input
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
status = _extract_status_from_recorddict(res_str, recorddict)
|
|
365
|
+
return GetParametersRes(status=status, parameters=parameters)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def recorddict_to_getpropertiesins(recorddict: RecordDict) -> GetPropertiesIns:
|
|
369
|
+
"""Derive GetPropertiesIns from a RecordDict object."""
|
|
370
|
+
config_record = recorddict.config_records["getpropertiesins.config"]
|
|
371
|
+
# pylint: disable-next=protected-access
|
|
372
|
+
config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
373
|
+
|
|
374
|
+
return GetPropertiesIns(config=config_dict)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def getpropertiesins_to_recorddict(getpropertiesins: GetPropertiesIns) -> RecordDict:
|
|
378
|
+
"""Construct a RecordDict from a GetPropertiesRes object."""
|
|
379
|
+
recorddict = RecordDict()
|
|
380
|
+
recorddict.config_records["getpropertiesins.config"] = ConfigRecord(
|
|
381
|
+
getpropertiesins.config, # type: ignore
|
|
382
|
+
)
|
|
383
|
+
return recorddict
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def recorddict_to_getpropertiesres(recorddict: RecordDict) -> GetPropertiesRes:
|
|
387
|
+
"""Derive GetPropertiesRes from a RecordDict object."""
|
|
388
|
+
res_str = "getpropertiesres"
|
|
389
|
+
config_record = recorddict.config_records[f"{res_str}.properties"]
|
|
390
|
+
# pylint: disable-next=protected-access
|
|
391
|
+
properties = _check_mapping_from_recordscalartype_to_scalar(config_record)
|
|
392
|
+
|
|
393
|
+
status = _extract_status_from_recorddict(res_str, recorddict=recorddict)
|
|
394
|
+
|
|
395
|
+
return GetPropertiesRes(status=status, properties=properties)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def getpropertiesres_to_recorddict(getpropertiesres: GetPropertiesRes) -> RecordDict:
|
|
399
|
+
"""Construct a RecordDict from a GetPropertiesRes object."""
|
|
400
|
+
recorddict = RecordDict()
|
|
401
|
+
res_str = "getpropertiesres"
|
|
402
|
+
recorddict.config_records[f"{res_str}.properties"] = ConfigRecord(
|
|
403
|
+
getpropertiesres.properties, # type: ignore
|
|
404
|
+
)
|
|
405
|
+
# status
|
|
406
|
+
recorddict = _embed_status_into_recorddict(
|
|
407
|
+
res_str, getpropertiesres.status, recorddict
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return recorddict
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -166,15 +166,15 @@ class RetryInvoker:
|
|
|
166
166
|
|
|
167
167
|
Examples
|
|
168
168
|
--------
|
|
169
|
-
Initialize a `RetryInvoker` with exponential backoff and invoke a function
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
169
|
+
Initialize a `RetryInvoker` with exponential backoff and invoke a function::
|
|
170
|
+
|
|
171
|
+
invoker = RetryInvoker(
|
|
172
|
+
exponential, # Or use `lambda: exponential(3, 2)` to pass arguments
|
|
173
|
+
grpc.RpcError,
|
|
174
|
+
max_tries=3,
|
|
175
|
+
max_time=None,
|
|
176
|
+
)
|
|
177
|
+
invoker.invoke(my_func, arg1, arg2, kw1=kwarg1)
|
|
178
178
|
"""
|
|
179
179
|
|
|
180
180
|
# pylint: disable-next=too-many-arguments
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,61 +15,83 @@
|
|
|
15
15
|
"""Shamir's secret sharing."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
18
|
+
import os
|
|
19
19
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from typing import cast
|
|
21
20
|
|
|
22
21
|
from Crypto.Protocol.SecretSharing import Shamir
|
|
23
22
|
from Crypto.Util.Padding import pad, unpad
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def create_shares(secret: bytes, threshold: int, num: int) -> list[bytes]:
|
|
27
|
-
"""Return list of shares (bytes).
|
|
26
|
+
"""Return a list of shares (bytes).
|
|
27
|
+
|
|
28
|
+
Shares are created from the provided secret using Shamir's secret sharing.
|
|
29
|
+
"""
|
|
30
|
+
# Shamir's secret sharing requires the secret to be a multiple of 16 bytes
|
|
31
|
+
# (AES block size). Pad the secret to the next multiple of 16 bytes.
|
|
28
32
|
secret_padded = pad(secret, 16)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
share_list: list[
|
|
33
|
+
chunks = [secret_padded[i : i + 16] for i in range(0, len(secret_padded), 16)]
|
|
34
|
+
|
|
35
|
+
# The share list should contain shares of the secret, and each share consists of:
|
|
36
|
+
# <4 bytes of index><share of chunk1><share of chunk2>...<share of chunkN>
|
|
37
|
+
share_list: list[bytearray] = [bytearray() for _ in range(num)]
|
|
34
38
|
|
|
35
|
-
|
|
39
|
+
# Create shares for each chunk in parallel
|
|
40
|
+
max_workers = min(len(chunks), os.cpu_count() or 1)
|
|
41
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
36
42
|
for chunk_shares in executor.map(
|
|
37
|
-
lambda
|
|
43
|
+
lambda chunk: _shamir_split(threshold, num, chunk), chunks
|
|
38
44
|
):
|
|
39
45
|
for idx, share in chunk_shares:
|
|
40
|
-
#
|
|
41
|
-
share_list[idx - 1]
|
|
46
|
+
# Initialize the share with the index if it is empty
|
|
47
|
+
if not share_list[idx - 1]:
|
|
48
|
+
share_list[idx - 1] += idx.to_bytes(4, "little", signed=False)
|
|
42
49
|
|
|
43
|
-
|
|
50
|
+
# Append the share to the bytes
|
|
51
|
+
share_list[idx - 1] += share
|
|
52
|
+
|
|
53
|
+
return [bytes(share) for share in share_list]
|
|
44
54
|
|
|
45
55
|
|
|
46
56
|
def _shamir_split(threshold: int, num: int, chunk: bytes) -> list[tuple[int, bytes]]:
|
|
57
|
+
"""Create shares for a chunk using Shamir's secret sharing.
|
|
58
|
+
|
|
59
|
+
Each share is a tuple (index, share_bytes), where share_bytes is 16 bytes long.
|
|
60
|
+
"""
|
|
47
61
|
return Shamir.split(threshold, num, chunk, ssss=False)
|
|
48
62
|
|
|
49
63
|
|
|
50
|
-
# Reconstructing secret with PyCryptodome
|
|
51
64
|
def combine_shares(share_list: list[bytes]) -> bytes:
|
|
52
|
-
"""Reconstruct secret from shares."""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
]
|
|
65
|
+
"""Reconstruct the secret from a list of shares."""
|
|
66
|
+
# Compute the number of chunks
|
|
67
|
+
# Each share contains 4 bytes of index and 16 bytes of share for each chunk
|
|
68
|
+
chunk_num = (len(share_list[0]) - 4) >> 4
|
|
56
69
|
|
|
57
|
-
chunk_num = len(unpickled_share_list[0])
|
|
58
70
|
secret_padded = bytearray(0)
|
|
59
|
-
chunk_shares_list: list[list[tuple[int, bytes]]] = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
71
|
+
chunk_shares_list: list[list[tuple[int, bytes]]] = [[] for _ in range(chunk_num)]
|
|
72
|
+
|
|
73
|
+
# Split shares into chunks
|
|
74
|
+
for share in share_list:
|
|
75
|
+
# The first 4 bytes are the index
|
|
76
|
+
index = int.from_bytes(share[:4], "little", signed=False)
|
|
77
|
+
for i in range(chunk_num):
|
|
78
|
+
start = (i << 4) + 4
|
|
79
|
+
chunk_shares_list[i].append((index, share[start : start + 16]))
|
|
80
|
+
|
|
81
|
+
# Combine shares for each chunk in parallel
|
|
82
|
+
max_workers = min(chunk_num, os.cpu_count() or 1)
|
|
83
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
67
84
|
for chunk in executor.map(_shamir_combine, chunk_shares_list):
|
|
68
85
|
secret_padded += chunk
|
|
69
86
|
|
|
70
|
-
|
|
71
|
-
|
|
87
|
+
try:
|
|
88
|
+
secret = unpad(bytes(secret_padded), 16)
|
|
89
|
+
except ValueError:
|
|
90
|
+
# If unpadding fails, it means the shares are not valid
|
|
91
|
+
raise ValueError("Failed to combine shares") from None
|
|
92
|
+
return secret
|
|
72
93
|
|
|
73
94
|
|
|
74
95
|
def _shamir_combine(shares: list[tuple[int, bytes]]) -> bytes:
|
|
96
|
+
"""Reconstruct a chunk from shares using Shamir's secret sharing."""
|
|
75
97
|
return Shamir.combine(shares, ssss=False)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -42,7 +42,7 @@ class Stage:
|
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
class Key:
|
|
45
|
-
"""Keys for the configs in the
|
|
45
|
+
"""Keys for the configs in the ConfigRecord."""
|
|
46
46
|
|
|
47
47
|
STAGE = "stage"
|
|
48
48
|
SAMPLE_NUMBER = "sample_num"
|