flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/__init__.py +1 -1
- flwr/cli/app.py +21 -2
- flwr/cli/build.py +83 -58
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +53 -17
- flwr/cli/example.py +1 -1
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +4 -4
- flwr/cli/login/__init__.py +1 -1
- flwr/cli/login/login.py +15 -8
- flwr/cli/ls.py +16 -37
- flwr/cli/new/__init__.py +1 -1
- flwr/cli/new/new.py +4 -4
- flwr/cli/new/templates/__init__.py +1 -1
- flwr/cli/new/templates/app/__init__.py +1 -1
- flwr/cli/new/templates/app/code/__init__.py +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -1
- flwr/cli/run/run.py +11 -19
- flwr/cli/stop.py +3 -3
- flwr/cli/utils.py +42 -17
- flwr/client/__init__.py +3 -3
- flwr/client/client.py +1 -1
- flwr/client/client_app.py +140 -138
- flwr/client/clientapp/__init__.py +1 -8
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +1 -1
- flwr/client/grpc_adapter_client/connection.py +5 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +131 -61
- flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/centraldp_mods.py +1 -1
- flwr/client/mod/comms_mods.py +39 -20
- flwr/client/mod/localdp_mod.py +6 -6
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +174 -68
- flwr/client/run_info_store.py +1 -1
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +3 -3
- flwr/common/address.py +1 -1
- flwr/common/args.py +1 -1
- flwr/common/auth_plugin/__init__.py +3 -1
- flwr/common/auth_plugin/auth_plugin.py +30 -4
- flwr/common/config.py +1 -1
- flwr/common/constant.py +37 -8
- flwr/common/context.py +1 -1
- flwr/common/date.py +1 -1
- flwr/common/differential_privacy.py +1 -1
- flwr/common/differential_privacy_constants.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit.py +6 -6
- flwr/common/exit_handlers.py +31 -1
- flwr/common/grpc.py +1 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/logger.py +1 -1
- flwr/common/message.py +137 -252
- flwr/common/object_ref.py +1 -1
- flwr/common/parameter.py +1 -1
- flwr/common/pyproject.py +1 -1
- flwr/common/record/__init__.py +3 -2
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +121 -243
- flwr/common/record/configrecord.py +71 -16
- flwr/common/record/conversion_utils.py +2 -2
- flwr/common/record/metricrecord.py +71 -20
- flwr/common/record/recorddict.py +207 -90
- flwr/common/record/typeddict.py +1 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +15 -11
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +52 -30
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +60 -184
- flwr/common/serde_utils.py +175 -0
- flwr/common/telemetry.py +2 -2
- flwr/common/typing.py +6 -4
- flwr/common/version.py +1 -1
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +71 -211
- flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
- flwr/{client → compat/client}/grpc_client/connection.py +13 -13
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/__init__.py +1 -1
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +2 -2
- flwr/server/app.py +69 -187
- flwr/server/client_manager.py +1 -1
- flwr/server/client_proxy.py +1 -1
- flwr/server/compat/__init__.py +1 -1
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +51 -29
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/criterion.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/history.py +1 -1
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +1 -1
- flwr/server/server_app.py +65 -58
- flwr/server/server_config.py +1 -1
- flwr/server/serverapp/__init__.py +1 -1
- flwr/server/serverapp/app.py +19 -1
- flwr/server/serverapp_components.py +1 -1
- flwr/server/strategy/__init__.py +1 -1
- flwr/server/strategy/aggregate.py +1 -1
- flwr/server/strategy/bulyan.py +2 -2
- flwr/server/strategy/dp_adaptive_clipping.py +17 -17
- flwr/server/strategy/dp_fixed_clipping.py +17 -17
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedtrimmedavg.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +3 -2
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/strategy/strategy.py +1 -1
- flwr/server/superlink/__init__.py +1 -1
- flwr/server/superlink/ffs/__init__.py +3 -1
- flwr/server/superlink/ffs/disk_ffs.py +1 -1
- flwr/server/superlink/ffs/ffs.py +1 -1
- flwr/server/superlink/ffs/ffs_factory.py +1 -1
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +7 -4
- flwr/server/superlink/linkstate/__init__.py +1 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
- flwr/server/superlink/linkstate/linkstate.py +54 -21
- flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
- flwr/server/superlink/linkstate/utils.py +34 -30
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/__init__.py +1 -1
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +45 -3
- flwr/server/typing.py +1 -1
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +3 -3
- flwr/server/workflow/__init__.py +1 -1
- flwr/server/workflow/constant.py +1 -1
- flwr/server/workflow/default_workflows.py +1 -1
- flwr/server/workflow/secure_aggregation/__init__.py +1 -1
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +18 -1
- flwr/simulation/legacy_app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/ray_transport/utils.py +1 -1
- flwr/simulation/run_simulation.py +2 -2
- flwr/simulation/simulationio_connection.py +1 -1
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/__init__.py +1 -1
- flwr/superexec/app.py +1 -1
- flwr/superexec/deployment.py +7 -3
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +8 -4
- flwr/superexec/exec_servicer.py +126 -24
- flwr/superexec/exec_user_auth_interceptor.py +38 -9
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -2
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +1 -8
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/{client → supernode}/nodestate/__init__.py +1 -1
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
- flwr-1.19.0.dist-info/RECORD +365 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- flwr-1.17.0.dist-info/LICENSE +0 -202
- flwr-1.17.0.dist-info/RECORD +0 -333
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,61 +15,83 @@
|
|
|
15
15
|
"""Shamir's secret sharing."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
18
|
+
import os
|
|
19
19
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from typing import cast
|
|
21
20
|
|
|
22
21
|
from Crypto.Protocol.SecretSharing import Shamir
|
|
23
22
|
from Crypto.Util.Padding import pad, unpad
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def create_shares(secret: bytes, threshold: int, num: int) -> list[bytes]:
|
|
27
|
-
"""Return list of shares (bytes).
|
|
26
|
+
"""Return a list of shares (bytes).
|
|
27
|
+
|
|
28
|
+
Shares are created from the provided secret using Shamir's secret sharing.
|
|
29
|
+
"""
|
|
30
|
+
# Shamir's secret sharing requires the secret to be a multiple of 16 bytes
|
|
31
|
+
# (AES block size). Pad the secret to the next multiple of 16 bytes.
|
|
28
32
|
secret_padded = pad(secret, 16)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
share_list: list[
|
|
33
|
+
chunks = [secret_padded[i : i + 16] for i in range(0, len(secret_padded), 16)]
|
|
34
|
+
|
|
35
|
+
# The share list should contain shares of the secret, and each share consists of:
|
|
36
|
+
# <4 bytes of index><share of chunk1><share of chunk2>...<share of chunkN>
|
|
37
|
+
share_list: list[bytearray] = [bytearray() for _ in range(num)]
|
|
34
38
|
|
|
35
|
-
|
|
39
|
+
# Create shares for each chunk in parallel
|
|
40
|
+
max_workers = min(len(chunks), os.cpu_count() or 1)
|
|
41
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
36
42
|
for chunk_shares in executor.map(
|
|
37
|
-
lambda
|
|
43
|
+
lambda chunk: _shamir_split(threshold, num, chunk), chunks
|
|
38
44
|
):
|
|
39
45
|
for idx, share in chunk_shares:
|
|
40
|
-
#
|
|
41
|
-
share_list[idx - 1]
|
|
46
|
+
# Initialize the share with the index if it is empty
|
|
47
|
+
if not share_list[idx - 1]:
|
|
48
|
+
share_list[idx - 1] += idx.to_bytes(4, "little", signed=False)
|
|
42
49
|
|
|
43
|
-
|
|
50
|
+
# Append the share to the bytes
|
|
51
|
+
share_list[idx - 1] += share
|
|
52
|
+
|
|
53
|
+
return [bytes(share) for share in share_list]
|
|
44
54
|
|
|
45
55
|
|
|
46
56
|
def _shamir_split(threshold: int, num: int, chunk: bytes) -> list[tuple[int, bytes]]:
|
|
57
|
+
"""Create shares for a chunk using Shamir's secret sharing.
|
|
58
|
+
|
|
59
|
+
Each share is a tuple (index, share_bytes), where share_bytes is 16 bytes long.
|
|
60
|
+
"""
|
|
47
61
|
return Shamir.split(threshold, num, chunk, ssss=False)
|
|
48
62
|
|
|
49
63
|
|
|
50
|
-
# Reconstructing secret with PyCryptodome
|
|
51
64
|
def combine_shares(share_list: list[bytes]) -> bytes:
|
|
52
|
-
"""Reconstruct secret from shares."""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
]
|
|
65
|
+
"""Reconstruct the secret from a list of shares."""
|
|
66
|
+
# Compute the number of chunks
|
|
67
|
+
# Each share contains 4 bytes of index and 16 bytes of share for each chunk
|
|
68
|
+
chunk_num = (len(share_list[0]) - 4) >> 4
|
|
56
69
|
|
|
57
|
-
chunk_num = len(unpickled_share_list[0])
|
|
58
70
|
secret_padded = bytearray(0)
|
|
59
|
-
chunk_shares_list: list[list[tuple[int, bytes]]] = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
71
|
+
chunk_shares_list: list[list[tuple[int, bytes]]] = [[] for _ in range(chunk_num)]
|
|
72
|
+
|
|
73
|
+
# Split shares into chunks
|
|
74
|
+
for share in share_list:
|
|
75
|
+
# The first 4 bytes are the index
|
|
76
|
+
index = int.from_bytes(share[:4], "little", signed=False)
|
|
77
|
+
for i in range(chunk_num):
|
|
78
|
+
start = (i << 4) + 4
|
|
79
|
+
chunk_shares_list[i].append((index, share[start : start + 16]))
|
|
80
|
+
|
|
81
|
+
# Combine shares for each chunk in parallel
|
|
82
|
+
max_workers = min(chunk_num, os.cpu_count() or 1)
|
|
83
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
67
84
|
for chunk in executor.map(_shamir_combine, chunk_shares_list):
|
|
68
85
|
secret_padded += chunk
|
|
69
86
|
|
|
70
|
-
|
|
71
|
-
|
|
87
|
+
try:
|
|
88
|
+
secret = unpad(bytes(secret_padded), 16)
|
|
89
|
+
except ValueError:
|
|
90
|
+
# If unpadding fails, it means the shares are not valid
|
|
91
|
+
raise ValueError("Failed to combine shares") from None
|
|
92
|
+
return secret
|
|
72
93
|
|
|
73
94
|
|
|
74
95
|
def _shamir_combine(shares: list[tuple[int, bytes]]) -> bytes:
|
|
96
|
+
"""Reconstruct a chunk from shares using Shamir's secret sharing."""
|
|
75
97
|
return Shamir.combine(shares, ssss=False)
|
flwr/common/serde.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,28 +16,20 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from collections import OrderedDict
|
|
19
|
-
from
|
|
20
|
-
from typing import Any, TypeVar, cast
|
|
21
|
-
|
|
22
|
-
from google.protobuf.message import Message as GrpcMessage
|
|
19
|
+
from typing import Any, cast
|
|
23
20
|
|
|
24
21
|
# pylint: disable=E0611
|
|
25
22
|
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
26
|
-
from flwr.proto.error_pb2 import Error as ProtoError
|
|
27
23
|
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
28
24
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
29
25
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
30
|
-
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
|
31
26
|
from flwr.proto.recorddict_pb2 import Array as ProtoArray
|
|
32
27
|
from flwr.proto.recorddict_pb2 import ArrayRecord as ProtoArrayRecord
|
|
33
|
-
from flwr.proto.recorddict_pb2 import BoolList, BytesList
|
|
34
28
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
35
29
|
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
|
36
|
-
from flwr.proto.recorddict_pb2 import DoubleList
|
|
37
30
|
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
|
|
38
31
|
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
|
|
39
32
|
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
40
|
-
from flwr.proto.recorddict_pb2 import SintList, StringList, UintList
|
|
41
33
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
|
42
34
|
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
|
|
43
35
|
from flwr.proto.transport_pb2 import (
|
|
@@ -60,8 +52,16 @@ from . import (
|
|
|
60
52
|
RecordDict,
|
|
61
53
|
typing,
|
|
62
54
|
)
|
|
63
|
-
from .
|
|
64
|
-
from .
|
|
55
|
+
from .constant import INT64_MAX_VALUE
|
|
56
|
+
from .message import Message, make_message
|
|
57
|
+
from .serde_utils import (
|
|
58
|
+
error_from_proto,
|
|
59
|
+
error_to_proto,
|
|
60
|
+
metadata_from_proto,
|
|
61
|
+
metadata_to_proto,
|
|
62
|
+
record_value_dict_from_proto,
|
|
63
|
+
record_value_dict_to_proto,
|
|
64
|
+
)
|
|
65
65
|
|
|
66
66
|
# === Parameters message ===
|
|
67
67
|
|
|
@@ -339,7 +339,6 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
|
|
|
339
339
|
|
|
340
340
|
|
|
341
341
|
# === Scalar messages ===
|
|
342
|
-
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
|
343
342
|
|
|
344
343
|
|
|
345
344
|
def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
|
|
@@ -377,107 +376,21 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
|
377
376
|
# === Record messages ===
|
|
378
377
|
|
|
379
378
|
|
|
380
|
-
_type_to_field: dict[type, str] = {
|
|
381
|
-
float: "double",
|
|
382
|
-
int: "sint64",
|
|
383
|
-
bool: "bool",
|
|
384
|
-
str: "string",
|
|
385
|
-
bytes: "bytes",
|
|
386
|
-
}
|
|
387
|
-
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
|
388
|
-
float: (DoubleList, "double_list"),
|
|
389
|
-
int: (SintList, "sint_list"),
|
|
390
|
-
bool: (BoolList, "bool_list"),
|
|
391
|
-
str: (StringList, "string_list"),
|
|
392
|
-
bytes: (BytesList, "bytes_list"),
|
|
393
|
-
}
|
|
394
|
-
T = TypeVar("T")
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
def _is_uint64(value: Any) -> bool:
|
|
398
|
-
"""Check if a value is uint64."""
|
|
399
|
-
return isinstance(value, int) and value > INT64_MAX_VALUE
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
def _record_value_to_proto(
|
|
403
|
-
value: Any, allowed_types: list[type], proto_class: type[T]
|
|
404
|
-
) -> T:
|
|
405
|
-
"""Serialize `*RecordValue` to ProtoBuf.
|
|
406
|
-
|
|
407
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
|
408
|
-
"""
|
|
409
|
-
arg = {}
|
|
410
|
-
for t in allowed_types:
|
|
411
|
-
# Single element
|
|
412
|
-
# Note: `isinstance(False, int) == True`.
|
|
413
|
-
if isinstance(value, t):
|
|
414
|
-
fld = _type_to_field[t]
|
|
415
|
-
if t is int and _is_uint64(value):
|
|
416
|
-
fld = "uint64"
|
|
417
|
-
arg[fld] = value
|
|
418
|
-
return proto_class(**arg)
|
|
419
|
-
# List
|
|
420
|
-
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
|
421
|
-
list_class, fld = _list_type_to_class_and_field[t]
|
|
422
|
-
# Use UintList if any element is of type `uint64`.
|
|
423
|
-
if t is int and any(_is_uint64(v) for v in value):
|
|
424
|
-
list_class, fld = UintList, "uint_list"
|
|
425
|
-
arg[fld] = list_class(vals=value)
|
|
426
|
-
return proto_class(**arg)
|
|
427
|
-
# Invalid types
|
|
428
|
-
raise TypeError(
|
|
429
|
-
f"The type of the following value is not allowed "
|
|
430
|
-
f"in '{proto_class.__name__}':\n{value}"
|
|
431
|
-
)
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
|
435
|
-
"""Deserialize `*RecordValue` from ProtoBuf."""
|
|
436
|
-
value_field = cast(str, value_proto.WhichOneof("value"))
|
|
437
|
-
if value_field.endswith("list"):
|
|
438
|
-
value = list(getattr(value_proto, value_field).vals)
|
|
439
|
-
else:
|
|
440
|
-
value = getattr(value_proto, value_field)
|
|
441
|
-
return value
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
def _record_value_dict_to_proto(
|
|
445
|
-
value_dict: TypedDict[str, Any],
|
|
446
|
-
allowed_types: list[type],
|
|
447
|
-
value_proto_class: type[T],
|
|
448
|
-
) -> dict[str, T]:
|
|
449
|
-
"""Serialize the record value dict to ProtoBuf.
|
|
450
|
-
|
|
451
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
|
452
|
-
"""
|
|
453
|
-
# Move bool to the front
|
|
454
|
-
if bool in allowed_types and allowed_types[0] != bool:
|
|
455
|
-
allowed_types.remove(bool)
|
|
456
|
-
allowed_types.insert(0, bool)
|
|
457
|
-
|
|
458
|
-
def proto(_v: Any) -> T:
|
|
459
|
-
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
|
460
|
-
|
|
461
|
-
return {k: proto(v) for k, v in value_dict.items()}
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
def _record_value_dict_from_proto(
|
|
465
|
-
value_dict_proto: MutableMapping[str, Any]
|
|
466
|
-
) -> dict[str, Any]:
|
|
467
|
-
"""Deserialize the record value dict from ProtoBuf."""
|
|
468
|
-
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
|
469
|
-
|
|
470
|
-
|
|
471
379
|
def array_to_proto(array: Array) -> ProtoArray:
|
|
472
380
|
"""Serialize Array to ProtoBuf."""
|
|
473
|
-
return ProtoArray(
|
|
381
|
+
return ProtoArray(
|
|
382
|
+
dtype=array.dtype,
|
|
383
|
+
shape=array.shape,
|
|
384
|
+
stype=array.stype,
|
|
385
|
+
data=array.data,
|
|
386
|
+
)
|
|
474
387
|
|
|
475
388
|
|
|
476
389
|
def array_from_proto(array_proto: ProtoArray) -> Array:
|
|
477
390
|
"""Deserialize Array from ProtoBuf."""
|
|
478
391
|
return Array(
|
|
479
392
|
dtype=array_proto.dtype,
|
|
480
|
-
shape=
|
|
393
|
+
shape=tuple(array_proto.shape),
|
|
481
394
|
stype=array_proto.stype,
|
|
482
395
|
data=array_proto.data,
|
|
483
396
|
)
|
|
@@ -486,8 +399,10 @@ def array_from_proto(array_proto: ProtoArray) -> Array:
|
|
|
486
399
|
def array_record_to_proto(record: ArrayRecord) -> ProtoArrayRecord:
|
|
487
400
|
"""Serialize ArrayRecord to ProtoBuf."""
|
|
488
401
|
return ProtoArrayRecord(
|
|
489
|
-
|
|
490
|
-
|
|
402
|
+
items=[
|
|
403
|
+
ProtoArrayRecord.Item(key=k, value=array_to_proto(v))
|
|
404
|
+
for k, v in record.items()
|
|
405
|
+
]
|
|
491
406
|
)
|
|
492
407
|
|
|
493
408
|
|
|
@@ -497,7 +412,7 @@ def array_record_from_proto(
|
|
|
497
412
|
"""Deserialize ArrayRecord from ProtoBuf."""
|
|
498
413
|
return ArrayRecord(
|
|
499
414
|
array_dict=OrderedDict(
|
|
500
|
-
|
|
415
|
+
{item.key: array_from_proto(item.value) for item in record_proto.items}
|
|
501
416
|
),
|
|
502
417
|
keep_input=False,
|
|
503
418
|
)
|
|
@@ -505,17 +420,19 @@ def array_record_from_proto(
|
|
|
505
420
|
|
|
506
421
|
def metric_record_to_proto(record: MetricRecord) -> ProtoMetricRecord:
|
|
507
422
|
"""Serialize MetricRecord to ProtoBuf."""
|
|
423
|
+
protos = record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
|
|
508
424
|
return ProtoMetricRecord(
|
|
509
|
-
|
|
425
|
+
items=[ProtoMetricRecord.Item(key=k, value=v) for k, v in protos.items()]
|
|
510
426
|
)
|
|
511
427
|
|
|
512
428
|
|
|
513
429
|
def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
514
430
|
"""Deserialize MetricRecord from ProtoBuf."""
|
|
431
|
+
protos = {item.key: item.value for item in record_proto.items}
|
|
515
432
|
return MetricRecord(
|
|
516
433
|
metric_dict=cast(
|
|
517
434
|
dict[str, typing.MetricRecordValues],
|
|
518
|
-
|
|
435
|
+
record_value_dict_from_proto(protos),
|
|
519
436
|
),
|
|
520
437
|
keep_input=False,
|
|
521
438
|
)
|
|
@@ -523,68 +440,60 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
|
523
440
|
|
|
524
441
|
def config_record_to_proto(record: ConfigRecord) -> ProtoConfigRecord:
|
|
525
442
|
"""Serialize ConfigRecord to ProtoBuf."""
|
|
443
|
+
protos = record_value_dict_to_proto(
|
|
444
|
+
record,
|
|
445
|
+
[bool, int, float, str, bytes],
|
|
446
|
+
ProtoConfigRecordValue,
|
|
447
|
+
)
|
|
526
448
|
return ProtoConfigRecord(
|
|
527
|
-
|
|
528
|
-
record,
|
|
529
|
-
[bool, int, float, str, bytes],
|
|
530
|
-
ProtoConfigRecordValue,
|
|
531
|
-
)
|
|
449
|
+
items=[ProtoConfigRecord.Item(key=k, value=v) for k, v in protos.items()]
|
|
532
450
|
)
|
|
533
451
|
|
|
534
452
|
|
|
535
453
|
def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
|
|
536
454
|
"""Deserialize ConfigRecord from ProtoBuf."""
|
|
455
|
+
protos = {item.key: item.value for item in record_proto.items}
|
|
537
456
|
return ConfigRecord(
|
|
538
457
|
config_dict=cast(
|
|
539
458
|
dict[str, typing.ConfigRecordValues],
|
|
540
|
-
|
|
459
|
+
record_value_dict_from_proto(protos),
|
|
541
460
|
),
|
|
542
461
|
keep_input=False,
|
|
543
462
|
)
|
|
544
463
|
|
|
545
464
|
|
|
546
|
-
# === Error message ===
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
def error_to_proto(error: Error) -> ProtoError:
|
|
550
|
-
"""Serialize Error to ProtoBuf."""
|
|
551
|
-
reason = error.reason if error.reason else ""
|
|
552
|
-
return ProtoError(code=error.code, reason=reason)
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
def error_from_proto(error_proto: ProtoError) -> Error:
|
|
556
|
-
"""Deserialize Error from ProtoBuf."""
|
|
557
|
-
reason = error_proto.reason if len(error_proto.reason) > 0 else None
|
|
558
|
-
return Error(code=error_proto.code, reason=reason)
|
|
559
|
-
|
|
560
|
-
|
|
561
465
|
# === RecordDict message ===
|
|
562
466
|
|
|
563
467
|
|
|
564
468
|
def recorddict_to_proto(recorddict: RecordDict) -> ProtoRecordDict:
|
|
565
469
|
"""Serialize RecordDict to ProtoBuf."""
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
470
|
+
item_cls = ProtoRecordDict.Item
|
|
471
|
+
items: list[ProtoRecordDict.Item] = []
|
|
472
|
+
for k, v in recorddict.items():
|
|
473
|
+
if isinstance(v, ArrayRecord):
|
|
474
|
+
items += [item_cls(key=k, array_record=array_record_to_proto(v))]
|
|
475
|
+
elif isinstance(v, MetricRecord):
|
|
476
|
+
items += [item_cls(key=k, metric_record=metric_record_to_proto(v))]
|
|
477
|
+
elif isinstance(v, ConfigRecord):
|
|
478
|
+
items += [item_cls(key=k, config_record=config_record_to_proto(v))]
|
|
479
|
+
else:
|
|
480
|
+
raise ValueError(f"Unsupported record type: {type(v)}")
|
|
481
|
+
return ProtoRecordDict(items=items)
|
|
577
482
|
|
|
578
483
|
|
|
579
484
|
def recorddict_from_proto(recorddict_proto: ProtoRecordDict) -> RecordDict:
|
|
580
485
|
"""Deserialize RecordDict from ProtoBuf."""
|
|
581
486
|
ret = RecordDict()
|
|
582
|
-
for
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
487
|
+
for item in recorddict_proto.items:
|
|
488
|
+
field = item.WhichOneof("value")
|
|
489
|
+
if field == "array_record":
|
|
490
|
+
ret[item.key] = array_record_from_proto(item.array_record)
|
|
491
|
+
elif field == "metric_record":
|
|
492
|
+
ret[item.key] = metric_record_from_proto(item.metric_record)
|
|
493
|
+
elif field == "config_record":
|
|
494
|
+
ret[item.key] = config_record_from_proto(item.config_record)
|
|
495
|
+
else:
|
|
496
|
+
raise ValueError(f"Unsupported record type: {field}")
|
|
588
497
|
return ret
|
|
589
498
|
|
|
590
499
|
|
|
@@ -646,41 +555,6 @@ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
|
|
|
646
555
|
return cast(typing.UserConfigValue, scalar)
|
|
647
556
|
|
|
648
557
|
|
|
649
|
-
# === Metadata messages ===
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
|
653
|
-
"""Serialize `Metadata` to ProtoBuf."""
|
|
654
|
-
proto = ProtoMetadata( # pylint: disable=E1101
|
|
655
|
-
run_id=metadata.run_id,
|
|
656
|
-
message_id=metadata.message_id,
|
|
657
|
-
src_node_id=metadata.src_node_id,
|
|
658
|
-
dst_node_id=metadata.dst_node_id,
|
|
659
|
-
reply_to_message_id=metadata.reply_to_message_id,
|
|
660
|
-
group_id=metadata.group_id,
|
|
661
|
-
ttl=metadata.ttl,
|
|
662
|
-
message_type=metadata.message_type,
|
|
663
|
-
created_at=metadata.created_at,
|
|
664
|
-
)
|
|
665
|
-
return proto
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
|
|
669
|
-
"""Deserialize `Metadata` from ProtoBuf."""
|
|
670
|
-
metadata = Metadata(
|
|
671
|
-
run_id=metadata_proto.run_id,
|
|
672
|
-
message_id=metadata_proto.message_id,
|
|
673
|
-
src_node_id=metadata_proto.src_node_id,
|
|
674
|
-
dst_node_id=metadata_proto.dst_node_id,
|
|
675
|
-
reply_to_message_id=metadata_proto.reply_to_message_id,
|
|
676
|
-
group_id=metadata_proto.group_id,
|
|
677
|
-
created_at=metadata_proto.created_at,
|
|
678
|
-
ttl=metadata_proto.ttl,
|
|
679
|
-
message_type=metadata_proto.message_type,
|
|
680
|
-
)
|
|
681
|
-
return metadata
|
|
682
|
-
|
|
683
|
-
|
|
684
558
|
# === Message messages ===
|
|
685
559
|
|
|
686
560
|
|
|
@@ -756,6 +630,7 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
|
|
|
756
630
|
running_at=run.running_at,
|
|
757
631
|
finished_at=run.finished_at,
|
|
758
632
|
status=run_status_to_proto(run.status),
|
|
633
|
+
flwr_aid=run.flwr_aid,
|
|
759
634
|
)
|
|
760
635
|
return proto
|
|
761
636
|
|
|
@@ -773,6 +648,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
|
773
648
|
running_at=run_proto.running_at,
|
|
774
649
|
finished_at=run_proto.finished_at,
|
|
775
650
|
status=run_status_from_proto(run_proto.status),
|
|
651
|
+
flwr_aid=run_proto.flwr_aid,
|
|
776
652
|
)
|
|
777
653
|
return run
|
|
778
654
|
|