flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +94 -59
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/new.py +12 -4
  9. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  10. flwr/cli/new/templates/app/README.md.tpl +5 -0
  11. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  23. flwr/cli/run/run.py +48 -49
  24. flwr/cli/stop.py +2 -2
  25. flwr/cli/utils.py +38 -5
  26. flwr/client/__init__.py +2 -2
  27. flwr/client/client_app.py +1 -1
  28. flwr/client/clientapp/__init__.py +0 -7
  29. flwr/client/grpc_adapter_client/connection.py +15 -8
  30. flwr/client/grpc_rere_client/connection.py +142 -97
  31. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  32. flwr/client/message_handler/message_handler.py +1 -1
  33. flwr/client/mod/comms_mods.py +36 -17
  34. flwr/client/rest_client/connection.py +176 -103
  35. flwr/clientapp/__init__.py +15 -0
  36. flwr/common/__init__.py +2 -2
  37. flwr/common/auth_plugin/__init__.py +2 -0
  38. flwr/common/auth_plugin/auth_plugin.py +29 -3
  39. flwr/common/constant.py +39 -8
  40. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  41. flwr/common/exit/exit_code.py +16 -1
  42. flwr/common/exit_handlers.py +30 -0
  43. flwr/common/grpc.py +12 -1
  44. flwr/common/heartbeat.py +165 -0
  45. flwr/common/inflatable.py +290 -0
  46. flwr/common/inflatable_protobuf_utils.py +141 -0
  47. flwr/common/inflatable_utils.py +508 -0
  48. flwr/common/message.py +110 -242
  49. flwr/common/record/__init__.py +2 -1
  50. flwr/common/record/array.py +402 -0
  51. flwr/common/record/arraychunk.py +59 -0
  52. flwr/common/record/arrayrecord.py +103 -225
  53. flwr/common/record/configrecord.py +59 -4
  54. flwr/common/record/conversion_utils.py +1 -1
  55. flwr/common/record/metricrecord.py +55 -4
  56. flwr/common/record/recorddict.py +69 -1
  57. flwr/common/recorddict_compat.py +2 -2
  58. flwr/common/retry_invoker.py +5 -1
  59. flwr/common/serde.py +59 -211
  60. flwr/common/serde_utils.py +175 -0
  61. flwr/common/typing.py +5 -3
  62. flwr/compat/__init__.py +15 -0
  63. flwr/compat/client/__init__.py +15 -0
  64. flwr/{client → compat/client}/app.py +28 -185
  65. flwr/compat/common/__init__.py +15 -0
  66. flwr/compat/server/__init__.py +15 -0
  67. flwr/compat/server/app.py +174 -0
  68. flwr/compat/simulation/__init__.py +15 -0
  69. flwr/proto/appio_pb2.py +43 -0
  70. flwr/proto/appio_pb2.pyi +151 -0
  71. flwr/proto/appio_pb2_grpc.py +4 -0
  72. flwr/proto/appio_pb2_grpc.pyi +4 -0
  73. flwr/proto/clientappio_pb2.py +12 -19
  74. flwr/proto/clientappio_pb2.pyi +23 -101
  75. flwr/proto/clientappio_pb2_grpc.py +269 -28
  76. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  77. flwr/proto/fleet_pb2.py +24 -27
  78. flwr/proto/fleet_pb2.pyi +19 -35
  79. flwr/proto/fleet_pb2_grpc.py +117 -13
  80. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  81. flwr/proto/heartbeat_pb2.py +33 -0
  82. flwr/proto/heartbeat_pb2.pyi +66 -0
  83. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  84. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  85. flwr/proto/message_pb2.py +28 -11
  86. flwr/proto/message_pb2.pyi +125 -0
  87. flwr/proto/recorddict_pb2.py +16 -28
  88. flwr/proto/recorddict_pb2.pyi +46 -64
  89. flwr/proto/run_pb2.py +24 -32
  90. flwr/proto/run_pb2.pyi +4 -52
  91. flwr/proto/serverappio_pb2.py +9 -23
  92. flwr/proto/serverappio_pb2.pyi +0 -110
  93. flwr/proto/serverappio_pb2_grpc.py +177 -72
  94. flwr/proto/serverappio_pb2_grpc.pyi +75 -33
  95. flwr/proto/simulationio_pb2.py +12 -11
  96. flwr/proto/simulationio_pb2_grpc.py +35 -0
  97. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  98. flwr/server/__init__.py +1 -1
  99. flwr/server/app.py +69 -187
  100. flwr/server/compat/app_utils.py +50 -28
  101. flwr/server/fleet_event_log_interceptor.py +6 -2
  102. flwr/server/grid/grpc_grid.py +148 -41
  103. flwr/server/grid/inmemory_grid.py +5 -4
  104. flwr/server/serverapp/app.py +45 -17
  105. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
  106. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  107. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  108. flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
  109. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
  110. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  111. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  112. flwr/server/superlink/linkstate/linkstate.py +53 -20
  113. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  114. flwr/server/superlink/linkstate/utils.py +33 -29
  115. flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
  116. flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
  117. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  118. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  119. flwr/server/superlink/utils.py +9 -2
  120. flwr/server/utils/validator.py +2 -2
  121. flwr/serverapp/__init__.py +15 -0
  122. flwr/simulation/app.py +25 -0
  123. flwr/simulation/run_simulation.py +17 -0
  124. flwr/supercore/__init__.py +15 -0
  125. flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
  126. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  127. flwr/supercore/grpc_health/__init__.py +22 -0
  128. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  129. flwr/supercore/license_plugin/__init__.py +22 -0
  130. flwr/supercore/license_plugin/license_plugin.py +26 -0
  131. flwr/supercore/object_store/__init__.py +24 -0
  132. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  133. flwr/supercore/object_store/object_store.py +170 -0
  134. flwr/supercore/object_store/object_store_factory.py +44 -0
  135. flwr/supercore/object_store/utils.py +43 -0
  136. flwr/supercore/scheduler/__init__.py +22 -0
  137. flwr/supercore/scheduler/plugin.py +71 -0
  138. flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
  139. flwr/superexec/deployment.py +7 -4
  140. flwr/superexec/exec_event_log_interceptor.py +8 -4
  141. flwr/superexec/exec_grpc.py +25 -5
  142. flwr/superexec/exec_license_interceptor.py +82 -0
  143. flwr/superexec/exec_servicer.py +135 -24
  144. flwr/superexec/exec_user_auth_interceptor.py +45 -8
  145. flwr/superexec/executor.py +5 -1
  146. flwr/superexec/simulation.py +8 -3
  147. flwr/superlink/__init__.py +15 -0
  148. flwr/{client/supernode → supernode}/__init__.py +0 -7
  149. flwr/supernode/cli/__init__.py +24 -0
  150. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
  151. flwr/supernode/cli/flwr_clientapp.py +88 -0
  152. flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
  153. flwr/supernode/nodestate/nodestate.py +227 -0
  154. flwr/supernode/runtime/__init__.py +15 -0
  155. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
  156. flwr/supernode/scheduler/__init__.py +22 -0
  157. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  158. flwr/supernode/servicer/__init__.py +15 -0
  159. flwr/supernode/servicer/clientappio/__init__.py +22 -0
  160. flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
  161. flwr/supernode/start_client_internal.py +589 -0
  162. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
  163. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
  164. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
  165. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
  166. flwr/client/clientapp/clientappio_servicer.py +0 -244
  167. flwr/client/heartbeat.py +0 -74
  168. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  169. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  170. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  171. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  172. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  173. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  174. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
@@ -38,8 +38,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
38
38
  CreateNodeResponse,
39
39
  DeleteNodeRequest,
40
40
  DeleteNodeResponse,
41
- PingRequest,
42
- PingResponse,
43
41
  PullMessagesRequest,
44
42
  PullMessagesResponse,
45
43
  PushMessagesRequest,
@@ -47,6 +45,18 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
47
45
  )
48
46
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
49
47
  from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
48
+ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
49
+ SendNodeHeartbeatRequest,
50
+ SendNodeHeartbeatResponse,
51
+ )
52
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
53
+ ConfirmMessageReceivedRequest,
54
+ ConfirmMessageReceivedResponse,
55
+ PullObjectRequest,
56
+ PullObjectResponse,
57
+ PushObjectRequest,
58
+ PushObjectResponse,
59
+ )
50
60
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
51
61
 
52
62
  T = TypeVar("T", bound=GrpcMessage)
@@ -120,11 +130,11 @@ class GrpcAdapter:
120
130
  """."""
121
131
  return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
122
132
 
123
- def Ping( # pylint: disable=C0103
124
- self, request: PingRequest, **kwargs: Any
125
- ) -> PingResponse:
133
+ def SendNodeHeartbeat( # pylint: disable=C0103
134
+ self, request: SendNodeHeartbeatRequest, **kwargs: Any
135
+ ) -> SendNodeHeartbeatResponse:
126
136
  """."""
127
- return self._send_and_receive(request, PingResponse, **kwargs)
137
+ return self._send_and_receive(request, SendNodeHeartbeatResponse, **kwargs)
128
138
 
129
139
  def PullMessages( # pylint: disable=C0103
130
140
  self, request: PullMessagesRequest, **kwargs: Any
@@ -149,3 +159,21 @@ class GrpcAdapter:
149
159
  ) -> GetFabResponse:
150
160
  """."""
151
161
  return self._send_and_receive(request, GetFabResponse, **kwargs)
162
+
163
+ def PushObject( # pylint: disable=C0103
164
+ self, request: PushObjectRequest, **kwargs: Any
165
+ ) -> PushObjectResponse:
166
+ """."""
167
+ return self._send_and_receive(request, PushObjectResponse, **kwargs)
168
+
169
+ def PullObject( # pylint: disable=C0103
170
+ self, request: PullObjectRequest, **kwargs: Any
171
+ ) -> PullObjectResponse:
172
+ """."""
173
+ return self._send_and_receive(request, PullObjectResponse, **kwargs)
174
+
175
+ def ConfirmMessageReceived( # pylint: disable=C0103
176
+ self, request: ConfirmMessageReceivedRequest, **kwargs: Any
177
+ ) -> ConfirmMessageReceivedResponse:
178
+ """."""
179
+ return self._send_and_receive(request, ConfirmMessageReceivedResponse, **kwargs)
@@ -164,7 +164,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
164
164
  in_meta = in_message_metadata
165
165
  if ( # pylint: disable-next=too-many-boolean-expressions
166
166
  out_meta.run_id == in_meta.run_id
167
- and out_meta.message_id == "" # This will be generated by the server
167
+ and out_meta.message_id == out_message.object_id # Should match the object id
168
168
  and out_meta.src_node_id == in_meta.dst_node_id
169
169
  and out_meta.dst_node_id == in_meta.src_node_id
170
170
  and out_meta.reply_to_message_id == in_meta.message_id
@@ -32,14 +32,17 @@ def message_size_mod(
32
32
 
33
33
  This mod logs the size in bytes of the message being transmited.
34
34
  """
35
- message_size_in_bytes = 0
35
+ # Log the size of the incoming message in bytes
36
+ total_bytes = sum(record.count_bytes() for record in msg.content.values())
37
+ log(INFO, "Incoming message size: %i bytes", total_bytes)
36
38
 
37
- for record in msg.content.values():
38
- message_size_in_bytes += record.count_bytes()
39
+ # Call the next layer
40
+ msg = call_next(msg, ctxt)
39
41
 
40
- log(INFO, "Message size: %i bytes", message_size_in_bytes)
41
-
42
- return call_next(msg, ctxt)
42
+ # Log the size of the outgoing message in bytes
43
+ total_bytes = sum(record.count_bytes() for record in msg.content.values())
44
+ log(INFO, "Outgoing message size: %i bytes", total_bytes)
45
+ return msg
43
46
 
44
47
 
45
48
  def arrays_size_mod(
@@ -50,25 +53,41 @@ def arrays_size_mod(
50
53
  This mod logs the number of array elements transmitted in ``ArrayRecord`` objects
51
54
  of the message as well as their sizes in bytes.
52
55
  """
53
- model_size_stats = {}
54
- arrays_size_in_bytes = 0
56
+ # Log the ArrayRecord size statistics and the total size in the incoming message
57
+ array_record_size_stats = _get_array_record_size_stats(msg)
58
+ total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
59
+ if array_record_size_stats:
60
+ log(INFO, "Incoming `ArrayRecord` size statistics:")
61
+ log(INFO, array_record_size_stats)
62
+ log(INFO, "Total array elements received: %i bytes", total_bytes)
63
+
64
+ msg = call_next(msg, ctxt)
65
+
66
+ # Log the ArrayRecord size statistics and the total size in the outgoing message
67
+ array_record_size_stats = _get_array_record_size_stats(msg)
68
+ total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
69
+ if array_record_size_stats:
70
+ log(INFO, "Outgoing `ArrayRecord` size statistics:")
71
+ log(INFO, array_record_size_stats)
72
+ log(INFO, "Total array elements sent: %i bytes", total_bytes)
73
+ return msg
74
+
75
+
76
+ def _get_array_record_size_stats(
77
+ msg: Message,
78
+ ) -> dict[str, dict[str, int]]:
79
+ """Get `ArrayRecord` size statistics from the message."""
80
+ array_record_size_stats = {}
55
81
  for record_name, arr_record in msg.content.array_records.items():
56
82
  arr_record_bytes = arr_record.count_bytes()
57
- arrays_size_in_bytes += arr_record_bytes
58
83
  element_count = 0
59
84
  for array in arr_record.values():
60
85
  element_count += (
61
86
  int(np.prod(array.shape)) if array.shape else array.numpy().size
62
87
  )
63
88
 
64
- model_size_stats[f"{record_name}"] = {
89
+ array_record_size_stats[record_name] = {
65
90
  "elements": element_count,
66
91
  "bytes": arr_record_bytes,
67
92
  }
68
-
69
- if model_size_stats:
70
- log(INFO, model_size_stats)
71
-
72
- log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
73
-
74
- return call_next(msg, ctxt)
93
+ return array_record_size_stats
@@ -15,30 +15,26 @@
15
15
  """Contextmanager for a REST request-response channel to the Flower server."""
16
16
 
17
17
 
18
- import random
19
- import threading
20
18
  from collections.abc import Iterator
21
19
  from contextlib import contextmanager
22
- from copy import copy
23
- from logging import ERROR, INFO, WARN
20
+ from logging import ERROR, WARN
24
21
  from typing import Callable, Optional, TypeVar, Union
25
22
 
26
23
  from cryptography.hazmat.primitives.asymmetric import ec
27
24
  from google.protobuf.message import Message as GrpcMessage
28
25
  from requests.exceptions import ConnectionError as RequestsConnectionError
29
26
 
30
- from flwr.client.heartbeat import start_ping_loop
31
- from flwr.client.message_handler.message_handler import validate_out_message
32
27
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
33
- from flwr.common.constant import (
34
- PING_BASE_MULTIPLIER,
35
- PING_CALL_TIMEOUT,
36
- PING_DEFAULT_INTERVAL,
37
- PING_RANDOM_RANGE,
38
- )
28
+ from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
39
29
  from flwr.common.exit import ExitCode, flwr_exit
30
+ from flwr.common.heartbeat import HeartbeatSender
31
+ from flwr.common.inflatable_protobuf_utils import (
32
+ make_confirm_message_received_fn_protobuf,
33
+ make_pull_object_fn_protobuf,
34
+ make_push_object_fn_protobuf,
35
+ )
40
36
  from flwr.common.logger import log
41
- from flwr.common.message import Message, Metadata
37
+ from flwr.common.message import Message, remove_content_from_message
42
38
  from flwr.common.retry_invoker import RetryInvoker
43
39
  from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
44
40
  from flwr.common.typing import Fab, Run
@@ -48,13 +44,24 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
48
44
  CreateNodeResponse,
49
45
  DeleteNodeRequest,
50
46
  DeleteNodeResponse,
51
- PingRequest,
52
- PingResponse,
53
47
  PullMessagesRequest,
54
48
  PullMessagesResponse,
55
49
  PushMessagesRequest,
56
50
  PushMessagesResponse,
57
51
  )
52
+ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
53
+ SendNodeHeartbeatRequest,
54
+ SendNodeHeartbeatResponse,
55
+ )
56
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
57
+ ConfirmMessageReceivedRequest,
58
+ ConfirmMessageReceivedResponse,
59
+ ObjectTree,
60
+ PullObjectRequest,
61
+ PullObjectResponse,
62
+ PushObjectRequest,
63
+ PushObjectResponse,
64
+ )
58
65
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
59
66
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
60
67
 
@@ -68,9 +75,12 @@ PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
68
75
  PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
69
76
  PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
70
77
  PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
71
- PATH_PING: str = "api/v0/fleet/ping"
78
+ PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
79
+ PATH_PUSH_OBJECT: str = "/api/v0/fleet/push-object"
80
+ PATH_SEND_NODE_HEARTBEAT: str = "api/v0/fleet/send-node-heartbeat"
72
81
  PATH_GET_RUN: str = "/api/v0/fleet/get-run"
73
82
  PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
83
+ PATH_CONFIRM_MESSAGE_RECEIVED: str = "/api/v0/fleet/confirm-message-received"
74
84
 
75
85
  T = TypeVar("T", bound=GrpcMessage)
76
86
 
@@ -89,12 +99,15 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
89
99
  ] = None,
90
100
  ) -> Iterator[
91
101
  tuple[
92
- Callable[[], Optional[Message]],
93
- Callable[[Message], None],
94
- Optional[Callable[[], Optional[int]]],
95
- Optional[Callable[[], None]],
96
- Optional[Callable[[int], Run]],
97
- Optional[Callable[[str, int], Fab]],
102
+ Callable[[], Optional[tuple[Message, ObjectTree]]],
103
+ Callable[[Message, ObjectTree], set[str]],
104
+ Callable[[], Optional[int]],
105
+ Callable[[], None],
106
+ Callable[[int], Run],
107
+ Callable[[str, int], Fab],
108
+ Callable[[int, str], bytes],
109
+ Callable[[int, str, bytes], None],
110
+ Callable[[int, str], None],
98
111
  ]
99
112
  ]:
100
113
  """Primitives for request/response-based interaction with a server.
@@ -130,6 +143,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
130
143
  create_node : Optional[Callable]
131
144
  delete_node : Optional[Callable]
132
145
  get_run : Optional[Callable]
146
+ pull_object : Callable[[str], bytes]
147
+ push_object : Callable[[str, bytes], None]
148
+ confirm_message_received : Callable[[str], None]
133
149
  """
134
150
  log(
135
151
  WARN,
@@ -158,13 +174,10 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
158
174
  log(ERROR, "Client authentication is not supported for this transport type.")
159
175
 
160
176
  # Shared variables for inner functions
161
- metadata: Optional[Metadata] = None
162
177
  node: Optional[Node] = None
163
- ping_thread: Optional[threading.Thread] = None
164
- ping_stop_event = threading.Event()
165
178
 
166
179
  ###########################################################################
167
- # ping/create_node/delete_node/receive/send/get_run functions
180
+ # heartbeat/create_node/delete_node/receive/send/get_run functions
168
181
  ###########################################################################
169
182
 
170
183
  def _request(
@@ -214,57 +227,89 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
214
227
  grpc_res.ParseFromString(res.content)
215
228
  return grpc_res
216
229
 
217
- def ping() -> None:
230
+ def _pull_object_protobuf(request: PullObjectRequest) -> PullObjectResponse:
231
+ res = _request(
232
+ req=request,
233
+ res_type=PullObjectResponse,
234
+ api_path=PATH_PULL_OBJECT,
235
+ )
236
+ if res is None:
237
+ raise ValueError(f"{PullObjectResponse.__name__} is None.")
238
+ return res
239
+
240
+ def _push_object_protobuf(request: PushObjectRequest) -> PushObjectResponse:
241
+ res = _request(
242
+ req=request,
243
+ res_type=PushObjectResponse,
244
+ api_path=PATH_PUSH_OBJECT,
245
+ )
246
+ if res is None:
247
+ raise ValueError(f"{PushObjectResponse.__name__} is None.")
248
+ return res
249
+
250
+ def _confirm_message_received_protobuf(
251
+ request: ConfirmMessageReceivedRequest,
252
+ ) -> ConfirmMessageReceivedResponse:
253
+ res = _request(
254
+ req=request,
255
+ res_type=ConfirmMessageReceivedResponse,
256
+ api_path=PATH_CONFIRM_MESSAGE_RECEIVED,
257
+ )
258
+ if res is None:
259
+ raise ValueError(f"{ConfirmMessageReceivedResponse.__name__} is None.")
260
+ return res
261
+
262
+ def send_node_heartbeat() -> bool:
218
263
  # Get Node
219
264
  if node is None:
220
265
  log(ERROR, "Node instance missing")
221
- return
266
+ return False
222
267
 
223
- # Construct the ping request
224
- req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
268
+ # Construct the heartbeat request
269
+ req = SendNodeHeartbeatRequest(
270
+ node=node, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
271
+ )
225
272
 
226
273
  # Send the request
227
- res = _request(req, PingResponse, PATH_PING, retry=False)
274
+ res = _request(
275
+ req, SendNodeHeartbeatResponse, PATH_SEND_NODE_HEARTBEAT, retry=False
276
+ )
228
277
  if res is None:
229
- return
278
+ return False
230
279
 
231
280
  # Check if success
232
281
  if not res.success:
233
- raise RuntimeError("Ping failed unexpectedly.")
282
+ raise RuntimeError(
283
+ "Heartbeat failed unexpectedly. The SuperLink does not "
284
+ "recognize this SuperNode."
285
+ )
286
+ return True
234
287
 
235
- # Wait
236
- rd = random.uniform(*PING_RANDOM_RANGE)
237
- next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
238
- next_interval *= PING_BASE_MULTIPLIER + rd
239
- if not ping_stop_event.is_set():
240
- ping_stop_event.wait(next_interval)
288
+ heartbeat_sender = HeartbeatSender(send_node_heartbeat)
241
289
 
242
290
  def create_node() -> Optional[int]:
243
291
  """Set create_node."""
244
- req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
292
+ req = CreateNodeRequest(heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL)
245
293
 
246
294
  # Send the request
247
295
  res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
248
296
  if res is None:
249
297
  return None
250
298
 
251
- # Remember the node and the ping-loop thread
252
- nonlocal node, ping_thread
299
+ # Remember the node and start the heartbeat sender
300
+ nonlocal node
253
301
  node = res.node
254
- ping_thread = start_ping_loop(ping, ping_stop_event)
302
+ heartbeat_sender.start()
255
303
  return node.node_id
256
304
 
257
305
  def delete_node() -> None:
258
306
  """Set delete_node."""
259
307
  nonlocal node
260
308
  if node is None:
261
- log(ERROR, "Node instance missing")
262
- return
309
+ raise RuntimeError("Node instance missing")
263
310
 
264
- # Stop the ping-loop thread
265
- ping_stop_event.set()
266
- if ping_thread is not None:
267
- ping_thread.join()
311
+ # Stop the heartbeat sender
312
+ heartbeat_sender.stop()
268
313
 
269
314
  # Send DeleteNode request
270
315
  req = DeleteNodeRequest(node=node)
@@ -277,75 +322,54 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
277
322
  # Cleanup
278
323
  node = None
279
324
 
280
- def receive() -> Optional[Message]:
281
- """Receive next Message from server."""
325
+ def receive() -> Optional[tuple[Message, ObjectTree]]:
326
+ """Pull a message with its ObjectTree from SuperLink."""
282
327
  # Get Node
283
328
  if node is None:
284
- log(ERROR, "Node instance missing")
285
- return None
329
+ raise RuntimeError("Node instance missing")
286
330
 
287
- # Request instructions (message) from server
331
+ # Try to pull a message with its object tree from SuperLink
288
332
  req = PullMessagesRequest(node=node)
289
-
290
- # Send the request
291
333
  res = _request(req, PullMessagesResponse, PATH_PULL_MESSAGES)
292
334
  if res is None:
335
+ raise ValueError("PushMessagesResponse is None.")
336
+
337
+ # If no messages are available, return None
338
+ if len(res.messages_list) == 0:
293
339
  return None
294
340
 
295
- # Get the current Messages
296
- message_proto = None if len(res.messages_list) == 0 else res.messages_list[0]
297
-
298
- # Discard the current message if not valid
299
- if message_proto is not None and not (
300
- message_proto.metadata.dst_node_id == node.node_id
301
- ):
302
- message_proto = None
303
-
304
- # Return the Message if available
305
- nonlocal metadata
306
- message = None
307
- if message_proto is not None:
308
- message = message_from_proto(message_proto)
309
- metadata = copy(message.metadata)
310
- log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
311
- return message
312
-
313
- def send(message: Message) -> None:
314
- """Send Message result back to server."""
315
- # Get Node
316
- if node is None:
317
- log(ERROR, "Node instance missing")
318
- return
341
+ # Get the current Message and its object tree
342
+ message_proto = res.messages_list[0]
343
+ object_tree = res.message_object_trees[0]
319
344
 
320
- # Get incoming message
321
- nonlocal metadata
322
- if metadata is None:
323
- log(ERROR, "No current message")
324
- return
345
+ # Construct the Message
346
+ in_message = message_from_proto(message_proto)
325
347
 
326
- # Validate out message
327
- if not validate_out_message(message, metadata):
328
- log(ERROR, "Invalid out message")
329
- return
330
- metadata = None
348
+ # Return the Message and its object tree
349
+ return in_message, object_tree
331
350
 
332
- # Serialize ProtoBuf to bytes
333
- message_proto = message_to_proto(message=message)
351
+ def send(message: Message, object_tree: ObjectTree) -> set[str]:
352
+ """Send the message with its ObjectTree to SuperLink."""
353
+ # Get Node
354
+ if node is None:
355
+ raise RuntimeError("Node instance missing")
334
356
 
335
- # Serialize ProtoBuf to bytes
336
- req = PushMessagesRequest(node=node, messages_list=[message_proto])
357
+ # Remove the content from the message if it has
358
+ if message.has_content():
359
+ message = remove_content_from_message(message)
337
360
 
338
- # Send the request
361
+ # Send the message with its ObjectTree to SuperLink
362
+ req = PushMessagesRequest(
363
+ node=node,
364
+ messages_list=[message_to_proto(message)],
365
+ message_object_trees=[object_tree],
366
+ )
339
367
  res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
340
368
  if res is None:
341
- return
369
+ raise ValueError("PushMessagesResponse is None.")
342
370
 
343
- log(
344
- INFO,
345
- "[Node] POST /%s: success, created result %s",
346
- PATH_PUSH_MESSAGES,
347
- res.results, # pylint: disable=no-member
348
- )
371
+ # Get and return the object IDs to push
372
+ return set(res.objects_to_push)
349
373
 
350
374
  def get_run(run_id: int) -> Run:
351
375
  # Construct the request
@@ -372,9 +396,58 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
372
396
  res.fab.content,
373
397
  )
374
398
 
399
+ def pull_object(run_id: int, object_id: str) -> bytes:
400
+ """Pull the object from the SuperLink."""
401
+ # Check Node
402
+ if node is None:
403
+ raise RuntimeError("Node instance missing")
404
+
405
+ fn = make_pull_object_fn_protobuf(
406
+ pull_object_protobuf=_pull_object_protobuf,
407
+ node=node,
408
+ run_id=run_id,
409
+ )
410
+ return fn(object_id)
411
+
412
+ def push_object(run_id: int, object_id: str, contents: bytes) -> None:
413
+ """Push the object to the SuperLink."""
414
+ # Check Node
415
+ if node is None:
416
+ raise RuntimeError("Node instance missing")
417
+
418
+ fn = make_push_object_fn_protobuf(
419
+ push_object_protobuf=_push_object_protobuf,
420
+ node=node,
421
+ run_id=run_id,
422
+ )
423
+ fn(object_id, contents)
424
+
425
+ def confirm_message_received(run_id: int, object_id: str) -> None:
426
+ """Confirm that the message has been received."""
427
+ # Check Node
428
+ if node is None:
429
+ raise RuntimeError("Node instance missing")
430
+
431
+ fn = make_confirm_message_received_fn_protobuf(
432
+ confirm_message_received_protobuf=_confirm_message_received_protobuf,
433
+ node=node,
434
+ run_id=run_id,
435
+ )
436
+ fn(object_id)
437
+
375
438
  try:
376
439
  # Yield methods
377
- yield (receive, send, create_node, delete_node, get_run, get_fab)
440
+ yield (
441
+ receive,
442
+ send,
443
+ create_node,
444
+ delete_node,
445
+ get_run,
446
+ get_fab,
447
+ pull_object,
448
+ push_object,
449
+ confirm_message_received,
450
+ )
378
451
  except Exception as exc: # pylint: disable=broad-except
379
452
  log(ERROR, exc)
380
453
  # Cleanup
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Public Flower ClientApp APIs."""
flwr/common/__init__.py CHANGED
@@ -15,6 +15,8 @@
15
15
  """Common components shared between server and client."""
16
16
 
17
17
 
18
+ from ..app.error import Error as Error
19
+ from ..app.metadata import Metadata as Metadata
18
20
  from .constant import MessageType as MessageType
19
21
  from .constant import MessageTypeLegacy as MessageTypeLegacy
20
22
  from .context import Context as Context
@@ -23,9 +25,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH
23
25
  from .logger import configure as configure
24
26
  from .logger import log as log
25
27
  from .message import DEFAULT_TTL
26
- from .message import Error as Error
27
28
  from .message import Message as Message
28
- from .message import Metadata as Metadata
29
29
  from .parameter import bytes_to_ndarray as bytes_to_ndarray
30
30
  from .parameter import ndarray_to_bytes as ndarray_to_bytes
31
31
  from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
@@ -17,8 +17,10 @@
17
17
 
18
18
  from .auth_plugin import CliAuthPlugin as CliAuthPlugin
19
19
  from .auth_plugin import ExecAuthPlugin as ExecAuthPlugin
20
+ from .auth_plugin import ExecAuthzPlugin as ExecAuthzPlugin
20
21
 
21
22
  __all__ = [
22
23
  "CliAuthPlugin",
23
24
  "ExecAuthPlugin",
25
+ "ExecAuthzPlugin",
24
26
  ]
@@ -20,7 +20,7 @@ from collections.abc import Sequence
20
20
  from pathlib import Path
21
21
  from typing import Optional, Union
22
22
 
23
- from flwr.common.typing import UserInfo
23
+ from flwr.common.typing import AccountInfo
24
24
  from flwr.proto.exec_pb2_grpc import ExecStub
25
25
 
26
26
  from ..typing import UserAuthCredentials, UserAuthLoginDetails
@@ -33,6 +33,9 @@ class ExecAuthPlugin(ABC):
33
33
  ----------
34
34
  user_auth_config_path : Path
35
35
  Path to the YAML file containing the authentication configuration.
36
+ verify_tls_cert : bool
37
+ Boolean indicating whether to verify the TLS certificate
38
+ when making requests to the server.
36
39
  """
37
40
 
38
41
  @abstractmethod
@@ -50,7 +53,7 @@ class ExecAuthPlugin(ABC):
50
53
  @abstractmethod
51
54
  def validate_tokens_in_metadata(
52
55
  self, metadata: Sequence[tuple[str, Union[str, bytes]]]
53
- ) -> tuple[bool, Optional[UserInfo]]:
56
+ ) -> tuple[bool, Optional[AccountInfo]]:
54
57
  """Validate authentication tokens in the provided metadata."""
55
58
 
56
59
  @abstractmethod
@@ -60,10 +63,33 @@ class ExecAuthPlugin(ABC):
60
63
  @abstractmethod
61
64
  def refresh_tokens(
62
65
  self, metadata: Sequence[tuple[str, Union[str, bytes]]]
63
- ) -> Optional[Sequence[tuple[str, Union[str, bytes]]]]:
66
+ ) -> tuple[
67
+ Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[AccountInfo]
68
+ ]:
64
69
  """Refresh authentication tokens in the provided metadata."""
65
70
 
66
71
 
72
+ class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
73
+ """Abstract Flower Authorization Plugin class for ExecServicer.
74
+
75
+ Parameters
76
+ ----------
77
+ user_auth_config_path : Path
78
+ Path to the YAML file containing the authorization configuration.
79
+ verify_tls_cert : bool
80
+ Boolean indicating whether to verify the TLS certificate
81
+ when making requests to the server.
82
+ """
83
+
84
+ @abstractmethod
85
+ def __init__(self, user_auth_config_path: Path, verify_tls_cert: bool):
86
+ """Abstract constructor."""
87
+
88
+ @abstractmethod
89
+ def verify_user_authorization(self, account_info: AccountInfo) -> bool:
90
+ """Verify user authorization request."""
91
+
92
+
67
93
  class CliAuthPlugin(ABC):
68
94
  """Abstract Flower Auth Plugin class for CLI.
69
95