flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/app/__init__.py +15 -0
  3. flwr/app/error.py +68 -0
  4. flwr/app/metadata.py +223 -0
  5. flwr/cli/__init__.py +1 -1
  6. flwr/cli/app.py +21 -2
  7. flwr/cli/build.py +83 -58
  8. flwr/cli/cli_user_auth_interceptor.py +1 -1
  9. flwr/cli/config_utils.py +53 -17
  10. flwr/cli/example.py +1 -1
  11. flwr/cli/install.py +1 -1
  12. flwr/cli/log.py +4 -4
  13. flwr/cli/login/__init__.py +1 -1
  14. flwr/cli/login/login.py +15 -8
  15. flwr/cli/ls.py +16 -37
  16. flwr/cli/new/__init__.py +1 -1
  17. flwr/cli/new/new.py +4 -4
  18. flwr/cli/new/templates/__init__.py +1 -1
  19. flwr/cli/new/templates/app/__init__.py +1 -1
  20. flwr/cli/new/templates/app/code/__init__.py +1 -1
  21. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  22. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
  24. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  25. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  26. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
  28. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  29. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  30. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  33. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  34. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  35. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  36. flwr/cli/run/__init__.py +1 -1
  37. flwr/cli/run/run.py +11 -19
  38. flwr/cli/stop.py +3 -3
  39. flwr/cli/utils.py +42 -17
  40. flwr/client/__init__.py +3 -3
  41. flwr/client/client.py +1 -1
  42. flwr/client/client_app.py +140 -138
  43. flwr/client/clientapp/__init__.py +1 -8
  44. flwr/client/clientapp/utils.py +1 -1
  45. flwr/client/dpfedavg_numpy_client.py +1 -1
  46. flwr/client/grpc_adapter_client/__init__.py +1 -1
  47. flwr/client/grpc_adapter_client/connection.py +5 -5
  48. flwr/client/grpc_rere_client/__init__.py +1 -1
  49. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  50. flwr/client/grpc_rere_client/connection.py +131 -61
  51. flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
  52. flwr/client/message_handler/__init__.py +1 -1
  53. flwr/client/message_handler/message_handler.py +2 -2
  54. flwr/client/mod/__init__.py +1 -1
  55. flwr/client/mod/centraldp_mods.py +1 -1
  56. flwr/client/mod/comms_mods.py +39 -20
  57. flwr/client/mod/localdp_mod.py +6 -6
  58. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  59. flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
  60. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  61. flwr/client/mod/utils.py +1 -1
  62. flwr/client/numpy_client.py +1 -1
  63. flwr/client/rest_client/__init__.py +1 -1
  64. flwr/client/rest_client/connection.py +174 -68
  65. flwr/client/run_info_store.py +1 -1
  66. flwr/client/typing.py +1 -1
  67. flwr/clientapp/__init__.py +15 -0
  68. flwr/common/__init__.py +3 -3
  69. flwr/common/address.py +1 -1
  70. flwr/common/args.py +1 -1
  71. flwr/common/auth_plugin/__init__.py +3 -1
  72. flwr/common/auth_plugin/auth_plugin.py +30 -4
  73. flwr/common/config.py +1 -1
  74. flwr/common/constant.py +37 -8
  75. flwr/common/context.py +1 -1
  76. flwr/common/date.py +1 -1
  77. flwr/common/differential_privacy.py +1 -1
  78. flwr/common/differential_privacy_constants.py +1 -1
  79. flwr/common/dp.py +1 -1
  80. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  81. flwr/common/exit/exit.py +6 -6
  82. flwr/common/exit_handlers.py +31 -1
  83. flwr/common/grpc.py +1 -1
  84. flwr/common/heartbeat.py +165 -0
  85. flwr/common/inflatable.py +290 -0
  86. flwr/common/inflatable_grpc_utils.py +99 -0
  87. flwr/common/inflatable_rest_utils.py +99 -0
  88. flwr/common/inflatable_utils.py +341 -0
  89. flwr/common/logger.py +1 -1
  90. flwr/common/message.py +137 -252
  91. flwr/common/object_ref.py +1 -1
  92. flwr/common/parameter.py +1 -1
  93. flwr/common/pyproject.py +1 -1
  94. flwr/common/record/__init__.py +3 -2
  95. flwr/common/record/array.py +323 -0
  96. flwr/common/record/arrayrecord.py +121 -243
  97. flwr/common/record/configrecord.py +71 -16
  98. flwr/common/record/conversion_utils.py +2 -2
  99. flwr/common/record/metricrecord.py +71 -20
  100. flwr/common/record/recorddict.py +207 -90
  101. flwr/common/record/typeddict.py +1 -1
  102. flwr/common/recorddict_compat.py +2 -2
  103. flwr/common/retry_invoker.py +15 -11
  104. flwr/common/secure_aggregation/__init__.py +1 -1
  105. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  106. flwr/common/secure_aggregation/crypto/shamir.py +52 -30
  107. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  108. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  109. flwr/common/secure_aggregation/quantization.py +1 -1
  110. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  111. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  112. flwr/common/serde.py +60 -184
  113. flwr/common/serde_utils.py +175 -0
  114. flwr/common/telemetry.py +2 -2
  115. flwr/common/typing.py +6 -4
  116. flwr/common/version.py +1 -1
  117. flwr/compat/__init__.py +15 -0
  118. flwr/compat/client/__init__.py +15 -0
  119. flwr/{client → compat/client}/app.py +71 -211
  120. flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
  121. flwr/{client → compat/client}/grpc_client/connection.py +13 -13
  122. flwr/compat/common/__init__.py +15 -0
  123. flwr/compat/server/__init__.py +15 -0
  124. flwr/compat/server/app.py +174 -0
  125. flwr/compat/simulation/__init__.py +15 -0
  126. flwr/proto/__init__.py +1 -1
  127. flwr/proto/fleet_pb2.py +32 -27
  128. flwr/proto/fleet_pb2.pyi +49 -35
  129. flwr/proto/fleet_pb2_grpc.py +117 -13
  130. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  131. flwr/proto/heartbeat_pb2.py +33 -0
  132. flwr/proto/heartbeat_pb2.pyi +66 -0
  133. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  134. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  135. flwr/proto/message_pb2.py +28 -11
  136. flwr/proto/message_pb2.pyi +125 -0
  137. flwr/proto/recorddict_pb2.py +16 -28
  138. flwr/proto/recorddict_pb2.pyi +46 -64
  139. flwr/proto/run_pb2.py +24 -32
  140. flwr/proto/run_pb2.pyi +4 -52
  141. flwr/proto/serverappio_pb2.py +32 -23
  142. flwr/proto/serverappio_pb2.pyi +45 -3
  143. flwr/proto/serverappio_pb2_grpc.py +138 -34
  144. flwr/proto/serverappio_pb2_grpc.pyi +54 -13
  145. flwr/proto/simulationio_pb2.py +12 -11
  146. flwr/proto/simulationio_pb2_grpc.py +35 -0
  147. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  148. flwr/server/__init__.py +2 -2
  149. flwr/server/app.py +69 -187
  150. flwr/server/client_manager.py +1 -1
  151. flwr/server/client_proxy.py +1 -1
  152. flwr/server/compat/__init__.py +1 -1
  153. flwr/server/compat/app.py +1 -1
  154. flwr/server/compat/app_utils.py +51 -29
  155. flwr/server/compat/legacy_context.py +1 -1
  156. flwr/server/criterion.py +1 -1
  157. flwr/server/fleet_event_log_interceptor.py +2 -2
  158. flwr/server/grid/grid.py +3 -3
  159. flwr/server/grid/grpc_grid.py +104 -34
  160. flwr/server/grid/inmemory_grid.py +5 -4
  161. flwr/server/history.py +1 -1
  162. flwr/server/run_serverapp.py +1 -1
  163. flwr/server/server.py +1 -1
  164. flwr/server/server_app.py +65 -58
  165. flwr/server/server_config.py +1 -1
  166. flwr/server/serverapp/__init__.py +1 -1
  167. flwr/server/serverapp/app.py +19 -1
  168. flwr/server/serverapp_components.py +1 -1
  169. flwr/server/strategy/__init__.py +1 -1
  170. flwr/server/strategy/aggregate.py +1 -1
  171. flwr/server/strategy/bulyan.py +2 -2
  172. flwr/server/strategy/dp_adaptive_clipping.py +17 -17
  173. flwr/server/strategy/dp_fixed_clipping.py +17 -17
  174. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  175. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  176. flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
  177. flwr/server/strategy/fedadagrad.py +1 -1
  178. flwr/server/strategy/fedadam.py +1 -1
  179. flwr/server/strategy/fedavg.py +1 -1
  180. flwr/server/strategy/fedavg_android.py +1 -1
  181. flwr/server/strategy/fedavgm.py +1 -1
  182. flwr/server/strategy/fedmedian.py +1 -1
  183. flwr/server/strategy/fedopt.py +1 -1
  184. flwr/server/strategy/fedprox.py +1 -1
  185. flwr/server/strategy/fedtrimmedavg.py +1 -1
  186. flwr/server/strategy/fedxgb_bagging.py +1 -1
  187. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  188. flwr/server/strategy/fedxgb_nn_avg.py +3 -2
  189. flwr/server/strategy/fedyogi.py +1 -1
  190. flwr/server/strategy/krum.py +1 -1
  191. flwr/server/strategy/qfedavg.py +1 -1
  192. flwr/server/strategy/strategy.py +1 -1
  193. flwr/server/superlink/__init__.py +1 -1
  194. flwr/server/superlink/ffs/__init__.py +3 -1
  195. flwr/server/superlink/ffs/disk_ffs.py +1 -1
  196. flwr/server/superlink/ffs/ffs.py +1 -1
  197. flwr/server/superlink/ffs/ffs_factory.py +1 -1
  198. flwr/server/superlink/fleet/__init__.py +1 -1
  199. flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
  200. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
  201. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  202. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  203. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  204. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  205. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
  206. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  207. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  208. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
  209. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  210. flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
  211. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  212. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
  213. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  214. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  215. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  216. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  217. flwr/server/superlink/fleet/vce/vce_api.py +7 -4
  218. flwr/server/superlink/linkstate/__init__.py +1 -1
  219. flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
  220. flwr/server/superlink/linkstate/linkstate.py +54 -21
  221. flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
  222. flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
  223. flwr/server/superlink/linkstate/utils.py +34 -30
  224. flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
  225. flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
  226. flwr/server/superlink/simulation/__init__.py +1 -1
  227. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  228. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  229. flwr/server/superlink/utils.py +45 -3
  230. flwr/server/typing.py +1 -1
  231. flwr/server/utils/__init__.py +1 -1
  232. flwr/server/utils/tensorboard.py +1 -1
  233. flwr/server/utils/validator.py +3 -3
  234. flwr/server/workflow/__init__.py +1 -1
  235. flwr/server/workflow/constant.py +1 -1
  236. flwr/server/workflow/default_workflows.py +1 -1
  237. flwr/server/workflow/secure_aggregation/__init__.py +1 -1
  238. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
  239. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
  240. flwr/serverapp/__init__.py +15 -0
  241. flwr/simulation/__init__.py +1 -1
  242. flwr/simulation/app.py +18 -1
  243. flwr/simulation/legacy_app.py +1 -1
  244. flwr/simulation/ray_transport/__init__.py +1 -1
  245. flwr/simulation/ray_transport/ray_actor.py +1 -1
  246. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  247. flwr/simulation/ray_transport/utils.py +1 -1
  248. flwr/simulation/run_simulation.py +2 -2
  249. flwr/simulation/simulationio_connection.py +1 -1
  250. flwr/supercore/__init__.py +15 -0
  251. flwr/supercore/object_store/__init__.py +24 -0
  252. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  253. flwr/supercore/object_store/object_store.py +192 -0
  254. flwr/supercore/object_store/object_store_factory.py +44 -0
  255. flwr/superexec/__init__.py +1 -1
  256. flwr/superexec/app.py +1 -1
  257. flwr/superexec/deployment.py +7 -3
  258. flwr/superexec/exec_event_log_interceptor.py +4 -4
  259. flwr/superexec/exec_grpc.py +8 -4
  260. flwr/superexec/exec_servicer.py +126 -24
  261. flwr/superexec/exec_user_auth_interceptor.py +38 -9
  262. flwr/superexec/executor.py +5 -1
  263. flwr/superexec/simulation.py +8 -2
  264. flwr/superlink/__init__.py +15 -0
  265. flwr/{client/supernode → supernode}/__init__.py +1 -8
  266. flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
  267. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
  268. flwr/supernode/cli/flwr_clientapp.py +81 -0
  269. flwr/{client → supernode}/nodestate/__init__.py +1 -1
  270. flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
  271. flwr/supernode/nodestate/nodestate.py +212 -0
  272. flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
  273. flwr/supernode/runtime/__init__.py +15 -0
  274. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
  275. flwr/supernode/servicer/__init__.py +15 -0
  276. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  277. flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
  278. flwr/supernode/start_client_internal.py +491 -0
  279. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
  280. flwr-1.19.0.dist-info/RECORD +365 -0
  281. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
  282. {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
  283. flwr/client/heartbeat.py +0 -74
  284. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  285. flwr-1.17.0.dist-info/LICENSE +0 -202
  286. flwr-1.17.0.dist-info/RECORD +0 -333
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,61 +15,83 @@
15
15
  """Shamir's secret sharing."""
16
16
 
17
17
 
18
- import pickle
18
+ import os
19
19
  from concurrent.futures import ThreadPoolExecutor
20
- from typing import cast
21
20
 
22
21
  from Crypto.Protocol.SecretSharing import Shamir
23
22
  from Crypto.Util.Padding import pad, unpad
24
23
 
25
24
 
26
25
  def create_shares(secret: bytes, threshold: int, num: int) -> list[bytes]:
27
- """Return list of shares (bytes)."""
26
+ """Return a list of shares (bytes).
27
+
28
+ Shares are created from the provided secret using Shamir's secret sharing.
29
+ """
30
+ # Shamir's secret sharing requires the secret to be a multiple of 16 bytes
31
+ # (AES block size). Pad the secret to the next multiple of 16 bytes.
28
32
  secret_padded = pad(secret, 16)
29
- secret_padded_chunk = [
30
- (threshold, num, secret_padded[i : i + 16])
31
- for i in range(0, len(secret_padded), 16)
32
- ]
33
- share_list: list[list[tuple[int, bytes]]] = [[] for _ in range(num)]
33
+ chunks = [secret_padded[i : i + 16] for i in range(0, len(secret_padded), 16)]
34
+
35
+ # The share list should contain shares of the secret, and each share consists of:
36
+ # <4 bytes of index><share of chunk1><share of chunk2>...<share of chunkN>
37
+ share_list: list[bytearray] = [bytearray() for _ in range(num)]
34
38
 
35
- with ThreadPoolExecutor(max_workers=10) as executor:
39
+ # Create shares for each chunk in parallel
40
+ max_workers = min(len(chunks), os.cpu_count() or 1)
41
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
36
42
  for chunk_shares in executor.map(
37
- lambda arg: _shamir_split(*arg), secret_padded_chunk
43
+ lambda chunk: _shamir_split(threshold, num, chunk), chunks
38
44
  ):
39
45
  for idx, share in chunk_shares:
40
- # Index in `chunk_shares` starts from 1
41
- share_list[idx - 1].append((idx, share))
46
+ # Initialize the share with the index if it is empty
47
+ if not share_list[idx - 1]:
48
+ share_list[idx - 1] += idx.to_bytes(4, "little", signed=False)
42
49
 
43
- return [pickle.dumps(shares) for shares in share_list]
50
+ # Append the share to the bytes
51
+ share_list[idx - 1] += share
52
+
53
+ return [bytes(share) for share in share_list]
44
54
 
45
55
 
46
56
  def _shamir_split(threshold: int, num: int, chunk: bytes) -> list[tuple[int, bytes]]:
57
+ """Create shares for a chunk using Shamir's secret sharing.
58
+
59
+ Each share is a tuple (index, share_bytes), where share_bytes is 16 bytes long.
60
+ """
47
61
  return Shamir.split(threshold, num, chunk, ssss=False)
48
62
 
49
63
 
50
- # Reconstructing secret with PyCryptodome
51
64
  def combine_shares(share_list: list[bytes]) -> bytes:
52
- """Reconstruct secret from shares."""
53
- unpickled_share_list: list[list[tuple[int, bytes]]] = [
54
- cast(list[tuple[int, bytes]], pickle.loads(share)) for share in share_list
55
- ]
65
+ """Reconstruct the secret from a list of shares."""
66
+ # Compute the number of chunks
67
+ # Each share contains 4 bytes of index and 16 bytes of share for each chunk
68
+ chunk_num = (len(share_list[0]) - 4) >> 4
56
69
 
57
- chunk_num = len(unpickled_share_list[0])
58
70
  secret_padded = bytearray(0)
59
- chunk_shares_list: list[list[tuple[int, bytes]]] = []
60
- for i in range(chunk_num):
61
- chunk_shares: list[tuple[int, bytes]] = []
62
- for share in unpickled_share_list:
63
- chunk_shares.append(share[i])
64
- chunk_shares_list.append(chunk_shares)
65
-
66
- with ThreadPoolExecutor(max_workers=10) as executor:
71
+ chunk_shares_list: list[list[tuple[int, bytes]]] = [[] for _ in range(chunk_num)]
72
+
73
+ # Split shares into chunks
74
+ for share in share_list:
75
+ # The first 4 bytes are the index
76
+ index = int.from_bytes(share[:4], "little", signed=False)
77
+ for i in range(chunk_num):
78
+ start = (i << 4) + 4
79
+ chunk_shares_list[i].append((index, share[start : start + 16]))
80
+
81
+ # Combine shares for each chunk in parallel
82
+ max_workers = min(chunk_num, os.cpu_count() or 1)
83
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
67
84
  for chunk in executor.map(_shamir_combine, chunk_shares_list):
68
85
  secret_padded += chunk
69
86
 
70
- secret = unpad(secret_padded, 16)
71
- return bytes(secret)
87
+ try:
88
+ secret = unpad(bytes(secret_padded), 16)
89
+ except ValueError:
90
+ # If unpadding fails, it means the shares are not valid
91
+ raise ValueError("Failed to combine shares") from None
92
+ return secret
72
93
 
73
94
 
74
95
  def _shamir_combine(shares: list[tuple[int, bytes]]) -> bytes:
96
+ """Reconstruct a chunk from shares using Shamir's secret sharing."""
75
97
  return Shamir.combine(shares, ssss=False)
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/serde.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,28 +16,20 @@
16
16
 
17
17
 
18
18
  from collections import OrderedDict
19
- from collections.abc import MutableMapping
20
- from typing import Any, TypeVar, cast
21
-
22
- from google.protobuf.message import Message as GrpcMessage
19
+ from typing import Any, cast
23
20
 
24
21
  # pylint: disable=E0611
25
22
  from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
26
- from flwr.proto.error_pb2 import Error as ProtoError
27
23
  from flwr.proto.fab_pb2 import Fab as ProtoFab
28
24
  from flwr.proto.message_pb2 import Context as ProtoContext
29
25
  from flwr.proto.message_pb2 import Message as ProtoMessage
30
- from flwr.proto.message_pb2 import Metadata as ProtoMetadata
31
26
  from flwr.proto.recorddict_pb2 import Array as ProtoArray
32
27
  from flwr.proto.recorddict_pb2 import ArrayRecord as ProtoArrayRecord
33
- from flwr.proto.recorddict_pb2 import BoolList, BytesList
34
28
  from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
35
29
  from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
36
- from flwr.proto.recorddict_pb2 import DoubleList
37
30
  from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
38
31
  from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
39
32
  from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
40
- from flwr.proto.recorddict_pb2 import SintList, StringList, UintList
41
33
  from flwr.proto.run_pb2 import Run as ProtoRun
42
34
  from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
43
35
  from flwr.proto.transport_pb2 import (
@@ -60,8 +52,16 @@ from . import (
60
52
  RecordDict,
61
53
  typing,
62
54
  )
63
- from .message import Error, Message, Metadata, make_message
64
- from .record.typeddict import TypedDict
55
+ from .constant import INT64_MAX_VALUE
56
+ from .message import Message, make_message
57
+ from .serde_utils import (
58
+ error_from_proto,
59
+ error_to_proto,
60
+ metadata_from_proto,
61
+ metadata_to_proto,
62
+ record_value_dict_from_proto,
63
+ record_value_dict_to_proto,
64
+ )
65
65
 
66
66
  # === Parameters message ===
67
67
 
@@ -339,7 +339,6 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
339
339
 
340
340
 
341
341
  # === Scalar messages ===
342
- INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
343
342
 
344
343
 
345
344
  def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
@@ -377,107 +376,21 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
377
376
  # === Record messages ===
378
377
 
379
378
 
380
- _type_to_field: dict[type, str] = {
381
- float: "double",
382
- int: "sint64",
383
- bool: "bool",
384
- str: "string",
385
- bytes: "bytes",
386
- }
387
- _list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
388
- float: (DoubleList, "double_list"),
389
- int: (SintList, "sint_list"),
390
- bool: (BoolList, "bool_list"),
391
- str: (StringList, "string_list"),
392
- bytes: (BytesList, "bytes_list"),
393
- }
394
- T = TypeVar("T")
395
-
396
-
397
- def _is_uint64(value: Any) -> bool:
398
- """Check if a value is uint64."""
399
- return isinstance(value, int) and value > INT64_MAX_VALUE
400
-
401
-
402
- def _record_value_to_proto(
403
- value: Any, allowed_types: list[type], proto_class: type[T]
404
- ) -> T:
405
- """Serialize `*RecordValue` to ProtoBuf.
406
-
407
- Note: `bool` MUST be put in the front of allowd_types if it exists.
408
- """
409
- arg = {}
410
- for t in allowed_types:
411
- # Single element
412
- # Note: `isinstance(False, int) == True`.
413
- if isinstance(value, t):
414
- fld = _type_to_field[t]
415
- if t is int and _is_uint64(value):
416
- fld = "uint64"
417
- arg[fld] = value
418
- return proto_class(**arg)
419
- # List
420
- if isinstance(value, list) and all(isinstance(item, t) for item in value):
421
- list_class, fld = _list_type_to_class_and_field[t]
422
- # Use UintList if any element is of type `uint64`.
423
- if t is int and any(_is_uint64(v) for v in value):
424
- list_class, fld = UintList, "uint_list"
425
- arg[fld] = list_class(vals=value)
426
- return proto_class(**arg)
427
- # Invalid types
428
- raise TypeError(
429
- f"The type of the following value is not allowed "
430
- f"in '{proto_class.__name__}':\n{value}"
431
- )
432
-
433
-
434
- def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
435
- """Deserialize `*RecordValue` from ProtoBuf."""
436
- value_field = cast(str, value_proto.WhichOneof("value"))
437
- if value_field.endswith("list"):
438
- value = list(getattr(value_proto, value_field).vals)
439
- else:
440
- value = getattr(value_proto, value_field)
441
- return value
442
-
443
-
444
- def _record_value_dict_to_proto(
445
- value_dict: TypedDict[str, Any],
446
- allowed_types: list[type],
447
- value_proto_class: type[T],
448
- ) -> dict[str, T]:
449
- """Serialize the record value dict to ProtoBuf.
450
-
451
- Note: `bool` MUST be put in the front of allowd_types if it exists.
452
- """
453
- # Move bool to the front
454
- if bool in allowed_types and allowed_types[0] != bool:
455
- allowed_types.remove(bool)
456
- allowed_types.insert(0, bool)
457
-
458
- def proto(_v: Any) -> T:
459
- return _record_value_to_proto(_v, allowed_types, value_proto_class)
460
-
461
- return {k: proto(v) for k, v in value_dict.items()}
462
-
463
-
464
- def _record_value_dict_from_proto(
465
- value_dict_proto: MutableMapping[str, Any]
466
- ) -> dict[str, Any]:
467
- """Deserialize the record value dict from ProtoBuf."""
468
- return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
469
-
470
-
471
379
  def array_to_proto(array: Array) -> ProtoArray:
472
380
  """Serialize Array to ProtoBuf."""
473
- return ProtoArray(**vars(array))
381
+ return ProtoArray(
382
+ dtype=array.dtype,
383
+ shape=array.shape,
384
+ stype=array.stype,
385
+ data=array.data,
386
+ )
474
387
 
475
388
 
476
389
  def array_from_proto(array_proto: ProtoArray) -> Array:
477
390
  """Deserialize Array from ProtoBuf."""
478
391
  return Array(
479
392
  dtype=array_proto.dtype,
480
- shape=list(array_proto.shape),
393
+ shape=tuple(array_proto.shape),
481
394
  stype=array_proto.stype,
482
395
  data=array_proto.data,
483
396
  )
@@ -486,8 +399,10 @@ def array_from_proto(array_proto: ProtoArray) -> Array:
486
399
  def array_record_to_proto(record: ArrayRecord) -> ProtoArrayRecord:
487
400
  """Serialize ArrayRecord to ProtoBuf."""
488
401
  return ProtoArrayRecord(
489
- data_keys=record.keys(),
490
- data_values=map(array_to_proto, record.values()),
402
+ items=[
403
+ ProtoArrayRecord.Item(key=k, value=array_to_proto(v))
404
+ for k, v in record.items()
405
+ ]
491
406
  )
492
407
 
493
408
 
@@ -497,7 +412,7 @@ def array_record_from_proto(
497
412
  """Deserialize ArrayRecord from ProtoBuf."""
498
413
  return ArrayRecord(
499
414
  array_dict=OrderedDict(
500
- zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
415
+ {item.key: array_from_proto(item.value) for item in record_proto.items}
501
416
  ),
502
417
  keep_input=False,
503
418
  )
@@ -505,17 +420,19 @@ def array_record_from_proto(
505
420
 
506
421
  def metric_record_to_proto(record: MetricRecord) -> ProtoMetricRecord:
507
422
  """Serialize MetricRecord to ProtoBuf."""
423
+ protos = record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
508
424
  return ProtoMetricRecord(
509
- data=_record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
425
+ items=[ProtoMetricRecord.Item(key=k, value=v) for k, v in protos.items()]
510
426
  )
511
427
 
512
428
 
513
429
  def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
514
430
  """Deserialize MetricRecord from ProtoBuf."""
431
+ protos = {item.key: item.value for item in record_proto.items}
515
432
  return MetricRecord(
516
433
  metric_dict=cast(
517
434
  dict[str, typing.MetricRecordValues],
518
- _record_value_dict_from_proto(record_proto.data),
435
+ record_value_dict_from_proto(protos),
519
436
  ),
520
437
  keep_input=False,
521
438
  )
@@ -523,68 +440,60 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
523
440
 
524
441
  def config_record_to_proto(record: ConfigRecord) -> ProtoConfigRecord:
525
442
  """Serialize ConfigRecord to ProtoBuf."""
443
+ protos = record_value_dict_to_proto(
444
+ record,
445
+ [bool, int, float, str, bytes],
446
+ ProtoConfigRecordValue,
447
+ )
526
448
  return ProtoConfigRecord(
527
- data=_record_value_dict_to_proto(
528
- record,
529
- [bool, int, float, str, bytes],
530
- ProtoConfigRecordValue,
531
- )
449
+ items=[ProtoConfigRecord.Item(key=k, value=v) for k, v in protos.items()]
532
450
  )
533
451
 
534
452
 
535
453
  def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
536
454
  """Deserialize ConfigRecord from ProtoBuf."""
455
+ protos = {item.key: item.value for item in record_proto.items}
537
456
  return ConfigRecord(
538
457
  config_dict=cast(
539
458
  dict[str, typing.ConfigRecordValues],
540
- _record_value_dict_from_proto(record_proto.data),
459
+ record_value_dict_from_proto(protos),
541
460
  ),
542
461
  keep_input=False,
543
462
  )
544
463
 
545
464
 
546
- # === Error message ===
547
-
548
-
549
- def error_to_proto(error: Error) -> ProtoError:
550
- """Serialize Error to ProtoBuf."""
551
- reason = error.reason if error.reason else ""
552
- return ProtoError(code=error.code, reason=reason)
553
-
554
-
555
- def error_from_proto(error_proto: ProtoError) -> Error:
556
- """Deserialize Error from ProtoBuf."""
557
- reason = error_proto.reason if len(error_proto.reason) > 0 else None
558
- return Error(code=error_proto.code, reason=reason)
559
-
560
-
561
465
  # === RecordDict message ===
562
466
 
563
467
 
564
468
  def recorddict_to_proto(recorddict: RecordDict) -> ProtoRecordDict:
565
469
  """Serialize RecordDict to ProtoBuf."""
566
- return ProtoRecordDict(
567
- arrays={
568
- k: array_record_to_proto(v) for k, v in recorddict.array_records.items()
569
- },
570
- metrics={
571
- k: metric_record_to_proto(v) for k, v in recorddict.metric_records.items()
572
- },
573
- configs={
574
- k: config_record_to_proto(v) for k, v in recorddict.config_records.items()
575
- },
576
- )
470
+ item_cls = ProtoRecordDict.Item
471
+ items: list[ProtoRecordDict.Item] = []
472
+ for k, v in recorddict.items():
473
+ if isinstance(v, ArrayRecord):
474
+ items += [item_cls(key=k, array_record=array_record_to_proto(v))]
475
+ elif isinstance(v, MetricRecord):
476
+ items += [item_cls(key=k, metric_record=metric_record_to_proto(v))]
477
+ elif isinstance(v, ConfigRecord):
478
+ items += [item_cls(key=k, config_record=config_record_to_proto(v))]
479
+ else:
480
+ raise ValueError(f"Unsupported record type: {type(v)}")
481
+ return ProtoRecordDict(items=items)
577
482
 
578
483
 
579
484
  def recorddict_from_proto(recorddict_proto: ProtoRecordDict) -> RecordDict:
580
485
  """Deserialize RecordDict from ProtoBuf."""
581
486
  ret = RecordDict()
582
- for k, arr_record_proto in recorddict_proto.arrays.items():
583
- ret[k] = array_record_from_proto(arr_record_proto)
584
- for k, m_record_proto in recorddict_proto.metrics.items():
585
- ret[k] = metric_record_from_proto(m_record_proto)
586
- for k, c_record_proto in recorddict_proto.configs.items():
587
- ret[k] = config_record_from_proto(c_record_proto)
487
+ for item in recorddict_proto.items:
488
+ field = item.WhichOneof("value")
489
+ if field == "array_record":
490
+ ret[item.key] = array_record_from_proto(item.array_record)
491
+ elif field == "metric_record":
492
+ ret[item.key] = metric_record_from_proto(item.metric_record)
493
+ elif field == "config_record":
494
+ ret[item.key] = config_record_from_proto(item.config_record)
495
+ else:
496
+ raise ValueError(f"Unsupported record type: {field}")
588
497
  return ret
589
498
 
590
499
 
@@ -646,41 +555,6 @@ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
646
555
  return cast(typing.UserConfigValue, scalar)
647
556
 
648
557
 
649
- # === Metadata messages ===
650
-
651
-
652
- def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
653
- """Serialize `Metadata` to ProtoBuf."""
654
- proto = ProtoMetadata( # pylint: disable=E1101
655
- run_id=metadata.run_id,
656
- message_id=metadata.message_id,
657
- src_node_id=metadata.src_node_id,
658
- dst_node_id=metadata.dst_node_id,
659
- reply_to_message_id=metadata.reply_to_message_id,
660
- group_id=metadata.group_id,
661
- ttl=metadata.ttl,
662
- message_type=metadata.message_type,
663
- created_at=metadata.created_at,
664
- )
665
- return proto
666
-
667
-
668
- def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
669
- """Deserialize `Metadata` from ProtoBuf."""
670
- metadata = Metadata(
671
- run_id=metadata_proto.run_id,
672
- message_id=metadata_proto.message_id,
673
- src_node_id=metadata_proto.src_node_id,
674
- dst_node_id=metadata_proto.dst_node_id,
675
- reply_to_message_id=metadata_proto.reply_to_message_id,
676
- group_id=metadata_proto.group_id,
677
- created_at=metadata_proto.created_at,
678
- ttl=metadata_proto.ttl,
679
- message_type=metadata_proto.message_type,
680
- )
681
- return metadata
682
-
683
-
684
558
  # === Message messages ===
685
559
 
686
560
 
@@ -756,6 +630,7 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
756
630
  running_at=run.running_at,
757
631
  finished_at=run.finished_at,
758
632
  status=run_status_to_proto(run.status),
633
+ flwr_aid=run.flwr_aid,
759
634
  )
760
635
  return proto
761
636
 
@@ -773,6 +648,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
773
648
  running_at=run_proto.running_at,
774
649
  finished_at=run_proto.finished_at,
775
650
  status=run_status_from_proto(run_proto.status),
651
+ flwr_aid=run_proto.flwr_aid,
776
652
  )
777
653
  return run
778
654