flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  5. flwr/cli/build.py +15 -5
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +23 -4
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  14. flwr/cli/new/templates/app/README.md.tpl +5 -0
  15. flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
  16. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
  17. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
  18. flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
  19. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
  20. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  21. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  22. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  23. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  24. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  25. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  26. flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
  27. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  28. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  29. flwr/cli/run/run.py +53 -50
  30. flwr/cli/stop.py +7 -4
  31. flwr/cli/utils.py +29 -11
  32. flwr/client/grpc_adapter_client/connection.py +11 -4
  33. flwr/client/grpc_rere_client/connection.py +93 -129
  34. flwr/client/rest_client/connection.py +134 -164
  35. flwr/clientapp/__init__.py +10 -0
  36. flwr/clientapp/mod/__init__.py +26 -0
  37. flwr/clientapp/mod/centraldp_mods.py +132 -0
  38. flwr/common/args.py +20 -6
  39. flwr/common/auth_plugin/__init__.py +4 -4
  40. flwr/common/auth_plugin/auth_plugin.py +7 -7
  41. flwr/common/constant.py +26 -5
  42. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  43. flwr/common/exit/__init__.py +4 -0
  44. flwr/common/exit/exit.py +8 -1
  45. flwr/common/exit/exit_code.py +42 -8
  46. flwr/common/exit/exit_handler.py +62 -0
  47. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  48. flwr/common/grpc.py +1 -1
  49. flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
  50. flwr/common/inflatable_utils.py +191 -24
  51. flwr/common/logger.py +1 -1
  52. flwr/common/record/array.py +101 -22
  53. flwr/common/record/arraychunk.py +59 -0
  54. flwr/common/retry_invoker.py +30 -11
  55. flwr/common/serde.py +0 -28
  56. flwr/common/telemetry.py +4 -0
  57. flwr/compat/client/app.py +14 -31
  58. flwr/compat/server/app.py +2 -2
  59. flwr/proto/appio_pb2.py +51 -0
  60. flwr/proto/appio_pb2.pyi +195 -0
  61. flwr/proto/appio_pb2_grpc.py +4 -0
  62. flwr/proto/appio_pb2_grpc.pyi +4 -0
  63. flwr/proto/clientappio_pb2.py +4 -19
  64. flwr/proto/clientappio_pb2.pyi +0 -125
  65. flwr/proto/clientappio_pb2_grpc.py +269 -29
  66. flwr/proto/clientappio_pb2_grpc.pyi +114 -21
  67. flwr/proto/control_pb2.py +62 -0
  68. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
  69. flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
  70. flwr/proto/fleet_pb2.py +12 -20
  71. flwr/proto/fleet_pb2.pyi +6 -36
  72. flwr/proto/serverappio_pb2.py +8 -31
  73. flwr/proto/serverappio_pb2.pyi +0 -152
  74. flwr/proto/serverappio_pb2_grpc.py +107 -38
  75. flwr/proto/serverappio_pb2_grpc.pyi +47 -20
  76. flwr/proto/simulationio_pb2.py +4 -11
  77. flwr/proto/simulationio_pb2.pyi +0 -58
  78. flwr/proto/simulationio_pb2_grpc.py +129 -27
  79. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  80. flwr/server/app.py +130 -153
  81. flwr/server/fleet_event_log_interceptor.py +4 -0
  82. flwr/server/grid/grpc_grid.py +94 -54
  83. flwr/server/grid/inmemory_grid.py +1 -0
  84. flwr/server/serverapp/app.py +165 -144
  85. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
  86. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  87. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  88. flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
  89. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
  90. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  91. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  93. flwr/server/superlink/linkstate/linkstate.py +2 -1
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  95. flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
  96. flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
  97. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  98. flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
  99. flwr/server/superlink/utils.py +0 -35
  100. flwr/serverapp/__init__.py +12 -0
  101. flwr/serverapp/dp_fixed_clipping.py +352 -0
  102. flwr/serverapp/exception.py +38 -0
  103. flwr/serverapp/strategy/__init__.py +38 -0
  104. flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
  105. flwr/serverapp/strategy/fedadagrad.py +162 -0
  106. flwr/serverapp/strategy/fedadam.py +181 -0
  107. flwr/serverapp/strategy/fedavg.py +295 -0
  108. flwr/serverapp/strategy/fedopt.py +218 -0
  109. flwr/serverapp/strategy/fedyogi.py +173 -0
  110. flwr/serverapp/strategy/result.py +105 -0
  111. flwr/serverapp/strategy/strategy.py +285 -0
  112. flwr/serverapp/strategy/strategy_utils.py +251 -0
  113. flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
  114. flwr/simulation/app.py +159 -154
  115. flwr/simulation/run_simulation.py +17 -0
  116. flwr/supercore/app_utils.py +58 -0
  117. flwr/supercore/cli/__init__.py +22 -0
  118. flwr/supercore/cli/flower_superexec.py +141 -0
  119. flwr/supercore/corestate/__init__.py +22 -0
  120. flwr/supercore/corestate/corestate.py +81 -0
  121. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  122. flwr/supercore/grpc_health/__init__.py +25 -0
  123. flwr/supercore/grpc_health/health_server.py +53 -0
  124. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  125. flwr/supercore/license_plugin/__init__.py +22 -0
  126. flwr/supercore/license_plugin/license_plugin.py +26 -0
  127. flwr/supercore/object_store/in_memory_object_store.py +31 -31
  128. flwr/supercore/object_store/object_store.py +20 -42
  129. flwr/supercore/object_store/utils.py +43 -0
  130. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  131. flwr/supercore/superexec/plugin/__init__.py +28 -0
  132. flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
  133. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  134. flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
  135. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  136. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  137. flwr/supercore/superexec/run_superexec.py +185 -0
  138. flwr/supercore/utils.py +32 -0
  139. flwr/superlink/servicer/__init__.py +15 -0
  140. flwr/superlink/servicer/control/__init__.py +22 -0
  141. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
  142. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
  143. flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
  144. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
  145. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
  146. flwr/supernode/cli/flower_supernode.py +3 -7
  147. flwr/supernode/cli/flwr_clientapp.py +20 -16
  148. flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
  149. flwr/supernode/nodestate/nodestate.py +3 -44
  150. flwr/supernode/runtime/run_clientapp.py +129 -115
  151. flwr/supernode/servicer/clientappio/__init__.py +1 -3
  152. flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
  153. flwr/supernode/start_client_internal.py +205 -148
  154. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
  155. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
  156. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
  157. flwr/common/inflatable_rest_utils.py +0 -99
  158. flwr/proto/exec_pb2.py +0 -62
  159. flwr/superexec/app.py +0 -45
  160. flwr/superexec/deployment.py +0 -192
  161. flwr/superexec/executor.py +0 -100
  162. flwr/superexec/simulation.py +0 -130
  163. /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
  164. /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
  165. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  166. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  167. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
@@ -81,12 +81,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
81
81
  metadata sent by the node. Continue RPC call if node is authenticated, else,
82
82
  terminate RPC call by setting context to abort.
83
83
  """
84
- # Filter out non-Fleet service calls
84
+ # Only apply to Fleet service
85
85
  if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
86
- return _unary_unary_rpc_terminator(
87
- "This request should be sent to a different service.",
88
- grpc.StatusCode.FAILED_PRECONDITION,
89
- )
86
+ return continuation(handler_call_details)
90
87
 
91
88
  state = self.state_factory.state()
92
89
  metadata_dict = dict(handler_call_details.invocation_metadata)
@@ -46,7 +46,6 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
46
46
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
47
  ConfirmMessageReceivedRequest,
48
48
  ConfirmMessageReceivedResponse,
49
- ObjectIDs,
50
49
  PullObjectRequest,
51
50
  PullObjectResponse,
52
51
  PushObjectRequest,
@@ -58,12 +57,11 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
58
57
  GetRunResponse,
59
58
  Run,
60
59
  )
61
- from flwr.server.superlink.ffs.ffs import Ffs
62
60
  from flwr.server.superlink.linkstate import LinkState
63
61
  from flwr.server.superlink.utils import check_abort
62
+ from flwr.supercore.ffs import Ffs
64
63
  from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
65
-
66
- from ...utils import store_mapping_and_register_objects
64
+ from flwr.supercore.object_store.utils import store_mapping_and_register_objects
67
65
 
68
66
 
69
67
  def create_node(
@@ -113,25 +111,22 @@ def pull_messages(
113
111
 
114
112
  # Convert to Messages
115
113
  msg_proto = []
116
- objects_to_pull: dict[str, ObjectIDs] = {}
114
+ trees = []
117
115
  for msg in message_list:
118
116
  try:
119
- msg_proto.append(message_to_proto(msg))
120
-
117
+ # Retrieve Message object tree from ObjectStore
121
118
  msg_object_id = msg.metadata.message_id
122
- descendants = store.get_message_descendant_ids(msg_object_id)
123
- # Include the object_id of the message itself
124
- objects_to_pull[msg_object_id] = ObjectIDs(
125
- object_ids=descendants + [msg_object_id]
126
- )
119
+ obj_tree = store.get_object_tree(msg_object_id)
120
+
121
+ # Add Message and its object tree to the response
122
+ msg_proto.append(message_to_proto(msg))
123
+ trees.append(obj_tree)
127
124
  except NoObjectInStoreError as e:
128
125
  log(ERROR, e.message)
129
126
  # Delete message ins from state
130
127
  state.delete_messages(message_ins_ids={msg_object_id})
131
128
 
132
- return PullMessagesResponse(
133
- messages_list=msg_proto, objects_to_pull=objects_to_pull
134
- )
129
+ return PullMessagesResponse(messages_list=msg_proto, message_object_trees=trees)
135
130
 
136
131
 
137
132
  def push_messages(
@@ -287,6 +282,5 @@ def confirm_message_received(
287
282
 
288
283
  # Delete the message object
289
284
  store.delete(request.message_object_id)
290
- store.delete_message_descendant_ids(request.message_object_id)
291
285
 
292
286
  return ConfirmMessageReceivedResponse()
@@ -47,10 +47,9 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
47
  PushObjectResponse,
48
48
  )
49
49
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
50
- from flwr.server.superlink.ffs.ffs import Ffs
51
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
52
50
  from flwr.server.superlink.fleet.message_handler import message_handler
53
51
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
52
+ from flwr.supercore.ffs import Ffs, FfsFactory
54
53
  from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
55
54
 
56
55
  try:
@@ -161,6 +161,7 @@ class RayBackend(Backend):
161
161
  "Call the backend's `build()` method before processing messages."
162
162
  )
163
163
 
164
+ future = None
164
165
  try:
165
166
  # Submit a task to the pool
166
167
  future = self.pool.submit(
@@ -183,7 +184,8 @@ class RayBackend(Backend):
183
184
  self.__class__.__name__,
184
185
  )
185
186
  # add actor back into pool
186
- self.pool.add_actor_back_to_pool(future)
187
+ if future is not None:
188
+ self.pool.add_actor_back_to_pool(future)
187
189
  raise ex
188
190
 
189
191
  def terminate(self) -> None:
@@ -23,7 +23,6 @@ from concurrent.futures import ThreadPoolExecutor
23
23
  from logging import DEBUG, ERROR, INFO, WARN
24
24
  from pathlib import Path
25
25
  from queue import Empty, Queue
26
- from time import sleep
27
26
  from typing import Callable, Optional
28
27
  from uuid import uuid4
29
28
 
@@ -153,7 +152,7 @@ def add_messages_to_queue(
153
152
  message_ins_list = state.get_message_ins(node_id=node_id, limit=1)
154
153
  for msg in message_ins_list:
155
154
  queue.put(msg)
156
- sleep(0.1)
155
+ f_stop.wait(0.1)
157
156
 
158
157
 
159
158
  def put_message_into_state(
@@ -182,6 +181,7 @@ def run_api(
182
181
  messageins_queue: Queue[Message] = Queue()
183
182
  messageres_queue: Queue[Message] = Queue()
184
183
 
184
+ backend = None
185
185
  try:
186
186
 
187
187
  # Instantiate backend
@@ -236,16 +236,16 @@ def run_api(
236
236
  log(ERROR, traceback.format_exc())
237
237
  log(WARN, "Stopping Simulation Engine.")
238
238
 
239
- # Manually trigger stopping event
240
- f_stop.set()
241
-
242
239
  # Raise exception
243
240
  raise RuntimeError("Simulation Engine crashed.") from ex
244
241
 
245
242
  finally:
243
+ # Manually trigger stopping event
244
+ f_stop.set()
246
245
 
247
246
  # Terminate backend
248
- backend.terminate()
247
+ if backend is not None:
248
+ backend.terminate()
249
249
 
250
250
 
251
251
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
@@ -15,6 +15,7 @@
15
15
  """In-memory LinkState implementation."""
16
16
 
17
17
 
18
+ import secrets
18
19
  import threading
19
20
  import time
20
21
  from bisect import bisect_right
@@ -25,6 +26,7 @@ from typing import Optional
25
26
 
26
27
  from flwr.common import Context, Message, log, now
27
28
  from flwr.common.constant import (
29
+ FLWR_APP_TOKEN_LENGTH,
28
30
  HEARTBEAT_MAX_INTERVAL,
29
31
  HEARTBEAT_PATIENCE,
30
32
  MESSAGE_TTL_TOLERANCE,
@@ -80,6 +82,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
80
82
  self.message_res_store: dict[str, Message] = {}
81
83
  self.message_ins_id_to_message_res_id: dict[str, str] = {}
82
84
 
85
+ # Store run ID to token mapping and token to run ID mapping
86
+ self.token_store: dict[int, str] = {}
87
+ self.token_to_run_id: dict[str, int] = {}
88
+ self.lock_token_store = threading.Lock()
89
+
83
90
  # Map flwr_aid to run_ids for O(1) reverse index lookup
84
91
  self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
85
92
 
@@ -678,3 +685,30 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
678
685
  index = bisect_right(run.logs, (after_timestamp, ""))
679
686
  latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
680
687
  return "".join(log for _, log in run.logs[index:]), latest_timestamp
688
+
689
+ def create_token(self, run_id: int) -> Optional[str]:
690
+ """Create a token for the given run ID."""
691
+ token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
692
+ with self.lock_token_store:
693
+ if run_id in self.token_store:
694
+ return None # Token already created for this run ID
695
+ self.token_store[run_id] = token
696
+ self.token_to_run_id[token] = run_id
697
+ return token
698
+
699
+ def verify_token(self, run_id: int, token: str) -> bool:
700
+ """Verify a token for the given run ID."""
701
+ with self.lock_token_store:
702
+ return self.token_store.get(run_id) == token
703
+
704
+ def delete_token(self, run_id: int) -> None:
705
+ """Delete the token for the given run ID."""
706
+ with self.lock_token_store:
707
+ token = self.token_store.pop(run_id, None)
708
+ if token is not None:
709
+ self.token_to_run_id.pop(token, None)
710
+
711
+ def get_run_id_by_token(self, token: str) -> Optional[int]:
712
+ """Get the run ID associated with a given token."""
713
+ with self.lock_token_store:
714
+ return self.token_to_run_id.get(token)
@@ -21,9 +21,10 @@ from typing import Optional
21
21
  from flwr.common import Context, Message
22
22
  from flwr.common.record import ConfigRecord
23
23
  from flwr.common.typing import Run, RunStatus, UserConfig
24
+ from flwr.supercore.corestate import CoreState
24
25
 
25
26
 
26
- class LinkState(abc.ABC): # pylint: disable=R0904
27
+ class LinkState(CoreState): # pylint: disable=R0904
27
28
  """Abstract LinkState."""
28
29
 
29
30
  @abc.abstractmethod
@@ -19,6 +19,7 @@
19
19
 
20
20
  import json
21
21
  import re
22
+ import secrets
22
23
  import sqlite3
23
24
  import time
24
25
  from collections.abc import Sequence
@@ -27,6 +28,7 @@ from typing import Any, Optional, Union, cast
27
28
 
28
29
  from flwr.common import Context, Message, Metadata, log, now
29
30
  from flwr.common.constant import (
31
+ FLWR_APP_TOKEN_LENGTH,
30
32
  HEARTBEAT_MAX_INTERVAL,
31
33
  HEARTBEAT_PATIENCE,
32
34
  MESSAGE_TTL_TOLERANCE,
@@ -163,6 +165,13 @@ CREATE TABLE IF NOT EXISTS message_res(
163
165
  );
164
166
  """
165
167
 
168
+ SQL_CREATE_TABLE_TOKEN_STORE = """
169
+ CREATE TABLE IF NOT EXISTS token_store (
170
+ run_id INTEGER PRIMARY KEY,
171
+ token TEXT UNIQUE NOT NULL
172
+ );
173
+ """
174
+
166
175
  DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
167
176
 
168
177
 
@@ -212,6 +221,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
212
221
  cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
213
222
  cur.execute(SQL_CREATE_TABLE_NODE)
214
223
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
224
+ cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
215
225
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
216
226
  res = cur.execute("SELECT name FROM sqlite_schema;")
217
227
  return res.fetchall()
@@ -1138,6 +1148,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1138
1148
 
1139
1149
  return message_ins
1140
1150
 
1151
+ def create_token(self, run_id: int) -> Optional[str]:
1152
+ """Create a token for the given run ID."""
1153
+ token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
1154
+ query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
1155
+ data = {"run_id": convert_uint64_to_sint64(run_id), "token": token}
1156
+ try:
1157
+ self.query(query, data)
1158
+ except sqlite3.IntegrityError:
1159
+ return None # Token already created for this run ID
1160
+ return token
1161
+
1162
+ def verify_token(self, run_id: int, token: str) -> bool:
1163
+ """Verify a token for the given run ID."""
1164
+ query = "SELECT token FROM token_store WHERE run_id = :run_id;"
1165
+ data = {"run_id": convert_uint64_to_sint64(run_id)}
1166
+ rows = self.query(query, data)
1167
+ if not rows:
1168
+ return False
1169
+ return cast(str, rows[0]["token"]) == token
1170
+
1171
+ def delete_token(self, run_id: int) -> None:
1172
+ """Delete the token for the given run ID."""
1173
+ query = "DELETE FROM token_store WHERE run_id = :run_id;"
1174
+ data = {"run_id": convert_uint64_to_sint64(run_id)}
1175
+ self.query(query, data)
1176
+
1177
+ def get_run_id_by_token(self, token: str) -> Optional[int]:
1178
+ """Get the run ID associated with a given token."""
1179
+ query = "SELECT run_id FROM token_store WHERE token = :token;"
1180
+ data = {"token": token}
1181
+ rows = self.query(query, data)
1182
+ if not rows:
1183
+ return None
1184
+ return convert_sint64_to_uint64(rows[0]["run_id"])
1185
+
1141
1186
 
1142
1187
  def dict_factory(
1143
1188
  cursor: sqlite3.Cursor,
@@ -26,8 +26,8 @@ from flwr.common.logger import log
26
26
  from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
27
27
  add_ServerAppIoServicer_to_server,
28
28
  )
29
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
30
29
  from flwr.server.superlink.linkstate import LinkStateFactory
30
+ from flwr.supercore.ffs import FfsFactory
31
31
  from flwr.supercore.object_store import ObjectStoreFactory
32
32
 
33
33
  from .serverappio_servicer import ServerAppIoServicer
@@ -58,7 +58,7 @@ def run_serverappio_api_grpc(
58
58
  certificates=certificates,
59
59
  )
60
60
 
61
- log(INFO, "Flower ECE: Starting ServerAppIo API (gRPC-rere) on %s", address)
61
+ log(INFO, "Flower Deployment Runtime: Starting ServerAppIo API on %s", address)
62
62
  serverappio_grpc_server.start()
63
63
 
64
64
  return serverappio_grpc_server
@@ -42,6 +42,20 @@ from flwr.common.serde import (
42
42
  )
43
43
  from flwr.common.typing import Fab, RunStatus
44
44
  from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
45
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
46
+ ListAppsToLaunchRequest,
47
+ ListAppsToLaunchResponse,
48
+ PullAppInputsRequest,
49
+ PullAppInputsResponse,
50
+ PullAppMessagesRequest,
51
+ PullAppMessagesResponse,
52
+ PushAppMessagesRequest,
53
+ PushAppMessagesResponse,
54
+ PushAppOutputsRequest,
55
+ PushAppOutputsResponse,
56
+ RequestTokenRequest,
57
+ RequestTokenResponse,
58
+ )
45
59
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
46
60
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
47
61
  SendAppHeartbeatRequest,
@@ -54,7 +68,6 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
54
68
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
55
69
  ConfirmMessageReceivedRequest,
56
70
  ConfirmMessageReceivedResponse,
57
- ObjectIDs,
58
71
  PullObjectRequest,
59
72
  PullObjectResponse,
60
73
  PushObjectRequest,
@@ -72,23 +85,13 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
72
85
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
73
86
  GetNodesRequest,
74
87
  GetNodesResponse,
75
- PullResMessagesRequest,
76
- PullResMessagesResponse,
77
- PullServerAppInputsRequest,
78
- PullServerAppInputsResponse,
79
- PushInsMessagesRequest,
80
- PushInsMessagesResponse,
81
- PushServerAppOutputsRequest,
82
- PushServerAppOutputsResponse,
83
88
  )
84
- from flwr.server.superlink.ffs.ffs import Ffs
85
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
86
89
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
87
90
  from flwr.server.superlink.utils import abort_if
88
91
  from flwr.server.utils.validator import validate_message
92
+ from flwr.supercore.ffs import Ffs, FfsFactory
89
93
  from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
90
-
91
- from ..utils import store_mapping_and_register_objects
94
+ from flwr.supercore.object_store.utils import store_mapping_and_register_objects
92
95
 
93
96
 
94
97
  class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
@@ -105,6 +108,42 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
105
108
  self.objectstore_factory = objectstore_factory
106
109
  self.lock = threading.RLock()
107
110
 
111
+ def ListAppsToLaunch(
112
+ self,
113
+ request: ListAppsToLaunchRequest,
114
+ context: grpc.ServicerContext,
115
+ ) -> ListAppsToLaunchResponse:
116
+ """Get run IDs with pending messages."""
117
+ log(DEBUG, "ServerAppIoServicer.ListAppsToLaunch")
118
+
119
+ # Initialize state connection
120
+ state = self.state_factory.state()
121
+
122
+ # Get IDs of runs in pending status
123
+ run_ids = state.get_run_ids(flwr_aid=None)
124
+ pending_run_ids = []
125
+ for run_id, status in state.get_run_status(run_ids).items():
126
+ if status.status == Status.PENDING:
127
+ pending_run_ids.append(run_id)
128
+
129
+ # Return run IDs
130
+ return ListAppsToLaunchResponse(run_ids=pending_run_ids)
131
+
132
+ def RequestToken(
133
+ self, request: RequestTokenRequest, context: grpc.ServicerContext
134
+ ) -> RequestTokenResponse:
135
+ """Request token."""
136
+ log(DEBUG, "ServerAppIoServicer.RequestToken")
137
+
138
+ # Initialize state connection
139
+ state = self.state_factory.state()
140
+
141
+ # Attempt to create a token for the provided run ID
142
+ token = state.create_token(request.run_id)
143
+
144
+ # Return the token
145
+ return RequestTokenResponse(token=token or "")
146
+
108
147
  def GetNodes(
109
148
  self, request: GetNodesRequest, context: grpc.ServicerContext
110
149
  ) -> GetNodesResponse:
@@ -129,8 +168,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
129
168
  return GetNodesResponse(nodes=nodes)
130
169
 
131
170
  def PushMessages(
132
- self, request: PushInsMessagesRequest, context: grpc.ServicerContext
133
- ) -> PushInsMessagesResponse:
171
+ self, request: PushAppMessagesRequest, context: grpc.ServicerContext
172
+ ) -> PushAppMessagesResponse:
134
173
  """Push a set of Messages."""
135
174
  log(DEBUG, "ServerAppIoServicer.PushMessages")
136
175
 
@@ -174,7 +213,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
174
213
  # Store Message object to descendants mapping and preregister objects
175
214
  objects_to_push = store_mapping_and_register_objects(store, request=request)
176
215
 
177
- return PushInsMessagesResponse(
216
+ return PushAppMessagesResponse(
178
217
  message_ids=[
179
218
  str(message_id) if message_id else "" for message_id in message_ids
180
219
  ],
@@ -182,8 +221,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
182
221
  )
183
222
 
184
223
  def PullMessages( # pylint: disable=R0914
185
- self, request: PullResMessagesRequest, context: grpc.ServicerContext
186
- ) -> PullResMessagesResponse:
224
+ self, request: PullAppMessagesRequest, context: grpc.ServicerContext
225
+ ) -> PullAppMessagesResponse:
187
226
  """Pull a set of Messages."""
188
227
  log(DEBUG, "ServerAppIoServicer.PullMessages")
189
228
 
@@ -210,12 +249,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
210
249
  if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
211
250
  with no_object_id_recompute():
212
251
  all_objects = get_all_nested_objects(msg_res)
213
- descendants = list(all_objects.keys())[:-1]
214
- message_obj_id = msg_res.metadata.message_id
215
- # Store mapping
216
- store.set_message_descendant_ids(
217
- msg_object_id=message_obj_id, descendant_ids=descendants
218
- )
219
252
  # Preregister
220
253
  store.preregister(request.run_id, get_object_tree(msg_res))
221
254
  # Store objects
@@ -231,7 +264,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
231
264
 
232
265
  # Convert Messages to proto
233
266
  messages_list = []
234
- objects_to_pull: dict[str, ObjectIDs] = {}
267
+ trees = []
235
268
  while messages_res:
236
269
  msg = messages_res.pop(0)
237
270
 
@@ -242,20 +275,20 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
242
275
  request_name="PullMessages",
243
276
  detail="`message.metadata` has mismatched `run_id`",
244
277
  )
245
- messages_list.append(message_to_proto(msg))
246
278
 
247
279
  try:
248
280
  msg_object_id = msg.metadata.message_id
249
- descendants = store.get_message_descendant_ids(msg_object_id)
250
- # Add mapping of message object ID to its descendants
251
- objects_to_pull[msg_object_id] = ObjectIDs(object_ids=descendants)
281
+ obj_tree = store.get_object_tree(msg_object_id)
282
+ # Add message and object tree to the response
283
+ messages_list.append(message_to_proto(msg))
284
+ trees.append(obj_tree)
252
285
  except NoObjectInStoreError as e:
253
286
  log(ERROR, e.message)
254
287
  # Delete message ins from state
255
288
  state.delete_messages(message_ins_ids={msg_object_id})
256
289
 
257
- return PullResMessagesResponse(
258
- messages_list=messages_list, objects_to_pull=objects_to_pull
290
+ return PullAppMessagesResponse(
291
+ messages_list=messages_list, message_object_trees=trees
259
292
  )
260
293
 
261
294
  def GetRun(
@@ -288,22 +321,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
288
321
 
289
322
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
290
323
 
291
- def PullServerAppInputs(
292
- self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
293
- ) -> PullServerAppInputsResponse:
324
+ def PullAppInputs(
325
+ self, request: PullAppInputsRequest, context: grpc.ServicerContext
326
+ ) -> PullAppInputsResponse:
294
327
  """Pull ServerApp process inputs."""
295
- log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
328
+ log(DEBUG, "ServerAppIoServicer.PullAppInputs")
296
329
  # Init access to LinkState
297
330
  state = self.state_factory.state()
298
331
 
332
+ # Validate the token
333
+ run_id = self._verify_token(request.token, context)
334
+
299
335
  # Lock access to LinkState, preventing obtaining the same pending run_id
300
336
  with self.lock:
301
- # Attempt getting the run_id of a pending run
302
- run_id = state.get_pending_run_id()
303
- # If there's no pending run, return an empty response
304
- if run_id is None:
305
- return PullServerAppInputsResponse()
306
-
307
337
  # Init access to Ffs
308
338
  ffs = self.ffs_factory.ffs()
309
339
 
@@ -318,7 +348,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
318
348
  # Update run status to STARTING
319
349
  if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
320
350
  log(INFO, "Starting run %d", run_id)
321
- return PullServerAppInputsResponse(
351
+ return PullAppInputsResponse(
322
352
  context=context_to_proto(serverapp_ctxt),
323
353
  run=run_to_proto(run),
324
354
  fab=fab_to_proto(fab),
@@ -328,11 +358,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
328
358
  # or if the status cannot be updated to STARTING
329
359
  raise RuntimeError(f"Failed to start run {run_id}")
330
360
 
331
- def PushServerAppOutputs(
332
- self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
333
- ) -> PushServerAppOutputsResponse:
361
+ def PushAppOutputs(
362
+ self, request: PushAppOutputsRequest, context: grpc.ServicerContext
363
+ ) -> PushAppOutputsResponse:
334
364
  """Push ServerApp process outputs."""
335
- log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
365
+ log(DEBUG, "ServerAppIoServicer.PushAppOutputs")
366
+
367
+ # Validate the token
368
+ run_id = self._verify_token(request.token, context)
336
369
 
337
370
  # Init state and store
338
371
  state = self.state_factory.state()
@@ -348,7 +381,10 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
348
381
  )
349
382
 
350
383
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
351
- return PushServerAppOutputsResponse()
384
+
385
+ # Remove the token
386
+ state.delete_token(run_id)
387
+ return PushAppOutputsResponse()
352
388
 
353
389
  def UpdateRunStatus(
354
390
  self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
@@ -512,10 +548,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
512
548
 
513
549
  # Delete the message object
514
550
  store.delete(request.message_object_id)
515
- store.delete_message_descendant_ids(request.message_object_id)
516
551
 
517
552
  return ConfirmMessageReceivedResponse()
518
553
 
554
+ def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
555
+ """Verify the token and return the associated run ID."""
556
+ state = self.state_factory.state()
557
+ run_id = state.get_run_id_by_token(token)
558
+ if run_id is None or not state.verify_token(run_id, token):
559
+ context.abort(
560
+ grpc.StatusCode.PERMISSION_DENIED,
561
+ "Invalid token.",
562
+ )
563
+ raise RuntimeError("This line should never be reached.")
564
+ return run_id
565
+
519
566
 
520
567
  def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
521
568
  """Raise a `ValueError` with a detailed message if a validation error occurs."""
@@ -26,8 +26,8 @@ from flwr.common.logger import log
26
26
  from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
27
27
  add_SimulationIoServicer_to_server,
28
28
  )
29
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
30
29
  from flwr.server.superlink.linkstate import LinkStateFactory
30
+ from flwr.supercore.ffs import FfsFactory
31
31
 
32
32
  from .simulationio_servicer import SimulationIoServicer
33
33