flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -16,23 +16,21 @@
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
19
- from typing import Callable, Dict, Optional, Type, TypeVar
19
+ from typing import Dict, Optional, cast
20
20
 
21
21
  from .configsrecord import ConfigsRecord
22
22
  from .metricsrecord import MetricsRecord
23
23
  from .parametersrecord import ParametersRecord
24
24
  from .typeddict import TypedDict
25
25
 
26
- T = TypeVar("T")
27
-
28
26
 
29
27
  @dataclass
30
- class RecordSet:
31
- """RecordSet stores groups of parameters, metrics and configs."""
28
+ class RecordSetData:
29
+ """Inner data container for the RecordSet class."""
32
30
 
33
- _parameters_records: TypedDict[str, ParametersRecord]
34
- _metrics_records: TypedDict[str, MetricsRecord]
35
- _configs_records: TypedDict[str, ConfigsRecord]
31
+ parameters_records: TypedDict[str, ParametersRecord]
32
+ metrics_records: TypedDict[str, MetricsRecord]
33
+ configs_records: TypedDict[str, ConfigsRecord]
36
34
 
37
35
  def __init__(
38
36
  self,
@@ -40,40 +38,93 @@ class RecordSet:
40
38
  metrics_records: Optional[Dict[str, MetricsRecord]] = None,
41
39
  configs_records: Optional[Dict[str, ConfigsRecord]] = None,
42
40
  ) -> None:
43
- def _get_check_fn(__t: Type[T]) -> Callable[[T], None]:
44
- def _check_fn(__v: T) -> None:
45
- if not isinstance(__v, __t):
46
- raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.")
47
-
48
- return _check_fn
49
-
50
- self._parameters_records = TypedDict[str, ParametersRecord](
51
- _get_check_fn(str), _get_check_fn(ParametersRecord)
41
+ self.parameters_records = TypedDict[str, ParametersRecord](
42
+ self._check_fn_str, self._check_fn_params
52
43
  )
53
- self._metrics_records = TypedDict[str, MetricsRecord](
54
- _get_check_fn(str), _get_check_fn(MetricsRecord)
44
+ self.metrics_records = TypedDict[str, MetricsRecord](
45
+ self._check_fn_str, self._check_fn_metrics
55
46
  )
56
- self._configs_records = TypedDict[str, ConfigsRecord](
57
- _get_check_fn(str), _get_check_fn(ConfigsRecord)
47
+ self.configs_records = TypedDict[str, ConfigsRecord](
48
+ self._check_fn_str, self._check_fn_configs
58
49
  )
59
50
  if parameters_records is not None:
60
- self._parameters_records.update(parameters_records)
51
+ self.parameters_records.update(parameters_records)
61
52
  if metrics_records is not None:
62
- self._metrics_records.update(metrics_records)
53
+ self.metrics_records.update(metrics_records)
63
54
  if configs_records is not None:
64
- self._configs_records.update(configs_records)
55
+ self.configs_records.update(configs_records)
56
+
57
+ def _check_fn_str(self, key: str) -> None:
58
+ if not isinstance(key, str):
59
+ raise TypeError(
60
+ f"Expected `{str.__name__}`, but "
61
+ f"received `{type(key).__name__}` for the key."
62
+ )
63
+
64
+ def _check_fn_params(self, record: ParametersRecord) -> None:
65
+ if not isinstance(record, ParametersRecord):
66
+ raise TypeError(
67
+ f"Expected `{ParametersRecord.__name__}`, but "
68
+ f"received `{type(record).__name__}` for the value."
69
+ )
70
+
71
+ def _check_fn_metrics(self, record: MetricsRecord) -> None:
72
+ if not isinstance(record, MetricsRecord):
73
+ raise TypeError(
74
+ f"Expected `{MetricsRecord.__name__}`, but "
75
+ f"received `{type(record).__name__}` for the value."
76
+ )
77
+
78
+ def _check_fn_configs(self, record: ConfigsRecord) -> None:
79
+ if not isinstance(record, ConfigsRecord):
80
+ raise TypeError(
81
+ f"Expected `{ConfigsRecord.__name__}`, but "
82
+ f"received `{type(record).__name__}` for the value."
83
+ )
84
+
85
+
86
+ class RecordSet:
87
+ """RecordSet stores groups of parameters, metrics and configs."""
88
+
89
+ def __init__(
90
+ self,
91
+ parameters_records: Optional[Dict[str, ParametersRecord]] = None,
92
+ metrics_records: Optional[Dict[str, MetricsRecord]] = None,
93
+ configs_records: Optional[Dict[str, ConfigsRecord]] = None,
94
+ ) -> None:
95
+ data = RecordSetData(
96
+ parameters_records=parameters_records,
97
+ metrics_records=metrics_records,
98
+ configs_records=configs_records,
99
+ )
100
+ self.__dict__["_data"] = data
65
101
 
66
102
  @property
67
103
  def parameters_records(self) -> TypedDict[str, ParametersRecord]:
68
104
  """Dictionary holding ParametersRecord instances."""
69
- return self._parameters_records
105
+ data = cast(RecordSetData, self.__dict__["_data"])
106
+ return data.parameters_records
70
107
 
71
108
  @property
72
109
  def metrics_records(self) -> TypedDict[str, MetricsRecord]:
73
110
  """Dictionary holding MetricsRecord instances."""
74
- return self._metrics_records
111
+ data = cast(RecordSetData, self.__dict__["_data"])
112
+ return data.metrics_records
75
113
 
76
114
  @property
77
115
  def configs_records(self) -> TypedDict[str, ConfigsRecord]:
78
116
  """Dictionary holding ConfigsRecord instances."""
79
- return self._configs_records
117
+ data = cast(RecordSetData, self.__dict__["_data"])
118
+ return data.configs_records
119
+
120
+ def __repr__(self) -> str:
121
+ """Return a string representation of this instance."""
122
+ flds = ("parameters_records", "metrics_records", "configs_records")
123
+ view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
124
+ return f"{self.__class__.__qualname__}({view})"
125
+
126
+ def __eq__(self, other: object) -> bool:
127
+ """Compare two instances of the class."""
128
+ if not isinstance(other, self.__class__):
129
+ raise NotImplementedError
130
+ return self.__dict__ == other.__dict__
@@ -35,6 +35,8 @@ from .typing import (
35
35
  Status,
36
36
  )
37
37
 
38
+ EMPTY_TENSOR_KEY = "_empty"
39
+
38
40
 
39
41
  def parametersrecord_to_parameters(
40
42
  record: ParametersRecord, keep_input: bool
@@ -59,7 +61,8 @@ def parametersrecord_to_parameters(
59
61
  parameters = Parameters(tensors=[], tensor_type="")
60
62
 
61
63
  for key in list(record.keys()):
62
- parameters.tensors.append(record[key].data)
64
+ if key != EMPTY_TENSOR_KEY:
65
+ parameters.tensors.append(record[key].data)
63
66
 
64
67
  if not parameters.tensor_type:
65
68
  # Setting from first array in record. Recall the warning in the docstrings
@@ -103,6 +106,10 @@ def parameters_to_parametersrecord(
103
106
  data=tensor, dtype="", stype=tensor_type, shape=[]
104
107
  )
105
108
 
109
+ if num_arrays == 0:
110
+ ordered_dict[EMPTY_TENSOR_KEY] = Array(
111
+ data=b"", dtype="", stype=tensor_type, shape=[]
112
+ )
106
113
  return ParametersRecord(ordered_dict, keep_input=keep_input)
107
114
 
108
115
 
@@ -107,7 +107,7 @@ class RetryInvoker:
107
107
 
108
108
  Parameters
109
109
  ----------
110
- wait_factory: Callable[[], Generator[float, None, None]]
110
+ wait_gen_factory: Callable[[], Generator[float, None, None]]
111
111
  A generator yielding successive wait times in seconds. If the generator
112
112
  is finite, the giveup event will be triggered when the generator raises
113
113
  `StopIteration`.
@@ -129,12 +129,12 @@ class RetryInvoker:
129
129
  data class object detailing the invocation.
130
130
  on_giveup: Optional[Callable[[RetryState], None]] (default: None)
131
131
  A callable to be executed in the event that `max_tries` or `max_time` is
132
- exceeded, `should_giveup` returns True, or `wait_factory()` generator raises
132
+ exceeded, `should_giveup` returns True, or `wait_gen_factory()` generator raises
133
133
  `StopInteration`. The parameter is a data class object detailing the
134
134
  invocation.
135
135
  jitter: Optional[Callable[[float], float]] (default: full_jitter)
136
- A function of the value yielded by `wait_factory()` returning the actual time
137
- to wait. This function helps distribute wait times stochastically to avoid
136
+ A function of the value yielded by `wait_gen_factory()` returning the actual
137
+ time to wait. This function helps distribute wait times stochastically to avoid
138
138
  timing collisions across concurrent clients. Wait times are jittered by
139
139
  default using the `full_jitter` function. To disable jittering, pass
140
140
  `jitter=None`.
@@ -142,6 +142,13 @@ class RetryInvoker:
142
142
  A function accepting an exception instance, returning whether or not
143
143
  to give up prematurely before other give-up conditions are evaluated.
144
144
  If set to None, the strategy is to never give up prematurely.
145
+ wait_function: Optional[Callable[[float], None]] (default: None)
146
+ A function that defines how to wait between retry attempts. It accepts
147
+ one argument, the wait time in seconds, allowing the use of various waiting
148
+ mechanisms (e.g., asynchronous waits or event-based synchronization) suitable
149
+ for different execution environments. If set to `None`, the `wait_function`
150
+ defaults to `time.sleep`, which is ideal for synchronous operations. Custom
151
+ functions should manage execution flow to prevent blocking or interference.
145
152
 
146
153
  Examples
147
154
  --------
@@ -159,7 +166,7 @@ class RetryInvoker:
159
166
  # pylint: disable-next=too-many-arguments
160
167
  def __init__(
161
168
  self,
162
- wait_factory: Callable[[], Generator[float, None, None]],
169
+ wait_gen_factory: Callable[[], Generator[float, None, None]],
163
170
  recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
164
171
  max_tries: Optional[int],
165
172
  max_time: Optional[float],
@@ -169,8 +176,9 @@ class RetryInvoker:
169
176
  on_giveup: Optional[Callable[[RetryState], None]] = None,
170
177
  jitter: Optional[Callable[[float], float]] = full_jitter,
171
178
  should_giveup: Optional[Callable[[Exception], bool]] = None,
179
+ wait_function: Optional[Callable[[float], None]] = None,
172
180
  ) -> None:
173
- self.wait_factory = wait_factory
181
+ self.wait_gen_factory = wait_gen_factory
174
182
  self.recoverable_exceptions = recoverable_exceptions
175
183
  self.max_tries = max_tries
176
184
  self.max_time = max_time
@@ -179,6 +187,9 @@ class RetryInvoker:
179
187
  self.on_giveup = on_giveup
180
188
  self.jitter = jitter
181
189
  self.should_giveup = should_giveup
190
+ if wait_function is None:
191
+ wait_function = time.sleep
192
+ self.wait_function = wait_function
182
193
 
183
194
  # pylint: disable-next=too-many-locals
184
195
  def invoke(
@@ -212,13 +223,13 @@ class RetryInvoker:
212
223
  Raises
213
224
  ------
214
225
  Exception
215
- If the number of tries exceeds `max_tries`, if the total time
216
- exceeds `max_time`, if `wait_factory()` generator raises `StopInteration`,
226
+ If the number of tries exceeds `max_tries`, if the total time exceeds
227
+ `max_time`, if `wait_gen_factory()` generator raises `StopInteration`,
217
228
  or if the `should_giveup` returns True for a raised exception.
218
229
 
219
230
  Notes
220
231
  -----
221
- The time between retries is determined by the provided `wait_factory()`
232
+ The time between retries is determined by the provided `wait_gen_factory()`
222
233
  generator and can optionally be jittered using the `jitter` function.
223
234
  The recoverable exceptions that trigger a retry, as well as conditions to
224
235
  stop retries, are also determined by the class's initialization parameters.
@@ -231,13 +242,13 @@ class RetryInvoker:
231
242
  handler(cast(RetryState, ref_state[0]))
232
243
 
233
244
  try_cnt = 0
234
- wait_generator = self.wait_factory()
235
- start = time.time()
245
+ wait_generator = self.wait_gen_factory()
246
+ start = time.monotonic()
236
247
  ref_state: List[Optional[RetryState]] = [None]
237
248
 
238
249
  while True:
239
250
  try_cnt += 1
240
- elapsed_time = time.time() - start
251
+ elapsed_time = time.monotonic() - start
241
252
  state = RetryState(
242
253
  target=target,
243
254
  args=args,
@@ -250,6 +261,7 @@ class RetryInvoker:
250
261
  try:
251
262
  ret = target(*args, **kwargs)
252
263
  except self.recoverable_exceptions as err:
264
+ state.exception = err
253
265
  # Check if giveup event should be triggered
254
266
  max_tries_exceeded = try_cnt == self.max_tries
255
267
  max_time_exceeded = (
@@ -282,7 +294,7 @@ class RetryInvoker:
282
294
  try_call_event_handler(self.on_backoff)
283
295
 
284
296
  # Sleep
285
- time.sleep(wait_time)
297
+ self.wait_function(state.actual_wait)
286
298
  else:
287
299
  # Trigger success event
288
300
  try_call_event_handler(self.on_success)
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,8 +18,9 @@
18
18
  import base64
19
19
  from typing import Tuple, cast
20
20
 
21
+ from cryptography.exceptions import InvalidSignature
21
22
  from cryptography.fernet import Fernet
22
- from cryptography.hazmat.primitives import hashes, serialization
23
+ from cryptography.hazmat.primitives import hashes, hmac, serialization
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
  from cryptography.hazmat.primitives.kdf.hkdf import HKDF
25
26
 
@@ -98,3 +99,21 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes:
98
99
  # The input key must be url safe
99
100
  fernet = Fernet(key)
100
101
  return fernet.decrypt(ciphertext)
102
+
103
+
104
+ def compute_hmac(key: bytes, message: bytes) -> bytes:
105
+ """Compute hmac of a message using key as hash."""
106
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
107
+ computed_hmac.update(message)
108
+ return computed_hmac.finalize()
109
+
110
+
111
+ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
112
+ """Verify hmac of a message using key as hash."""
113
+ computed_hmac = hmac.HMAC(key, hashes.SHA256())
114
+ computed_hmac.update(message)
115
+ try:
116
+ computed_hmac.verify(hmac_value)
117
+ return True
118
+ except InvalidSignature:
119
+ return False
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/serde.py CHANGED
@@ -20,7 +20,11 @@ from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar,
20
20
  from google.protobuf.message import Message as GrpcMessage
21
21
 
22
22
  # pylint: disable=E0611
23
+ from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
23
24
  from flwr.proto.error_pb2 import Error as ProtoError
25
+ from flwr.proto.message_pb2 import Context as ProtoContext
26
+ from flwr.proto.message_pb2 import Message as ProtoMessage
27
+ from flwr.proto.message_pb2 import Metadata as ProtoMetadata
24
28
  from flwr.proto.node_pb2 import Node
25
29
  from flwr.proto.recordset_pb2 import Array as ProtoArray
26
30
  from flwr.proto.recordset_pb2 import BoolList, BytesList
@@ -32,6 +36,7 @@ from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordVal
32
36
  from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
33
37
  from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
34
38
  from flwr.proto.recordset_pb2 import Sint64List, StringList
39
+ from flwr.proto.run_pb2 import Run as ProtoRun
35
40
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
36
41
  from flwr.proto.transport_pb2 import (
37
42
  ClientMessage,
@@ -44,7 +49,15 @@ from flwr.proto.transport_pb2 import (
44
49
  )
45
50
 
46
51
  # pylint: enable=E0611
47
- from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing
52
+ from . import (
53
+ Array,
54
+ ConfigsRecord,
55
+ Context,
56
+ MetricsRecord,
57
+ ParametersRecord,
58
+ RecordSet,
59
+ typing,
60
+ )
48
61
  from .message import Error, Message, Metadata
49
62
  from .record.typeddict import TypedDict
50
63
 
@@ -575,6 +588,7 @@ def message_to_taskins(message: Message) -> TaskIns:
575
588
  task=Task(
576
589
  producer=Node(node_id=0, anonymous=True), # Assume driver node
577
590
  consumer=Node(node_id=md.dst_node_id, anonymous=False),
591
+ created_at=md.created_at,
578
592
  ttl=md.ttl,
579
593
  ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
580
594
  task_type=md.message_type,
@@ -601,7 +615,7 @@ def message_from_taskins(taskins: TaskIns) -> Message:
601
615
  )
602
616
 
603
617
  # Construct Message
604
- return Message(
618
+ message = Message(
605
619
  metadata=metadata,
606
620
  content=(
607
621
  recordset_from_proto(taskins.task.recordset)
@@ -614,6 +628,8 @@ def message_from_taskins(taskins: TaskIns) -> Message:
614
628
  else None
615
629
  ),
616
630
  )
631
+ message.metadata.created_at = taskins.task.created_at
632
+ return message
617
633
 
618
634
 
619
635
  def message_to_taskres(message: Message) -> TaskRes:
@@ -626,6 +642,7 @@ def message_to_taskres(message: Message) -> TaskRes:
626
642
  task=Task(
627
643
  producer=Node(node_id=md.src_node_id, anonymous=False),
628
644
  consumer=Node(node_id=0, anonymous=True), # Assume driver node
645
+ created_at=md.created_at,
629
646
  ttl=md.ttl,
630
647
  ancestry=[md.reply_to_message] if md.reply_to_message != "" else [],
631
648
  task_type=md.message_type,
@@ -652,7 +669,7 @@ def message_from_taskres(taskres: TaskRes) -> Message:
652
669
  )
653
670
 
654
671
  # Construct the Message
655
- return Message(
672
+ message = Message(
656
673
  metadata=metadata,
657
674
  content=(
658
675
  recordset_from_proto(taskres.task.recordset)
@@ -665,3 +682,192 @@ def message_from_taskres(taskres: TaskRes) -> Message:
665
682
  else None
666
683
  ),
667
684
  )
685
+ message.metadata.created_at = taskres.task.created_at
686
+ return message
687
+
688
+
689
+ # === User configs ===
690
+
691
+
692
+ def user_config_to_proto(user_config: typing.UserConfig) -> Any:
693
+ """Serialize `UserConfig` to ProtoBuf."""
694
+ proto = {}
695
+ for key, value in user_config.items():
696
+ proto[key] = user_config_value_to_proto(value)
697
+ return proto
698
+
699
+
700
+ def user_config_from_proto(proto: Any) -> typing.UserConfig:
701
+ """Deserialize `UserConfig` from ProtoBuf."""
702
+ metrics = {}
703
+ for key, value in proto.items():
704
+ metrics[key] = user_config_value_from_proto(value)
705
+ return metrics
706
+
707
+
708
+ def user_config_value_to_proto(user_config_value: typing.UserConfigValue) -> Scalar:
709
+ """Serialize `UserConfigValue` to ProtoBuf."""
710
+ if isinstance(user_config_value, bool):
711
+ return Scalar(bool=user_config_value)
712
+
713
+ if isinstance(user_config_value, float):
714
+ return Scalar(double=user_config_value)
715
+
716
+ if isinstance(user_config_value, int):
717
+ return Scalar(sint64=user_config_value)
718
+
719
+ if isinstance(user_config_value, str):
720
+ return Scalar(string=user_config_value)
721
+
722
+ raise ValueError(
723
+ f"Accepted types: {bool, float, int, str} (but not {type(user_config_value)})"
724
+ )
725
+
726
+
727
+ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
728
+ """Deserialize `UserConfigValue` from ProtoBuf."""
729
+ scalar_field = scalar_msg.WhichOneof("scalar")
730
+ scalar = getattr(scalar_msg, cast(str, scalar_field))
731
+ return cast(typing.UserConfigValue, scalar)
732
+
733
+
734
+ # === Metadata messages ===
735
+
736
+
737
+ def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
738
+ """Serialize `Metadata` to ProtoBuf."""
739
+ proto = ProtoMetadata( # pylint: disable=E1101
740
+ run_id=metadata.run_id,
741
+ message_id=metadata.message_id,
742
+ src_node_id=metadata.src_node_id,
743
+ dst_node_id=metadata.dst_node_id,
744
+ reply_to_message=metadata.reply_to_message,
745
+ group_id=metadata.group_id,
746
+ ttl=metadata.ttl,
747
+ message_type=metadata.message_type,
748
+ )
749
+ return proto
750
+
751
+
752
+ def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
753
+ """Deserialize `Metadata` from ProtoBuf."""
754
+ metadata = Metadata(
755
+ run_id=metadata_proto.run_id,
756
+ message_id=metadata_proto.message_id,
757
+ src_node_id=metadata_proto.src_node_id,
758
+ dst_node_id=metadata_proto.dst_node_id,
759
+ reply_to_message=metadata_proto.reply_to_message,
760
+ group_id=metadata_proto.group_id,
761
+ ttl=metadata_proto.ttl,
762
+ message_type=metadata_proto.message_type,
763
+ )
764
+ return metadata
765
+
766
+
767
+ # === Message messages ===
768
+
769
+
770
+ def message_to_proto(message: Message) -> ProtoMessage:
771
+ """Serialize `Message` to ProtoBuf."""
772
+ proto = ProtoMessage(
773
+ metadata=metadata_to_proto(message.metadata),
774
+ content=recordset_to_proto(message.content),
775
+ error=error_to_proto(message.error) if message.has_error() else None,
776
+ )
777
+ return proto
778
+
779
+
780
+ def message_from_proto(message_proto: ProtoMessage) -> Message:
781
+ """Deserialize `Message` from ProtoBuf."""
782
+ message = Message(
783
+ metadata=metadata_from_proto(message_proto.metadata),
784
+ content=(
785
+ recordset_from_proto(message_proto.content)
786
+ if message_proto.HasField("content")
787
+ else None
788
+ ),
789
+ error=(
790
+ error_from_proto(message_proto.error)
791
+ if message_proto.HasField("error")
792
+ else None
793
+ ),
794
+ )
795
+ return message
796
+
797
+
798
+ # === Context messages ===
799
+
800
+
801
+ def context_to_proto(context: Context) -> ProtoContext:
802
+ """Serialize `Context` to ProtoBuf."""
803
+ proto = ProtoContext(
804
+ node_id=context.node_id,
805
+ node_config=user_config_to_proto(context.node_config),
806
+ state=recordset_to_proto(context.state),
807
+ run_config=user_config_to_proto(context.run_config),
808
+ )
809
+ return proto
810
+
811
+
812
+ def context_from_proto(context_proto: ProtoContext) -> Context:
813
+ """Deserialize `Context` from ProtoBuf."""
814
+ context = Context(
815
+ node_id=context_proto.node_id,
816
+ node_config=user_config_from_proto(context_proto.node_config),
817
+ state=recordset_from_proto(context_proto.state),
818
+ run_config=user_config_from_proto(context_proto.run_config),
819
+ )
820
+ return context
821
+
822
+
823
+ # === Run messages ===
824
+
825
+
826
+ def run_to_proto(run: typing.Run) -> ProtoRun:
827
+ """Serialize `Run` to ProtoBuf."""
828
+ proto = ProtoRun(
829
+ run_id=run.run_id,
830
+ fab_id=run.fab_id,
831
+ fab_version=run.fab_version,
832
+ override_config=user_config_to_proto(run.override_config),
833
+ fab_hash="",
834
+ )
835
+ return proto
836
+
837
+
838
+ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
839
+ """Deserialize `Run` from ProtoBuf."""
840
+ run = typing.Run(
841
+ run_id=run_proto.run_id,
842
+ fab_id=run_proto.fab_id,
843
+ fab_version=run_proto.fab_version,
844
+ override_config=user_config_from_proto(run_proto.override_config),
845
+ )
846
+ return run
847
+
848
+
849
+ # === ClientApp status messages ===
850
+
851
+
852
+ def clientappstatus_to_proto(
853
+ status: typing.ClientAppOutputStatus,
854
+ ) -> ClientAppOutputStatus:
855
+ """Serialize `ClientAppOutputStatus` to ProtoBuf."""
856
+ code = ClientAppOutputCode.SUCCESS
857
+ if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
858
+ code = ClientAppOutputCode.DEADLINE_EXCEEDED
859
+ if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
860
+ code = ClientAppOutputCode.UNKNOWN_ERROR
861
+ return ClientAppOutputStatus(code=code, message=status.message)
862
+
863
+
864
+ def clientappstatus_from_proto(
865
+ msg: ClientAppOutputStatus,
866
+ ) -> typing.ClientAppOutputStatus:
867
+ """Deserialize `ClientAppOutputStatus` from ProtoBuf."""
868
+ code = typing.ClientAppOutputCode.SUCCESS
869
+ if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
870
+ code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
871
+ if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
872
+ code = typing.ClientAppOutputCode.UNKNOWN_ERROR
873
+ return typing.ClientAppOutputStatus(code=code, message=msg.message)