flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from os import urandom
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
|
22
21
|
from flwr.common.constant import (
|
|
@@ -33,6 +32,7 @@ from flwr.common.typing import RunStatus
|
|
|
33
32
|
# pylint: disable=E0611
|
|
34
33
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
35
34
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
35
|
+
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
36
36
|
|
|
37
37
|
# pylint: enable=E0611
|
|
38
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
|
@@ -50,7 +50,8 @@ VALID_RUN_SUB_STATUSES = {
|
|
|
50
50
|
}
|
|
51
51
|
MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
52
52
|
"Error: Message Unavailable - The requested message could not be found in the "
|
|
53
|
-
"database. It may have expired due to its TTL
|
|
53
|
+
"database. It may have expired due to its TTL, been deleted because the "
|
|
54
|
+
"destination SuperNode was removed from the federation, or never existed."
|
|
54
55
|
)
|
|
55
56
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
56
57
|
"Error: Reply Message Unavailable - The reply message has expired."
|
|
@@ -62,7 +63,7 @@ NODE_UNAVAILABLE_ERROR_REASON = (
|
|
|
62
63
|
|
|
63
64
|
|
|
64
65
|
def generate_rand_int_from_bytes(
|
|
65
|
-
num_bytes: int, exclude:
|
|
66
|
+
num_bytes: int, exclude: list[int] | None = None
|
|
66
67
|
) -> int:
|
|
67
68
|
"""Generate a random unsigned integer from `num_bytes` bytes.
|
|
68
69
|
|
|
@@ -76,58 +77,6 @@ def generate_rand_int_from_bytes(
|
|
|
76
77
|
return num
|
|
77
78
|
|
|
78
79
|
|
|
79
|
-
def convert_uint64_to_sint64(u: int) -> int:
|
|
80
|
-
"""Convert a uint64 value to a sint64 value with the same bit sequence.
|
|
81
|
-
|
|
82
|
-
Parameters
|
|
83
|
-
----------
|
|
84
|
-
u : int
|
|
85
|
-
The unsigned 64-bit integer to convert.
|
|
86
|
-
|
|
87
|
-
Returns
|
|
88
|
-
-------
|
|
89
|
-
int
|
|
90
|
-
The signed 64-bit integer equivalent.
|
|
91
|
-
|
|
92
|
-
The signed 64-bit integer will have the same bit pattern as the
|
|
93
|
-
unsigned 64-bit integer but may have a different decimal value.
|
|
94
|
-
|
|
95
|
-
For numbers within the range [0, `sint64` max value], the decimal
|
|
96
|
-
value remains the same. However, for numbers greater than the `sint64`
|
|
97
|
-
max value, the decimal value will differ due to the wraparound caused
|
|
98
|
-
by the sign bit.
|
|
99
|
-
"""
|
|
100
|
-
if u >= (1 << 63):
|
|
101
|
-
return u - (1 << 64)
|
|
102
|
-
return u
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def convert_sint64_to_uint64(s: int) -> int:
|
|
106
|
-
"""Convert a sint64 value to a uint64 value with the same bit sequence.
|
|
107
|
-
|
|
108
|
-
Parameters
|
|
109
|
-
----------
|
|
110
|
-
s : int
|
|
111
|
-
The signed 64-bit integer to convert.
|
|
112
|
-
|
|
113
|
-
Returns
|
|
114
|
-
-------
|
|
115
|
-
int
|
|
116
|
-
The unsigned 64-bit integer equivalent.
|
|
117
|
-
|
|
118
|
-
The unsigned 64-bit integer will have the same bit pattern as the
|
|
119
|
-
signed 64-bit integer but may have a different decimal value.
|
|
120
|
-
|
|
121
|
-
For negative `sint64` values, the conversion adds 2^64 to the
|
|
122
|
-
signed value to obtain the equivalent `uint64` value. For non-negative
|
|
123
|
-
`sint64` values, the decimal value remains unchanged in the `uint64`
|
|
124
|
-
representation.
|
|
125
|
-
"""
|
|
126
|
-
if s < 0:
|
|
127
|
-
return s + (1 << 64)
|
|
128
|
-
return s
|
|
129
|
-
|
|
130
|
-
|
|
131
80
|
def convert_uint64_values_in_dict_to_sint64(
|
|
132
81
|
data_dict: dict[str, int], keys: list[str]
|
|
133
82
|
) -> None:
|
|
@@ -142,7 +91,7 @@ def convert_uint64_values_in_dict_to_sint64(
|
|
|
142
91
|
"""
|
|
143
92
|
for key in keys:
|
|
144
93
|
if key in data_dict:
|
|
145
|
-
data_dict[key] =
|
|
94
|
+
data_dict[key] = uint64_to_int64(data_dict[key])
|
|
146
95
|
|
|
147
96
|
|
|
148
97
|
def convert_sint64_values_in_dict_to_uint64(
|
|
@@ -159,7 +108,7 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
159
108
|
"""
|
|
160
109
|
for key in keys:
|
|
161
110
|
if key in data_dict:
|
|
162
|
-
data_dict[key] =
|
|
111
|
+
data_dict[key] = int64_to_uint64(data_dict[key])
|
|
163
112
|
|
|
164
113
|
|
|
165
114
|
def context_to_bytes(context: Context) -> bytes:
|
|
@@ -308,7 +257,7 @@ def message_ttl_has_expired(message_metadata: Metadata, current_time: float) ->
|
|
|
308
257
|
def verify_message_ids(
|
|
309
258
|
inquired_message_ids: set[str],
|
|
310
259
|
found_message_ins_dict: dict[str, Message],
|
|
311
|
-
current_time:
|
|
260
|
+
current_time: float | None = None,
|
|
312
261
|
update_set: bool = True,
|
|
313
262
|
) -> dict[str, Message]:
|
|
314
263
|
"""Verify found Messages and generate error Messages for invalid ones.
|
|
@@ -351,7 +300,7 @@ def verify_found_message_replies(
|
|
|
351
300
|
inquired_message_ids: set[str],
|
|
352
301
|
found_message_ins_dict: dict[str, Message],
|
|
353
302
|
found_message_res_list: list[Message],
|
|
354
|
-
current_time:
|
|
303
|
+
current_time: float | None = None,
|
|
355
304
|
update_set: bool = True,
|
|
356
305
|
) -> dict[str, Message]:
|
|
357
306
|
"""Verify found Message replies and generate error Message for invalid ones.
|
|
@@ -396,7 +345,7 @@ def check_node_availability_for_in_message(
|
|
|
396
345
|
inquired_in_message_ids: set[str],
|
|
397
346
|
found_in_message_dict: dict[str, Message],
|
|
398
347
|
node_id_to_online_until: dict[int, float],
|
|
399
|
-
current_time:
|
|
348
|
+
current_time: float | None = None,
|
|
400
349
|
update_set: bool = True,
|
|
401
350
|
) -> dict[str, Message]:
|
|
402
351
|
"""Check node availability for given Message and generate error reply Message if
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import INFO
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
import grpc
|
|
22
21
|
|
|
@@ -38,7 +37,7 @@ def run_serverappio_api_grpc(
|
|
|
38
37
|
state_factory: LinkStateFactory,
|
|
39
38
|
ffs_factory: FfsFactory,
|
|
40
39
|
objectstore_factory: ObjectStoreFactory,
|
|
41
|
-
certificates:
|
|
40
|
+
certificates: tuple[bytes, bytes, bytes] | None,
|
|
42
41
|
) -> grpc.Server:
|
|
43
42
|
"""Run ServerAppIo API (gRPC, request-response)."""
|
|
44
43
|
# Create ServerAppIo API gRPC server
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
from logging import DEBUG, ERROR, INFO
|
|
20
|
-
from typing import Optional
|
|
21
20
|
|
|
22
21
|
import grpc
|
|
23
22
|
|
|
@@ -91,7 +90,6 @@ from flwr.server.superlink.utils import abort_if
|
|
|
91
90
|
from flwr.server.utils.validator import validate_message
|
|
92
91
|
from flwr.supercore.ffs import Ffs, FfsFactory
|
|
93
92
|
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
|
|
94
|
-
from flwr.supercore.object_store.utils import store_mapping_and_register_objects
|
|
95
93
|
|
|
96
94
|
|
|
97
95
|
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
@@ -141,6 +139,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
141
139
|
# Attempt to create a token for the provided run ID
|
|
142
140
|
token = state.create_token(request.run_id)
|
|
143
141
|
|
|
142
|
+
# Transition the run to STARTING if token creation was successful
|
|
143
|
+
if token:
|
|
144
|
+
state.update_run_status(
|
|
145
|
+
run_id=request.run_id,
|
|
146
|
+
new_status=RunStatus(Status.STARTING, "", ""),
|
|
147
|
+
)
|
|
148
|
+
|
|
144
149
|
# Return the token
|
|
145
150
|
return RequestTokenResponse(token=token or "")
|
|
146
151
|
|
|
@@ -192,8 +197,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
192
197
|
request_name="PushMessages",
|
|
193
198
|
detail="`messages_list` must not be empty",
|
|
194
199
|
)
|
|
195
|
-
message_ids: list[
|
|
196
|
-
|
|
200
|
+
message_ids: list[str | None] = []
|
|
201
|
+
objects_to_push: set[str] = set()
|
|
202
|
+
for message_proto, object_tree in zip(
|
|
203
|
+
request.messages_list, request.message_object_trees, strict=True
|
|
204
|
+
):
|
|
197
205
|
message = message_from_proto(message_proto=message_proto)
|
|
198
206
|
validation_errors = validate_message(message, is_reply_message=False)
|
|
199
207
|
_raise_if(
|
|
@@ -206,13 +214,12 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
206
214
|
request_name="PushMessages",
|
|
207
215
|
detail="`Message.metadata` has mismatched `run_id`",
|
|
208
216
|
)
|
|
209
|
-
# Store
|
|
210
|
-
|
|
217
|
+
# Store objects
|
|
218
|
+
objects_to_push |= set(store.preregister(request.run_id, object_tree))
|
|
219
|
+
# Store message
|
|
220
|
+
message_id: str | None = state.store_message_ins(message=message)
|
|
211
221
|
message_ids.append(message_id)
|
|
212
222
|
|
|
213
|
-
# Store Message object to descendants mapping and preregister objects
|
|
214
|
-
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
|
215
|
-
|
|
216
223
|
return PushAppMessagesResponse(
|
|
217
224
|
message_ids=[
|
|
218
225
|
str(message_id) if message_id else "" for message_id in message_ids
|
|
@@ -316,7 +323,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
316
323
|
|
|
317
324
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
318
325
|
if result := ffs.get(request.hash_str):
|
|
319
|
-
fab = Fab(request.hash_str, result[0])
|
|
326
|
+
fab = Fab(request.hash_str, result[0], result[1])
|
|
320
327
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
321
328
|
|
|
322
329
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
@@ -343,10 +350,10 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
343
350
|
fab = None
|
|
344
351
|
if run and run.fab_hash:
|
|
345
352
|
if result := ffs.get(run.fab_hash):
|
|
346
|
-
fab = Fab(run.fab_hash, result[0])
|
|
353
|
+
fab = Fab(run.fab_hash, result[0], result[1])
|
|
347
354
|
if run and fab and serverapp_ctxt:
|
|
348
|
-
# Update run status to
|
|
349
|
-
if state.update_run_status(run_id, RunStatus(Status.
|
|
355
|
+
# Update run status to RUNNING
|
|
356
|
+
if state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")):
|
|
350
357
|
log(INFO, "Starting run %d", run_id)
|
|
351
358
|
return PullAppInputsResponse(
|
|
352
359
|
context=context_to_proto(serverapp_ctxt),
|
|
@@ -355,8 +362,12 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
355
362
|
)
|
|
356
363
|
|
|
357
364
|
# Raise an exception if the Run or Fab is not found,
|
|
358
|
-
# or if the status cannot be updated to
|
|
359
|
-
|
|
365
|
+
# or if the status cannot be updated to RUNNING
|
|
366
|
+
context.abort(
|
|
367
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
368
|
+
f"Failed to start run {run_id}",
|
|
369
|
+
)
|
|
370
|
+
raise RuntimeError("Unreachable code") # for mypy
|
|
360
371
|
|
|
361
372
|
def PushAppOutputs(
|
|
362
373
|
self, request: PushAppOutputsRequest, context: grpc.ServicerContext
|
|
@@ -441,20 +452,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
441
452
|
def SendAppHeartbeat(
|
|
442
453
|
self, request: SendAppHeartbeatRequest, context: grpc.ServicerContext
|
|
443
454
|
) -> SendAppHeartbeatResponse:
|
|
444
|
-
"""Handle a heartbeat from
|
|
455
|
+
"""Handle a heartbeat from an app process."""
|
|
445
456
|
log(DEBUG, "ServerAppIoServicer.SendAppHeartbeat")
|
|
446
457
|
|
|
447
458
|
# Init state
|
|
448
459
|
state = self.state_factory.state()
|
|
449
460
|
|
|
450
461
|
# Acknowledge the heartbeat
|
|
451
|
-
|
|
452
|
-
# starting or running status.
|
|
453
|
-
success = state.acknowledge_app_heartbeat(
|
|
454
|
-
run_id=request.run_id,
|
|
455
|
-
heartbeat_interval=request.heartbeat_interval,
|
|
456
|
-
)
|
|
457
|
-
|
|
462
|
+
success = state.acknowledge_app_heartbeat(request.token)
|
|
458
463
|
return SendAppHeartbeatResponse(success=success)
|
|
459
464
|
|
|
460
465
|
def PushObject(
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import INFO
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
import grpc
|
|
22
21
|
|
|
@@ -36,7 +35,7 @@ def run_simulationio_api_grpc(
|
|
|
36
35
|
address: str,
|
|
37
36
|
state_factory: LinkStateFactory,
|
|
38
37
|
ffs_factory: FfsFactory,
|
|
39
|
-
certificates:
|
|
38
|
+
certificates: tuple[bytes, bytes, bytes] | None,
|
|
40
39
|
) -> grpc.Server:
|
|
41
40
|
"""Run SimulationIo API (gRPC, request-response)."""
|
|
42
41
|
# Create SimulationIo API gRPC server
|
|
@@ -110,6 +110,13 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
110
110
|
# Attempt to create a token for the provided run ID
|
|
111
111
|
token = state.create_token(request.run_id)
|
|
112
112
|
|
|
113
|
+
# Transition the run to STARTING if token creation was successful
|
|
114
|
+
if token:
|
|
115
|
+
state.update_run_status(
|
|
116
|
+
run_id=request.run_id,
|
|
117
|
+
new_status=RunStatus(Status.STARTING, "", ""),
|
|
118
|
+
)
|
|
119
|
+
|
|
113
120
|
# Return the token
|
|
114
121
|
return RequestTokenResponse(token=token or "")
|
|
115
122
|
|
|
@@ -150,10 +157,10 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
150
157
|
fab = None
|
|
151
158
|
if run and run.fab_hash:
|
|
152
159
|
if result := ffs.get(run.fab_hash):
|
|
153
|
-
fab = Fab(run.fab_hash, result[0])
|
|
160
|
+
fab = Fab(run.fab_hash, result[0], result[1])
|
|
154
161
|
if run and fab and serverapp_ctxt:
|
|
155
|
-
# Update run status to
|
|
156
|
-
if state.update_run_status(run_id, RunStatus(Status.
|
|
162
|
+
# Update run status to RUNNING
|
|
163
|
+
if state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")):
|
|
157
164
|
log(INFO, "Starting run %d", run_id)
|
|
158
165
|
return PullAppInputsResponse(
|
|
159
166
|
context=context_to_proto(serverapp_ctxt),
|
|
@@ -162,8 +169,12 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
162
169
|
)
|
|
163
170
|
|
|
164
171
|
# Raise an exception if the Run or Fab is not found,
|
|
165
|
-
# or if the status cannot be updated to
|
|
166
|
-
|
|
172
|
+
# or if the status cannot be updated to RUNNING
|
|
173
|
+
context.abort(
|
|
174
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
175
|
+
f"Failed to start run {run_id}",
|
|
176
|
+
)
|
|
177
|
+
raise RuntimeError("Unreachable code") # for mypy
|
|
167
178
|
|
|
168
179
|
def PushAppOutputs(
|
|
169
180
|
self, request: PushAppOutputsRequest, context: ServicerContext
|
|
@@ -257,20 +268,14 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
257
268
|
def SendAppHeartbeat(
|
|
258
269
|
self, request: SendAppHeartbeatRequest, context: grpc.ServicerContext
|
|
259
270
|
) -> SendAppHeartbeatResponse:
|
|
260
|
-
"""Handle a heartbeat from
|
|
261
|
-
log(DEBUG, "
|
|
271
|
+
"""Handle a heartbeat from an app process."""
|
|
272
|
+
log(DEBUG, "SimulationIoServicer.SendAppHeartbeat")
|
|
262
273
|
|
|
263
274
|
# Init state
|
|
264
275
|
state = self.state_factory.state()
|
|
265
276
|
|
|
266
277
|
# Acknowledge the heartbeat
|
|
267
|
-
|
|
268
|
-
# starting or running status.
|
|
269
|
-
success = state.acknowledge_app_heartbeat(
|
|
270
|
-
run_id=request.run_id,
|
|
271
|
-
heartbeat_interval=request.heartbeat_interval,
|
|
272
|
-
)
|
|
273
|
-
|
|
278
|
+
success = state.acknowledge_app_heartbeat(request.token)
|
|
274
279
|
return SendAppHeartbeatResponse(success=success)
|
|
275
280
|
|
|
276
281
|
def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
|
flwr/server/superlink/utils.py
CHANGED
|
@@ -15,8 +15,6 @@
|
|
|
15
15
|
"""SuperLink utilities."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Optional, Union
|
|
19
|
-
|
|
20
18
|
import grpc
|
|
21
19
|
|
|
22
20
|
from flwr.common.constant import Status, SubStatus
|
|
@@ -36,8 +34,8 @@ def check_abort(
|
|
|
36
34
|
run_id: int,
|
|
37
35
|
abort_status_list: list[str],
|
|
38
36
|
state: LinkState,
|
|
39
|
-
store:
|
|
40
|
-
) ->
|
|
37
|
+
store: ObjectStore | None = None,
|
|
38
|
+
) -> str | None:
|
|
41
39
|
"""Check if the status of the provided `run_id` is in `abort_status_list`."""
|
|
42
40
|
run_status: RunStatus = state.get_run_status({run_id})[run_id]
|
|
43
41
|
|
|
@@ -54,7 +52,7 @@ def check_abort(
|
|
|
54
52
|
return None
|
|
55
53
|
|
|
56
54
|
|
|
57
|
-
def abort_grpc_context(msg:
|
|
55
|
+
def abort_grpc_context(msg: str | None, context: grpc.ServicerContext) -> None:
|
|
58
56
|
"""Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
|
|
59
57
|
if msg is not None:
|
|
60
58
|
context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)
|
|
@@ -64,7 +62,7 @@ def abort_if(
|
|
|
64
62
|
run_id: int,
|
|
65
63
|
abort_status_list: list[str],
|
|
66
64
|
state: LinkState,
|
|
67
|
-
store:
|
|
65
|
+
store: ObjectStore | None,
|
|
68
66
|
context: grpc.ServicerContext,
|
|
69
67
|
) -> None:
|
|
70
68
|
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
|
flwr/server/typing.py
CHANGED
flwr/server/utils/tensorboard.py
CHANGED
|
@@ -16,20 +16,16 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from datetime import datetime
|
|
20
21
|
from logging import WARN
|
|
21
|
-
from typing import
|
|
22
|
+
from typing import cast
|
|
22
23
|
|
|
23
24
|
from flwr.common import EvaluateRes, Scalar
|
|
24
25
|
from flwr.common.logger import log
|
|
25
26
|
from flwr.server.client_proxy import ClientProxy
|
|
26
27
|
from flwr.server.strategy import Strategy
|
|
27
28
|
|
|
28
|
-
try:
|
|
29
|
-
import tensorflow as TF
|
|
30
|
-
except ModuleNotFoundError:
|
|
31
|
-
TF = None
|
|
32
|
-
|
|
33
29
|
MISSING_EXTRA_TF = """
|
|
34
30
|
Extra dependency required for using tensorboard are missing.
|
|
35
31
|
The program will continue without tensorboard.
|
|
@@ -59,6 +55,17 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
|
|
|
59
55
|
# Variant 2
|
|
60
56
|
strategy = tensorboard(logdir=LOGDIR)(FedAvg)()
|
|
61
57
|
"""
|
|
58
|
+
log(
|
|
59
|
+
WARN,
|
|
60
|
+
"The `tensorboard` function is deprecated and will be removed "
|
|
61
|
+
"in a future release.",
|
|
62
|
+
)
|
|
63
|
+
# Lazy import of TensorFlow to avoid slow import times
|
|
64
|
+
try:
|
|
65
|
+
import tensorflow as TF # pylint: disable=import-outside-toplevel
|
|
66
|
+
except ModuleNotFoundError:
|
|
67
|
+
TF = None # pylint: disable=invalid-name
|
|
68
|
+
|
|
62
69
|
print(
|
|
63
70
|
"\n\t\033[32mStart TensorBoard with the following parameters"
|
|
64
71
|
f"\n\t$ tensorboard --logdir {logdir}\033[39m\n"
|
|
@@ -93,8 +100,8 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
|
|
|
93
100
|
self,
|
|
94
101
|
server_round: int,
|
|
95
102
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
96
|
-
failures: list[
|
|
97
|
-
) -> tuple[
|
|
103
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
104
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
98
105
|
"""Hooks into aggregate_evaluate for TensorBoard logging purpose."""
|
|
99
106
|
# Execute decorated function and extract results for logging
|
|
100
107
|
# They will be returned at the end of this function but also
|
flwr/server/utils/validator.py
CHANGED
|
@@ -15,10 +15,9 @@
|
|
|
15
15
|
"""Validators."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import time
|
|
19
|
-
|
|
20
18
|
from flwr.common import Message
|
|
21
19
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
20
|
+
from flwr.common.date import now
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
# pylint: disable-next=too-many-branches
|
|
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
|
44
43
|
validation_errors.append("`metadata.ttl` must be higher than zero")
|
|
45
44
|
|
|
46
45
|
# Verify TTL and created_at time
|
|
47
|
-
current_time =
|
|
46
|
+
current_time = now().timestamp()
|
|
48
47
|
if metadata.created_at + metadata.ttl <= current_time:
|
|
49
48
|
validation_errors.append("Message TTL has expired")
|
|
50
49
|
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import io
|
|
19
19
|
import timeit
|
|
20
20
|
from logging import INFO, WARN
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
@@ -47,8 +47,8 @@ class DefaultWorkflow:
|
|
|
47
47
|
|
|
48
48
|
def __init__(
|
|
49
49
|
self,
|
|
50
|
-
fit_workflow:
|
|
51
|
-
evaluate_workflow:
|
|
50
|
+
fit_workflow: Workflow | None = None,
|
|
51
|
+
evaluate_workflow: Workflow | None = None,
|
|
52
52
|
) -> None:
|
|
53
53
|
if fit_workflow is None:
|
|
54
54
|
fit_workflow = default_fit_workflow
|
|
@@ -275,7 +275,7 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
|
275
275
|
|
|
276
276
|
# Aggregate training results
|
|
277
277
|
results: list[tuple[ClientProxy, FitRes]] = []
|
|
278
|
-
failures: list[
|
|
278
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException] = []
|
|
279
279
|
for msg in messages:
|
|
280
280
|
if msg.has_content():
|
|
281
281
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
@@ -357,7 +357,7 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
|
357
357
|
|
|
358
358
|
# Aggregate the evaluation results
|
|
359
359
|
results: list[tuple[ClientProxy, EvaluateRes]] = []
|
|
360
|
-
failures: list[
|
|
360
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException] = []
|
|
361
361
|
for msg in messages:
|
|
362
362
|
if msg.has_content():
|
|
363
363
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
@@ -15,8 +15,6 @@
|
|
|
15
15
|
"""Workflow for the SecAgg protocol."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Optional, Union
|
|
19
|
-
|
|
20
18
|
from .secaggplus_workflow import SecAggPlusWorkflow
|
|
21
19
|
|
|
22
20
|
|
|
@@ -94,13 +92,13 @@ class SecAggWorkflow(SecAggPlusWorkflow):
|
|
|
94
92
|
|
|
95
93
|
def __init__( # pylint: disable=R0913
|
|
96
94
|
self,
|
|
97
|
-
reconstruction_threshold:
|
|
95
|
+
reconstruction_threshold: int | float,
|
|
98
96
|
*,
|
|
99
97
|
max_weight: float = 1000.0,
|
|
100
98
|
clipping_range: float = 8.0,
|
|
101
99
|
quantization_range: int = 4194304,
|
|
102
100
|
modulus_range: int = 4294967296,
|
|
103
|
-
timeout:
|
|
101
|
+
timeout: float | None = None,
|
|
104
102
|
) -> None:
|
|
105
103
|
super().__init__(
|
|
106
104
|
num_shares=1.0,
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import random
|
|
19
19
|
from dataclasses import dataclass, field
|
|
20
20
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
@@ -35,8 +35,6 @@ from flwr.common import (
|
|
|
35
35
|
)
|
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import combine_shares
|
|
37
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
38
|
-
bytes_to_private_key,
|
|
39
|
-
bytes_to_public_key,
|
|
40
38
|
generate_shared_key,
|
|
41
39
|
)
|
|
42
40
|
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
@@ -56,6 +54,10 @@ from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
|
|
56
54
|
from flwr.server.client_proxy import ClientProxy
|
|
57
55
|
from flwr.server.compat.legacy_context import LegacyContext
|
|
58
56
|
from flwr.server.grid import Grid
|
|
57
|
+
from flwr.supercore.primitives.asymmetric import (
|
|
58
|
+
bytes_to_private_key,
|
|
59
|
+
bytes_to_public_key,
|
|
60
|
+
)
|
|
59
61
|
|
|
60
62
|
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
|
61
63
|
from ..constant import Key as WorkflowKey
|
|
@@ -167,14 +169,14 @@ class SecAggPlusWorkflow:
|
|
|
167
169
|
|
|
168
170
|
def __init__( # pylint: disable=R0913
|
|
169
171
|
self,
|
|
170
|
-
num_shares:
|
|
171
|
-
reconstruction_threshold:
|
|
172
|
+
num_shares: int | float,
|
|
173
|
+
reconstruction_threshold: int | float,
|
|
172
174
|
*,
|
|
173
175
|
max_weight: float = 1000.0,
|
|
174
176
|
clipping_range: float = 8.0,
|
|
175
177
|
quantization_range: int = 4194304,
|
|
176
178
|
modulus_range: int = 4294967296,
|
|
177
|
-
timeout:
|
|
179
|
+
timeout: float | None = None,
|
|
178
180
|
) -> None:
|
|
179
181
|
self.num_shares = num_shares
|
|
180
182
|
self.reconstruction_threshold = reconstruction_threshold
|
|
@@ -209,7 +211,7 @@ class SecAggPlusWorkflow:
|
|
|
209
211
|
|
|
210
212
|
def _check_init_params(self) -> None: # pylint: disable=R0912
|
|
211
213
|
# Check `num_shares`
|
|
212
|
-
if not isinstance(self.num_shares, (int
|
|
214
|
+
if not isinstance(self.num_shares, (int | float)):
|
|
213
215
|
raise TypeError("`num_shares` must be of type int or float.")
|
|
214
216
|
if isinstance(self.num_shares, int):
|
|
215
217
|
if self.num_shares == 1:
|
|
@@ -227,7 +229,7 @@ class SecAggPlusWorkflow:
|
|
|
227
229
|
raise ValueError("`num_shares` as a float must be greater than 0.")
|
|
228
230
|
|
|
229
231
|
# Check `reconstruction_threshold`
|
|
230
|
-
if not isinstance(self.reconstruction_threshold, (int
|
|
232
|
+
if not isinstance(self.reconstruction_threshold, (int | float)):
|
|
231
233
|
raise TypeError("`reconstruction_threshold` must be of type int or float.")
|
|
232
234
|
if isinstance(self.reconstruction_threshold, int):
|
|
233
235
|
if self.reconstruction_threshold == 1:
|
|
@@ -465,7 +467,7 @@ class SecAggPlusWorkflow:
|
|
|
465
467
|
dsts += dst_lst
|
|
466
468
|
ciphertexts += ctxt_lst
|
|
467
469
|
|
|
468
|
-
for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
|
|
470
|
+
for src, dst, ciphertext in zip(srcs, dsts, ciphertexts, strict=True):
|
|
469
471
|
if dst in fwd_ciphertexts:
|
|
470
472
|
fwd_ciphertexts[dst].append(ciphertext)
|
|
471
473
|
fwd_srcs[dst].append(src)
|
|
@@ -602,7 +604,7 @@ class SecAggPlusWorkflow:
|
|
|
602
604
|
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
603
605
|
nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
|
|
604
606
|
shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
|
|
605
|
-
for owner_nid, share in zip(nids, shares):
|
|
607
|
+
for owner_nid, share in zip(nids, shares, strict=True):
|
|
606
608
|
collected_shares_dict[owner_nid].append(share)
|
|
607
609
|
|
|
608
610
|
# Remove masks for every active client after collect_masked_vectors stage
|