flwr-nightly 1.19.0.dev20250611__py3-none-any.whl → 1.19.0.dev20250613__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/ls.py +12 -33
- flwr/cli/utils.py +18 -1
- flwr/client/grpc_rere_client/connection.py +47 -29
- flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
- flwr/client/rest_client/connection.py +70 -51
- flwr/common/constant.py +4 -0
- flwr/common/inflatable.py +24 -0
- flwr/common/serde.py +2 -0
- flwr/common/typing.py +2 -0
- flwr/proto/fleet_pb2.py +12 -16
- flwr/proto/fleet_pb2.pyi +4 -19
- flwr/proto/fleet_pb2_grpc.py +34 -0
- flwr/proto/fleet_pb2_grpc.pyi +13 -0
- flwr/proto/message_pb2.py +15 -9
- flwr/proto/message_pb2.pyi +41 -0
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +22 -26
- flwr/proto/serverappio_pb2.pyi +4 -19
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +1 -0
- flwr/server/grid/grpc_grid.py +20 -9
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +26 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +20 -3
- flwr/server/superlink/linkstate/linkstate.py +6 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
- flwr/server/superlink/utils.py +23 -10
- flwr/supercore/object_store/in_memory_object_store.py +160 -33
- flwr/supercore/object_store/object_store.py +54 -7
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_grpc.py +3 -0
- flwr/superexec/exec_servicer.py +125 -22
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/RECORD +43 -43
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/entry_points.txt +0 -0
@@ -15,44 +15,95 @@
|
|
15
15
|
"""Flower in-memory ObjectStore implementation."""
|
16
16
|
|
17
17
|
|
18
|
+
import threading
|
19
|
+
from dataclasses import dataclass
|
18
20
|
from typing import Optional
|
19
21
|
|
20
|
-
from flwr.common.inflatable import
|
22
|
+
from flwr.common.inflatable import (
|
23
|
+
get_object_children_ids_from_object_content,
|
24
|
+
get_object_id,
|
25
|
+
is_valid_sha256_hash,
|
26
|
+
iterate_object_tree,
|
27
|
+
)
|
21
28
|
from flwr.common.inflatable_utils import validate_object_content
|
29
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
22
30
|
|
23
31
|
from .object_store import NoObjectInStoreError, ObjectStore
|
24
32
|
|
25
33
|
|
34
|
+
@dataclass
|
35
|
+
class ObjectEntry:
|
36
|
+
"""Data class representing an object entry in the store."""
|
37
|
+
|
38
|
+
content: bytes
|
39
|
+
is_available: bool
|
40
|
+
ref_count: int # Number of references (direct parents) to this object
|
41
|
+
runs: set[int] # Set of run IDs that used this object
|
42
|
+
|
43
|
+
|
26
44
|
class InMemoryObjectStore(ObjectStore):
|
27
45
|
"""In-memory implementation of the ObjectStore interface."""
|
28
46
|
|
29
47
|
def __init__(self, verify: bool = True) -> None:
|
30
48
|
self.verify = verify
|
31
|
-
self.store: dict[str,
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
49
|
+
self.store: dict[str, ObjectEntry] = {}
|
50
|
+
self.lock_store = threading.RLock()
|
51
|
+
# Mapping the Object ID of a message to the list of descendant object IDs
|
52
|
+
self.msg_descendant_objects_mapping: dict[str, list[str]] = {}
|
53
|
+
self.lock_msg_mapping = threading.RLock()
|
54
|
+
# Mapping each run ID to a set of object IDs that are used in that run
|
55
|
+
self.run_objects_mapping: dict[int, set[str]] = {}
|
56
|
+
|
57
|
+
def preregister(self, run_id: int, object_tree: ObjectTree) -> list[str]:
|
36
58
|
"""Identify and preregister missing objects."""
|
37
59
|
new_objects = []
|
38
|
-
|
60
|
+
if run_id not in self.run_objects_mapping:
|
61
|
+
self.run_objects_mapping[run_id] = set()
|
62
|
+
|
63
|
+
for tree_node in iterate_object_tree(object_tree):
|
64
|
+
obj_id = tree_node.object_id
|
39
65
|
# Verify object ID format (must be a valid sha256 hash)
|
40
66
|
if not is_valid_sha256_hash(obj_id):
|
41
67
|
raise ValueError(f"Invalid object ID format: {obj_id}")
|
42
|
-
|
43
|
-
self.store
|
44
|
-
|
68
|
+
with self.lock_store:
|
69
|
+
if obj_id not in self.store:
|
70
|
+
self.store[obj_id] = ObjectEntry(
|
71
|
+
content=b"", # Initially empty content
|
72
|
+
is_available=False, # Initially not available
|
73
|
+
ref_count=0, # Reference count starts at 0
|
74
|
+
runs={run_id}, # Start with the current run ID
|
75
|
+
)
|
76
|
+
|
77
|
+
# Increment the reference count for all its children
|
78
|
+
# Post-order traversal ensures that children are registered
|
79
|
+
# before parents
|
80
|
+
for child_node in tree_node.children:
|
81
|
+
child_id = child_node.object_id
|
82
|
+
self.store[child_id].ref_count += 1
|
83
|
+
|
84
|
+
# Add the object ID to the run's mapping
|
85
|
+
self.run_objects_mapping[run_id].add(obj_id)
|
86
|
+
|
87
|
+
# Add to the list of new objects
|
88
|
+
new_objects.append(obj_id)
|
89
|
+
else:
|
90
|
+
# Object is in store, retrieve it
|
91
|
+
obj_entry = self.store[obj_id]
|
92
|
+
|
93
|
+
# Add to the list of new objects if not available
|
94
|
+
if not obj_entry.is_available:
|
95
|
+
new_objects.append(obj_id)
|
96
|
+
|
97
|
+
# If the object is already registered but not in this run,
|
98
|
+
# add the run ID to its runs
|
99
|
+
if obj_id not in self.run_objects_mapping[run_id]:
|
100
|
+
obj_entry.runs.add(run_id)
|
101
|
+
self.run_objects_mapping[run_id].add(obj_id)
|
45
102
|
|
46
103
|
return new_objects
|
47
104
|
|
48
105
|
def put(self, object_id: str, object_content: bytes) -> None:
|
49
106
|
"""Put an object into the store."""
|
50
|
-
# Only allow adding the object if it has been preregistered
|
51
|
-
if object_id not in self.store:
|
52
|
-
raise NoObjectInStoreError(
|
53
|
-
f"Object with ID '{object_id}' was not pre-registered."
|
54
|
-
)
|
55
|
-
|
56
107
|
if self.verify:
|
57
108
|
# Verify object_id and object_content match
|
58
109
|
object_id_from_content = get_object_id(object_content)
|
@@ -62,41 +113,117 @@ class InMemoryObjectStore(ObjectStore):
|
|
62
113
|
# Validate object content
|
63
114
|
validate_object_content(content=object_content)
|
64
115
|
|
65
|
-
|
66
|
-
|
67
|
-
|
116
|
+
with self.lock_store:
|
117
|
+
# Only allow adding the object if it has been preregistered
|
118
|
+
if object_id not in self.store:
|
119
|
+
raise NoObjectInStoreError(
|
120
|
+
f"Object with ID '{object_id}' was not pre-registered."
|
121
|
+
)
|
122
|
+
|
123
|
+
# Return if object is already present in the store
|
124
|
+
if self.store[object_id].is_available:
|
125
|
+
return
|
68
126
|
|
69
|
-
|
127
|
+
# Update the object entry in the store
|
128
|
+
self.store[object_id].content = object_content
|
129
|
+
self.store[object_id].is_available = True
|
70
130
|
|
71
131
|
def set_message_descendant_ids(
|
72
132
|
self, msg_object_id: str, descendant_ids: list[str]
|
73
133
|
) -> None:
|
74
134
|
"""Store the mapping from a ``Message`` object ID to the object IDs of its
|
75
135
|
descendants."""
|
76
|
-
self.
|
136
|
+
with self.lock_msg_mapping:
|
137
|
+
self.msg_descendant_objects_mapping[msg_object_id] = descendant_ids
|
77
138
|
|
78
139
|
def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
|
79
140
|
"""Retrieve the object IDs of all descendants of a given Message."""
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
141
|
+
with self.lock_msg_mapping:
|
142
|
+
if msg_object_id not in self.msg_descendant_objects_mapping:
|
143
|
+
raise NoObjectInStoreError(
|
144
|
+
f"No message registered in Object Store with ID '{msg_object_id}'. "
|
145
|
+
"Mapping to descendants could not be found."
|
146
|
+
)
|
147
|
+
return self.msg_descendant_objects_mapping[msg_object_id]
|
148
|
+
|
149
|
+
def delete_message_descendant_ids(self, msg_object_id: str) -> None:
|
150
|
+
"""Delete the mapping from a ``Message`` object ID to its descendants."""
|
151
|
+
with self.lock_msg_mapping:
|
152
|
+
self.msg_descendant_objects_mapping.pop(msg_object_id, None)
|
86
153
|
|
87
154
|
def get(self, object_id: str) -> Optional[bytes]:
|
88
155
|
"""Get an object from the store."""
|
89
|
-
|
156
|
+
with self.lock_store:
|
157
|
+
# Check if the object ID is pre-registered
|
158
|
+
if object_id not in self.store:
|
159
|
+
return None
|
160
|
+
|
161
|
+
# Return content (if not yet available, it will b"")
|
162
|
+
return self.store[object_id].content
|
90
163
|
|
91
164
|
def delete(self, object_id: str) -> None:
|
92
|
-
"""Delete an object from the store."""
|
93
|
-
|
94
|
-
|
165
|
+
"""Delete an object and its unreferenced descendants from the store."""
|
166
|
+
with self.lock_store:
|
167
|
+
# If the object is not in the store, nothing to delete
|
168
|
+
if (object_entry := self.store.get(object_id)) is None:
|
169
|
+
return
|
170
|
+
|
171
|
+
# Delete the object if it has no references left
|
172
|
+
if object_entry.ref_count == 0:
|
173
|
+
del self.store[object_id]
|
174
|
+
|
175
|
+
# Remove the object from the run's mapping
|
176
|
+
for run_id in object_entry.runs:
|
177
|
+
self.run_objects_mapping[run_id].discard(object_id)
|
178
|
+
|
179
|
+
# Decrease the reference count of its children
|
180
|
+
children_ids = get_object_children_ids_from_object_content(
|
181
|
+
object_entry.content
|
182
|
+
)
|
183
|
+
for child_id in children_ids:
|
184
|
+
self.store[child_id].ref_count -= 1
|
185
|
+
|
186
|
+
# Recursively try to delete the child object
|
187
|
+
self.delete(child_id)
|
188
|
+
|
189
|
+
def delete_objects_in_run(self, run_id: int) -> None:
|
190
|
+
"""Delete all objects that were registered in a specific run."""
|
191
|
+
with self.lock_store:
|
192
|
+
if run_id not in self.run_objects_mapping:
|
193
|
+
return
|
194
|
+
for object_id in list(self.run_objects_mapping[run_id]):
|
195
|
+
# Check if the object is still in the store
|
196
|
+
if (object_entry := self.store.get(object_id)) is None:
|
197
|
+
continue
|
198
|
+
|
199
|
+
# Remove the run ID from the object's runs
|
200
|
+
object_entry.runs.discard(run_id)
|
201
|
+
|
202
|
+
# Only message objects are allowed to have a `ref_count` of 0,
|
203
|
+
# and every message object must have a `ref_count` of 0
|
204
|
+
if object_entry.ref_count == 0:
|
205
|
+
# Delete the message object and its unreferenced descendants
|
206
|
+
self.delete(object_id)
|
207
|
+
|
208
|
+
# Delete the message's descendants mapping
|
209
|
+
self.delete_message_descendant_ids(object_id)
|
210
|
+
|
211
|
+
# Remove the run from the mapping
|
212
|
+
del self.run_objects_mapping[run_id]
|
95
213
|
|
96
214
|
def clear(self) -> None:
|
97
215
|
"""Clear the store."""
|
98
|
-
self.
|
216
|
+
with self.lock_store:
|
217
|
+
self.store.clear()
|
218
|
+
self.msg_descendant_objects_mapping.clear()
|
219
|
+
self.run_objects_mapping.clear()
|
99
220
|
|
100
221
|
def __contains__(self, object_id: str) -> bool:
|
101
222
|
"""Check if an object_id is in the store."""
|
102
|
-
|
223
|
+
with self.lock_store:
|
224
|
+
return object_id in self.store
|
225
|
+
|
226
|
+
def __len__(self) -> int:
|
227
|
+
"""Get the number of objects in the store."""
|
228
|
+
with self.lock_store:
|
229
|
+
return len(self.store)
|
@@ -18,6 +18,8 @@
|
|
18
18
|
import abc
|
19
19
|
from typing import Optional
|
20
20
|
|
21
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
22
|
+
|
21
23
|
|
22
24
|
class NoObjectInStoreError(Exception):
|
23
25
|
"""Error when trying to access an element in the ObjectStore that does not exist."""
|
@@ -39,20 +41,23 @@ class ObjectStore(abc.ABC):
|
|
39
41
|
"""
|
40
42
|
|
41
43
|
@abc.abstractmethod
|
42
|
-
def preregister(self,
|
44
|
+
def preregister(self, run_id: int, object_tree: ObjectTree) -> list[str]:
|
43
45
|
"""Identify and preregister missing objects in the `ObjectStore`.
|
44
46
|
|
45
47
|
Parameters
|
46
48
|
----------
|
47
|
-
|
48
|
-
|
49
|
-
|
49
|
+
run_id : int
|
50
|
+
The ID of the run for which to preregister objects.
|
51
|
+
object_tree : ObjectTree
|
52
|
+
The object tree containing the IDs of objects to preregister.
|
53
|
+
This tree should contain all objects that are expected to be
|
54
|
+
stored in the `ObjectStore`.
|
50
55
|
|
51
56
|
Returns
|
52
57
|
-------
|
53
58
|
list[str]
|
54
|
-
A list of object IDs that were not
|
55
|
-
|
59
|
+
A list of object IDs that were either not previously preregistered
|
60
|
+
in the `ObjectStore`, or were preregistered but are not yet available.
|
56
61
|
"""
|
57
62
|
|
58
63
|
@abc.abstractmethod
|
@@ -84,12 +89,34 @@ class ObjectStore(abc.ABC):
|
|
84
89
|
|
85
90
|
@abc.abstractmethod
|
86
91
|
def delete(self, object_id: str) -> None:
|
87
|
-
"""Delete an object from the store.
|
92
|
+
"""Delete an object and its unreferenced descendants from the store.
|
93
|
+
|
94
|
+
This method attempts to recursively delete the specified object and its
|
95
|
+
descendants, if they are not referenced by any other object.
|
88
96
|
|
89
97
|
Parameters
|
90
98
|
----------
|
91
99
|
object_id : str
|
92
100
|
The object_id under which the object is stored.
|
101
|
+
|
102
|
+
Notes
|
103
|
+
-----
|
104
|
+
The object of the given object_id will NOT be deleted if it is still referenced
|
105
|
+
by any other object in the store.
|
106
|
+
"""
|
107
|
+
|
108
|
+
@abc.abstractmethod
|
109
|
+
def delete_objects_in_run(self, run_id: int) -> None:
|
110
|
+
"""Delete all objects that were registered in a specific run.
|
111
|
+
|
112
|
+
Parameters
|
113
|
+
----------
|
114
|
+
run_id : int
|
115
|
+
The ID of the run for which to delete objects.
|
116
|
+
|
117
|
+
Notes
|
118
|
+
-----
|
119
|
+
Objects that are still registered in other runs will NOT be deleted.
|
93
120
|
"""
|
94
121
|
|
95
122
|
@abc.abstractmethod
|
@@ -129,6 +156,16 @@ class ObjectStore(abc.ABC):
|
|
129
156
|
A list of object IDs of all descendant objects of the ``Message``.
|
130
157
|
"""
|
131
158
|
|
159
|
+
@abc.abstractmethod
|
160
|
+
def delete_message_descendant_ids(self, msg_object_id: str) -> None:
|
161
|
+
"""Delete the mapping from a ``Message`` object ID to its descendants.
|
162
|
+
|
163
|
+
Parameters
|
164
|
+
----------
|
165
|
+
msg_object_id : str
|
166
|
+
The object ID of the ``Message``.
|
167
|
+
"""
|
168
|
+
|
132
169
|
@abc.abstractmethod
|
133
170
|
def __contains__(self, object_id: str) -> bool:
|
134
171
|
"""Check if an object_id is in the store.
|
@@ -143,3 +180,13 @@ class ObjectStore(abc.ABC):
|
|
143
180
|
bool
|
144
181
|
True if the object_id is in the store, False otherwise.
|
145
182
|
"""
|
183
|
+
|
184
|
+
@abc.abstractmethod
|
185
|
+
def __len__(self) -> int:
|
186
|
+
"""Return the number of objects in the store.
|
187
|
+
|
188
|
+
Returns
|
189
|
+
-------
|
190
|
+
int
|
191
|
+
The number of objects currently stored.
|
192
|
+
"""
|
flwr/superexec/deployment.py
CHANGED
@@ -132,6 +132,7 @@ class DeploymentEngine(Executor):
|
|
132
132
|
self,
|
133
133
|
fab: Fab,
|
134
134
|
override_config: UserConfig,
|
135
|
+
flwr_aid: Optional[str],
|
135
136
|
) -> int:
|
136
137
|
fab_hash = self.ffs.put(fab.content, {})
|
137
138
|
if fab_hash != fab.hash_str:
|
@@ -141,7 +142,7 @@ class DeploymentEngine(Executor):
|
|
141
142
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
142
143
|
|
143
144
|
run_id = self.linkstate.create_run(
|
144
|
-
fab_id, fab_version, fab_hash, override_config, ConfigRecord()
|
145
|
+
fab_id, fab_version, fab_hash, override_config, ConfigRecord(), flwr_aid
|
145
146
|
)
|
146
147
|
return run_id
|
147
148
|
|
@@ -161,6 +162,7 @@ class DeploymentEngine(Executor):
|
|
161
162
|
fab_file: bytes,
|
162
163
|
override_config: UserConfig,
|
163
164
|
federation_options: ConfigRecord,
|
165
|
+
flwr_aid: Optional[str],
|
164
166
|
) -> Optional[int]:
|
165
167
|
"""Start run using the Flower Deployment Engine."""
|
166
168
|
run_id = None
|
@@ -168,7 +170,9 @@ class DeploymentEngine(Executor):
|
|
168
170
|
|
169
171
|
# Call SuperLink to create run
|
170
172
|
run_id = self._create_run(
|
171
|
-
Fab(hashlib.sha256(fab_file).hexdigest(), fab_file),
|
173
|
+
Fab(hashlib.sha256(fab_file).hexdigest(), fab_file),
|
174
|
+
override_config,
|
175
|
+
flwr_aid,
|
172
176
|
)
|
173
177
|
|
174
178
|
# Register context for the Run
|
flwr/superexec/exec_grpc.py
CHANGED
@@ -29,6 +29,7 @@ from flwr.common.typing import UserConfig
|
|
29
29
|
from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
|
30
30
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
31
31
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
32
|
+
from flwr.supercore.object_store import ObjectStoreFactory
|
32
33
|
from flwr.superexec.exec_event_log_interceptor import ExecEventLogInterceptor
|
33
34
|
from flwr.superexec.exec_user_auth_interceptor import ExecUserAuthInterceptor
|
34
35
|
|
@@ -42,6 +43,7 @@ def run_exec_api_grpc(
|
|
42
43
|
executor: Executor,
|
43
44
|
state_factory: LinkStateFactory,
|
44
45
|
ffs_factory: FfsFactory,
|
46
|
+
objectstore_factory: ObjectStoreFactory,
|
45
47
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
46
48
|
config: UserConfig,
|
47
49
|
auth_plugin: Optional[ExecAuthPlugin] = None,
|
@@ -54,6 +56,7 @@ def run_exec_api_grpc(
|
|
54
56
|
exec_servicer: grpc.Server = ExecServicer(
|
55
57
|
linkstate_factory=state_factory,
|
56
58
|
ffs_factory=ffs_factory,
|
59
|
+
objectstore_factory=objectstore_factory,
|
57
60
|
executor=executor,
|
58
61
|
auth_plugin=auth_plugin,
|
59
62
|
)
|
flwr/superexec/exec_servicer.py
CHANGED
@@ -18,20 +18,25 @@
|
|
18
18
|
import time
|
19
19
|
from collections.abc import Generator
|
20
20
|
from logging import ERROR, INFO
|
21
|
-
from typing import Any, Optional
|
21
|
+
from typing import Any, Optional, cast
|
22
22
|
|
23
23
|
import grpc
|
24
24
|
|
25
25
|
from flwr.common import now
|
26
26
|
from flwr.common.auth_plugin import ExecAuthPlugin
|
27
|
-
from flwr.common.constant import
|
27
|
+
from flwr.common.constant import (
|
28
|
+
LOG_STREAM_INTERVAL,
|
29
|
+
RUN_ID_NOT_FOUND_MESSAGE,
|
30
|
+
Status,
|
31
|
+
SubStatus,
|
32
|
+
)
|
28
33
|
from flwr.common.logger import log
|
29
34
|
from flwr.common.serde import (
|
30
35
|
config_record_from_proto,
|
31
36
|
run_to_proto,
|
32
37
|
user_config_from_proto,
|
33
38
|
)
|
34
|
-
from flwr.common.typing import RunStatus
|
39
|
+
from flwr.common.typing import Run, RunStatus
|
35
40
|
from flwr.proto import exec_pb2_grpc # pylint: disable=E0611
|
36
41
|
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
37
42
|
GetAuthTokensRequest,
|
@@ -49,22 +54,26 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
49
54
|
)
|
50
55
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
51
56
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
57
|
+
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
|
52
58
|
|
59
|
+
from .exec_user_auth_interceptor import shared_account_info
|
53
60
|
from .executor import Executor
|
54
61
|
|
55
62
|
|
56
63
|
class ExecServicer(exec_pb2_grpc.ExecServicer):
|
57
64
|
"""SuperExec API servicer."""
|
58
65
|
|
59
|
-
def __init__(
|
66
|
+
def __init__( # pylint: disable=R0913, R0917
|
60
67
|
self,
|
61
68
|
linkstate_factory: LinkStateFactory,
|
62
69
|
ffs_factory: FfsFactory,
|
70
|
+
objectstore_factory: ObjectStoreFactory,
|
63
71
|
executor: Executor,
|
64
72
|
auth_plugin: Optional[ExecAuthPlugin] = None,
|
65
73
|
) -> None:
|
66
74
|
self.linkstate_factory = linkstate_factory
|
67
75
|
self.ffs_factory = ffs_factory
|
76
|
+
self.objectstore_factory = objectstore_factory
|
68
77
|
self.executor = executor
|
69
78
|
self.executor.initialize(linkstate_factory, ffs_factory)
|
70
79
|
self.auth_plugin = auth_plugin
|
@@ -75,10 +84,12 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
75
84
|
"""Create run ID."""
|
76
85
|
log(INFO, "ExecServicer.StartRun")
|
77
86
|
|
87
|
+
flwr_aid = shared_account_info.get().flwr_aid if self.auth_plugin else None
|
78
88
|
run_id = self.executor.start_run(
|
79
89
|
request.fab.content,
|
80
90
|
user_config_from_proto(request.override_config),
|
81
91
|
config_record_from_proto(request.federation_options),
|
92
|
+
flwr_aid,
|
82
93
|
)
|
83
94
|
|
84
95
|
if run_id is None:
|
@@ -94,12 +105,20 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
94
105
|
log(INFO, "ExecServicer.StreamLogs")
|
95
106
|
state = self.linkstate_factory.state()
|
96
107
|
|
97
|
-
# Retrieve run ID
|
108
|
+
# Retrieve run ID and run
|
98
109
|
run_id = request.run_id
|
110
|
+
run = state.get_run(run_id)
|
99
111
|
|
100
112
|
# Exit if `run_id` not found
|
101
|
-
if not
|
102
|
-
context.abort(grpc.StatusCode.NOT_FOUND,
|
113
|
+
if not run:
|
114
|
+
context.abort(grpc.StatusCode.NOT_FOUND, RUN_ID_NOT_FOUND_MESSAGE)
|
115
|
+
|
116
|
+
# If user auth is enabled, check if `flwr_aid` matches the run's `flwr_aid`
|
117
|
+
if self.auth_plugin:
|
118
|
+
flwr_aid = shared_account_info.get().flwr_aid
|
119
|
+
_check_flwr_aid_in_run(
|
120
|
+
flwr_aid=flwr_aid, run=cast(Run, run), context=context
|
121
|
+
)
|
103
122
|
|
104
123
|
after_timestamp = request.after_timestamp + 1e-6
|
105
124
|
while context.is_active():
|
@@ -118,7 +137,10 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
118
137
|
# is returned at this point and the server ends the stream.
|
119
138
|
run_status = state.get_run_status({run_id})[run_id]
|
120
139
|
if run_status.status == Status.FINISHED:
|
121
|
-
log(INFO, "All logs for run ID `%s` returned",
|
140
|
+
log(INFO, "All logs for run ID `%s` returned", run_id)
|
141
|
+
|
142
|
+
# Delete objects of the run from the object store
|
143
|
+
self.objectstore_factory.store().delete_objects_in_run(run_id)
|
122
144
|
break
|
123
145
|
|
124
146
|
time.sleep(LOG_STREAM_INTERVAL) # Sleep briefly to avoid busy waiting
|
@@ -130,11 +152,44 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
130
152
|
log(INFO, "ExecServicer.List")
|
131
153
|
state = self.linkstate_factory.state()
|
132
154
|
|
133
|
-
#
|
155
|
+
# Build a set of run IDs for `flwr ls --runs`
|
134
156
|
if not request.HasField("run_id"):
|
135
|
-
|
136
|
-
|
137
|
-
|
157
|
+
if self.auth_plugin:
|
158
|
+
# If no `run_id` is specified and user auth is enabled,
|
159
|
+
# return run IDs for the authenticated user
|
160
|
+
flwr_aid = shared_account_info.get().flwr_aid
|
161
|
+
if flwr_aid is None:
|
162
|
+
context.abort(
|
163
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
164
|
+
"️⛔️ User authentication is enabled, but `flwr_aid` is None",
|
165
|
+
)
|
166
|
+
run_ids = state.get_run_ids(flwr_aid=flwr_aid)
|
167
|
+
else:
|
168
|
+
# If no `run_id` is specified and no user auth is enabled,
|
169
|
+
# return all run IDs
|
170
|
+
run_ids = state.get_run_ids(None)
|
171
|
+
# Build a set of run IDs for `flwr ls --run-id <run_id>`
|
172
|
+
else:
|
173
|
+
# Retrieve run ID and run
|
174
|
+
run_id = request.run_id
|
175
|
+
run = state.get_run(run_id)
|
176
|
+
|
177
|
+
# Exit if `run_id` not found
|
178
|
+
if not run:
|
179
|
+
context.abort(grpc.StatusCode.NOT_FOUND, RUN_ID_NOT_FOUND_MESSAGE)
|
180
|
+
|
181
|
+
# If user auth is enabled, check if `flwr_aid` matches the run's `flwr_aid`
|
182
|
+
if self.auth_plugin:
|
183
|
+
flwr_aid = shared_account_info.get().flwr_aid
|
184
|
+
_check_flwr_aid_in_run(
|
185
|
+
flwr_aid=flwr_aid, run=cast(Run, run), context=context
|
186
|
+
)
|
187
|
+
|
188
|
+
run_ids = {run_id}
|
189
|
+
|
190
|
+
# Init the object store
|
191
|
+
store = self.objectstore_factory.store()
|
192
|
+
return _create_list_runs_response(run_ids, state, store)
|
138
193
|
|
139
194
|
def StopRun(
|
140
195
|
self, request: StopRunRequest, context: grpc.ServicerContext
|
@@ -143,30 +198,42 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
143
198
|
log(INFO, "ExecServicer.StopRun")
|
144
199
|
state = self.linkstate_factory.state()
|
145
200
|
|
201
|
+
# Retrieve run ID and run
|
202
|
+
run_id = request.run_id
|
203
|
+
run = state.get_run(run_id)
|
204
|
+
|
146
205
|
# Exit if `run_id` not found
|
147
|
-
if not
|
148
|
-
context.abort(
|
149
|
-
|
206
|
+
if not run:
|
207
|
+
context.abort(grpc.StatusCode.NOT_FOUND, RUN_ID_NOT_FOUND_MESSAGE)
|
208
|
+
|
209
|
+
# If user auth is enabled, check if `flwr_aid` matches the run's `flwr_aid`
|
210
|
+
if self.auth_plugin:
|
211
|
+
flwr_aid = shared_account_info.get().flwr_aid
|
212
|
+
_check_flwr_aid_in_run(
|
213
|
+
flwr_aid=flwr_aid, run=cast(Run, run), context=context
|
150
214
|
)
|
151
215
|
|
152
|
-
run_status = state.get_run_status({
|
216
|
+
run_status = state.get_run_status({run_id})[run_id]
|
153
217
|
if run_status.status == Status.FINISHED:
|
154
218
|
context.abort(
|
155
219
|
grpc.StatusCode.FAILED_PRECONDITION,
|
156
|
-
f"Run ID {
|
220
|
+
f"Run ID {run_id} is already finished",
|
157
221
|
)
|
158
222
|
|
159
223
|
update_success = state.update_run_status(
|
160
|
-
run_id=
|
224
|
+
run_id=run_id,
|
161
225
|
new_status=RunStatus(Status.FINISHED, SubStatus.STOPPED, ""),
|
162
226
|
)
|
163
227
|
|
164
228
|
if update_success:
|
165
|
-
message_ids: set[str] = state.get_message_ids_from_run_id(
|
229
|
+
message_ids: set[str] = state.get_message_ids_from_run_id(run_id)
|
166
230
|
|
167
231
|
# Delete Messages and their replies for the `run_id`
|
168
232
|
state.delete_messages(message_ids)
|
169
233
|
|
234
|
+
# Delete objects of the run from the object store
|
235
|
+
self.objectstore_factory.store().delete_objects_in_run(run_id)
|
236
|
+
|
170
237
|
return StopRunResponse(success=update_success)
|
171
238
|
|
172
239
|
def GetLoginDetails(
|
@@ -221,10 +288,46 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
221
288
|
)
|
222
289
|
|
223
290
|
|
224
|
-
def _create_list_runs_response(
|
291
|
+
def _create_list_runs_response(
|
292
|
+
run_ids: set[int], state: LinkState, store: ObjectStore
|
293
|
+
) -> ListRunsResponse:
|
225
294
|
"""Create response for `flwr ls --runs` and `flwr ls --run-id <run_id>`."""
|
226
|
-
run_dict = {run_id:
|
295
|
+
run_dict = {run_id: run for run_id in run_ids if (run := state.get_run(run_id))}
|
296
|
+
|
297
|
+
# Delete objects of finished runs from the object store
|
298
|
+
for run_id, run in run_dict.items():
|
299
|
+
if run.status.status == Status.FINISHED:
|
300
|
+
store.delete_objects_in_run(run_id)
|
301
|
+
|
227
302
|
return ListRunsResponse(
|
228
|
-
run_dict={run_id: run_to_proto(run) for run_id, run in run_dict.items()
|
303
|
+
run_dict={run_id: run_to_proto(run) for run_id, run in run_dict.items()},
|
229
304
|
now=now().isoformat(),
|
230
305
|
)
|
306
|
+
|
307
|
+
|
308
|
+
def _check_flwr_aid_in_run(
|
309
|
+
flwr_aid: Optional[str], run: Run, context: grpc.ServicerContext
|
310
|
+
) -> None:
|
311
|
+
"""Guard clause to check if `flwr_aid` matches the run's `flwr_aid`."""
|
312
|
+
# `flwr_aid` must not be None. Abort if it is None.
|
313
|
+
if flwr_aid is None:
|
314
|
+
context.abort(
|
315
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
316
|
+
"️⛔️ User authentication is enabled, but `flwr_aid` is None",
|
317
|
+
)
|
318
|
+
|
319
|
+
# `run.flwr_aid` must not be an empty string. Abort if it is empty.
|
320
|
+
run_flwr_aid = run.flwr_aid
|
321
|
+
if not run_flwr_aid:
|
322
|
+
context.abort(
|
323
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
324
|
+
"⛔️ User authentication is enabled, but the run is not associated "
|
325
|
+
"with a `flwr_aid`.",
|
326
|
+
)
|
327
|
+
|
328
|
+
# Exit if `flwr_aid` does not match the run's `flwr_aid`
|
329
|
+
if run_flwr_aid != flwr_aid:
|
330
|
+
context.abort(
|
331
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
332
|
+
"⛔️ Run ID does not belong to the user",
|
333
|
+
)
|
flwr/superexec/executor.py
CHANGED
@@ -74,6 +74,7 @@ class Executor(ABC):
|
|
74
74
|
fab_file: bytes,
|
75
75
|
override_config: UserConfig,
|
76
76
|
federation_options: ConfigRecord,
|
77
|
+
flwr_aid: Optional[str],
|
77
78
|
) -> Optional[int]:
|
78
79
|
"""Start a run using the given Flower FAB ID and version.
|
79
80
|
|
@@ -88,6 +89,9 @@ class Executor(ABC):
|
|
88
89
|
The config overrides dict sent by the user (using `flwr run`).
|
89
90
|
federation_options: ConfigRecord
|
90
91
|
The federation options sent by the user (using `flwr run`).
|
92
|
+
flwr_aid : Optional[str]
|
93
|
+
The Flower Account ID of the user starting the run, if authentication is
|
94
|
+
enabled.
|
91
95
|
|
92
96
|
Returns
|
93
97
|
-------
|