flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +94 -59
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +48 -49
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +38 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +15 -8
- flwr/client/grpc_rere_client/connection.py +142 -97
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +176 -103
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +39 -8
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/grpc.py +12 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_protobuf_utils.py +141 -0
- flwr/common/inflatable_utils.py +508 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +402 -0
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -211
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +28 -185
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +24 -27
- flwr/proto/fleet_pb2.pyi +19 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +9 -23
- flwr/proto/serverappio_pb2.pyi +0 -110
- flwr/proto/serverappio_pb2_grpc.py +177 -72
- flwr/proto/serverappio_pb2_grpc.pyi +75 -33
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +69 -187
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +6 -2
- flwr/server/grid/grpc_grid.py +148 -41
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +45 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +9 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +25 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +170 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
- flwr/superexec/deployment.py +7 -4
- flwr/superexec/exec_event_log_interceptor.py +8 -4
- flwr/superexec/exec_grpc.py +25 -5
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +135 -24
- flwr/superexec/exec_user_auth_interceptor.py +45 -8
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -3
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/supernode/cli/__init__.py +24 -0
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
- flwr/supernode/cli/flwr_clientapp.py +88 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
- flwr/supernode/nodestate/nodestate.py +227 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +22 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
- flwr/supernode/start_client_internal.py +589 -0
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
- flwr/client/clientapp/clientappio_servicer.py +0 -244
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
|
@@ -14,12 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Fleet API message handlers."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
from logging import ERROR
|
|
18
18
|
from typing import Optional
|
|
19
|
-
from uuid import UUID
|
|
20
19
|
|
|
21
|
-
from flwr.common import Message
|
|
20
|
+
from flwr.common import Message, log
|
|
22
21
|
from flwr.common.constant import Status
|
|
22
|
+
from flwr.common.inflatable import UnexpectedObjectContentError
|
|
23
23
|
from flwr.common.serde import (
|
|
24
24
|
fab_to_proto,
|
|
25
25
|
message_from_proto,
|
|
@@ -33,23 +33,35 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
33
33
|
CreateNodeResponse,
|
|
34
34
|
DeleteNodeRequest,
|
|
35
35
|
DeleteNodeResponse,
|
|
36
|
-
PingRequest,
|
|
37
|
-
PingResponse,
|
|
38
36
|
PullMessagesRequest,
|
|
39
37
|
PullMessagesResponse,
|
|
40
38
|
PushMessagesRequest,
|
|
41
39
|
PushMessagesResponse,
|
|
42
40
|
Reconnect,
|
|
43
41
|
)
|
|
42
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
43
|
+
SendNodeHeartbeatRequest,
|
|
44
|
+
SendNodeHeartbeatResponse,
|
|
45
|
+
)
|
|
46
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
47
|
+
ConfirmMessageReceivedRequest,
|
|
48
|
+
ConfirmMessageReceivedResponse,
|
|
49
|
+
PullObjectRequest,
|
|
50
|
+
PullObjectResponse,
|
|
51
|
+
PushObjectRequest,
|
|
52
|
+
PushObjectResponse,
|
|
53
|
+
)
|
|
44
54
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
45
55
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
46
56
|
GetRunRequest,
|
|
47
57
|
GetRunResponse,
|
|
48
58
|
Run,
|
|
49
59
|
)
|
|
50
|
-
from flwr.server.superlink.ffs.ffs import Ffs
|
|
51
60
|
from flwr.server.superlink.linkstate import LinkState
|
|
52
61
|
from flwr.server.superlink.utils import check_abort
|
|
62
|
+
from flwr.supercore.ffs import Ffs
|
|
63
|
+
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
|
|
64
|
+
from flwr.supercore.object_store.utils import store_mapping_and_register_objects
|
|
53
65
|
|
|
54
66
|
|
|
55
67
|
def create_node(
|
|
@@ -58,7 +70,7 @@ def create_node(
|
|
|
58
70
|
) -> CreateNodeResponse:
|
|
59
71
|
"""."""
|
|
60
72
|
# Create node
|
|
61
|
-
node_id = state.create_node(
|
|
73
|
+
node_id = state.create_node(heartbeat_interval=request.heartbeat_interval)
|
|
62
74
|
return CreateNodeResponse(node=Node(node_id=node_id))
|
|
63
75
|
|
|
64
76
|
|
|
@@ -73,17 +85,21 @@ def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeRespo
|
|
|
73
85
|
return DeleteNodeResponse()
|
|
74
86
|
|
|
75
87
|
|
|
76
|
-
def
|
|
77
|
-
request:
|
|
88
|
+
def send_node_heartbeat(
|
|
89
|
+
request: SendNodeHeartbeatRequest, # pylint: disable=unused-argument
|
|
78
90
|
state: LinkState, # pylint: disable=unused-argument
|
|
79
|
-
) ->
|
|
91
|
+
) -> SendNodeHeartbeatResponse:
|
|
80
92
|
"""."""
|
|
81
|
-
res = state.
|
|
82
|
-
|
|
93
|
+
res = state.acknowledge_node_heartbeat(
|
|
94
|
+
request.node.node_id, request.heartbeat_interval
|
|
95
|
+
)
|
|
96
|
+
return SendNodeHeartbeatResponse(success=res)
|
|
83
97
|
|
|
84
98
|
|
|
85
99
|
def pull_messages(
|
|
86
|
-
request: PullMessagesRequest,
|
|
100
|
+
request: PullMessagesRequest,
|
|
101
|
+
state: LinkState,
|
|
102
|
+
store: ObjectStore,
|
|
87
103
|
) -> PullMessagesResponse:
|
|
88
104
|
"""Pull Messages handler."""
|
|
89
105
|
# Get node_id if client node is not anonymous
|
|
@@ -95,14 +111,28 @@ def pull_messages(
|
|
|
95
111
|
|
|
96
112
|
# Convert to Messages
|
|
97
113
|
msg_proto = []
|
|
114
|
+
trees = []
|
|
98
115
|
for msg in message_list:
|
|
99
|
-
|
|
116
|
+
try:
|
|
117
|
+
# Retrieve Message object tree from ObjectStore
|
|
118
|
+
msg_object_id = msg.metadata.message_id
|
|
119
|
+
obj_tree = store.get_object_tree(msg_object_id)
|
|
120
|
+
|
|
121
|
+
# Add Message and its object tree to the response
|
|
122
|
+
msg_proto.append(message_to_proto(msg))
|
|
123
|
+
trees.append(obj_tree)
|
|
124
|
+
except NoObjectInStoreError as e:
|
|
125
|
+
log(ERROR, e.message)
|
|
126
|
+
# Delete message ins from state
|
|
127
|
+
state.delete_messages(message_ins_ids={msg_object_id})
|
|
100
128
|
|
|
101
|
-
return PullMessagesResponse(messages_list=msg_proto)
|
|
129
|
+
return PullMessagesResponse(messages_list=msg_proto, message_object_trees=trees)
|
|
102
130
|
|
|
103
131
|
|
|
104
132
|
def push_messages(
|
|
105
|
-
request: PushMessagesRequest,
|
|
133
|
+
request: PushMessagesRequest,
|
|
134
|
+
state: LinkState,
|
|
135
|
+
store: ObjectStore,
|
|
106
136
|
) -> PushMessagesResponse:
|
|
107
137
|
"""Push Messages handler."""
|
|
108
138
|
# Convert Message from proto
|
|
@@ -113,22 +143,29 @@ def push_messages(
|
|
|
113
143
|
msg.metadata.run_id,
|
|
114
144
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
115
145
|
state,
|
|
146
|
+
store,
|
|
116
147
|
)
|
|
117
148
|
if abort_msg:
|
|
118
149
|
raise InvalidRunStatusException(abort_msg)
|
|
119
150
|
|
|
120
151
|
# Store Message in State
|
|
121
|
-
message_id: Optional[
|
|
152
|
+
message_id: Optional[str] = state.store_message_res(message=msg)
|
|
153
|
+
|
|
154
|
+
# Store Message object to descendants mapping and preregister objects
|
|
155
|
+
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
|
122
156
|
|
|
123
157
|
# Build response
|
|
124
158
|
response = PushMessagesResponse(
|
|
125
159
|
reconnect=Reconnect(reconnect=5),
|
|
126
160
|
results={str(message_id): 0},
|
|
161
|
+
objects_to_push=objects_to_push,
|
|
127
162
|
)
|
|
128
163
|
return response
|
|
129
164
|
|
|
130
165
|
|
|
131
|
-
def get_run(
|
|
166
|
+
def get_run(
|
|
167
|
+
request: GetRunRequest, state: LinkState, store: ObjectStore
|
|
168
|
+
) -> GetRunResponse:
|
|
132
169
|
"""Get run information."""
|
|
133
170
|
run = state.get_run(request.run_id)
|
|
134
171
|
|
|
@@ -140,6 +177,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
|
140
177
|
request.run_id,
|
|
141
178
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
142
179
|
state,
|
|
180
|
+
store,
|
|
143
181
|
)
|
|
144
182
|
if abort_msg:
|
|
145
183
|
raise InvalidRunStatusException(abort_msg)
|
|
@@ -156,7 +194,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
|
156
194
|
|
|
157
195
|
|
|
158
196
|
def get_fab(
|
|
159
|
-
request: GetFabRequest, ffs: Ffs, state: LinkState
|
|
197
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
|
|
160
198
|
) -> GetFabResponse:
|
|
161
199
|
"""Get FAB."""
|
|
162
200
|
# Abort if the run is not running
|
|
@@ -164,6 +202,7 @@ def get_fab(
|
|
|
164
202
|
request.run_id,
|
|
165
203
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
166
204
|
state,
|
|
205
|
+
store,
|
|
167
206
|
)
|
|
168
207
|
if abort_msg:
|
|
169
208
|
raise InvalidRunStatusException(abort_msg)
|
|
@@ -173,3 +212,75 @@ def get_fab(
|
|
|
173
212
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
174
213
|
|
|
175
214
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def push_object(
|
|
218
|
+
request: PushObjectRequest, state: LinkState, store: ObjectStore
|
|
219
|
+
) -> PushObjectResponse:
|
|
220
|
+
"""Push Object."""
|
|
221
|
+
abort_msg = check_abort(
|
|
222
|
+
request.run_id,
|
|
223
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
224
|
+
state,
|
|
225
|
+
store,
|
|
226
|
+
)
|
|
227
|
+
if abort_msg:
|
|
228
|
+
raise InvalidRunStatusException(abort_msg)
|
|
229
|
+
|
|
230
|
+
stored = False
|
|
231
|
+
try:
|
|
232
|
+
store.put(request.object_id, request.object_content)
|
|
233
|
+
stored = True
|
|
234
|
+
except (NoObjectInStoreError, ValueError) as e:
|
|
235
|
+
log(ERROR, str(e))
|
|
236
|
+
except UnexpectedObjectContentError as e:
|
|
237
|
+
# Object content is not valid
|
|
238
|
+
log(ERROR, str(e))
|
|
239
|
+
raise
|
|
240
|
+
return PushObjectResponse(stored=stored)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def pull_object(
|
|
244
|
+
request: PullObjectRequest, state: LinkState, store: ObjectStore
|
|
245
|
+
) -> PullObjectResponse:
|
|
246
|
+
"""Pull Object."""
|
|
247
|
+
abort_msg = check_abort(
|
|
248
|
+
request.run_id,
|
|
249
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
250
|
+
state,
|
|
251
|
+
store,
|
|
252
|
+
)
|
|
253
|
+
if abort_msg:
|
|
254
|
+
raise InvalidRunStatusException(abort_msg)
|
|
255
|
+
|
|
256
|
+
# Fetch from store
|
|
257
|
+
content = store.get(request.object_id)
|
|
258
|
+
if content is not None:
|
|
259
|
+
object_available = content != b""
|
|
260
|
+
return PullObjectResponse(
|
|
261
|
+
object_found=True,
|
|
262
|
+
object_available=object_available,
|
|
263
|
+
object_content=content,
|
|
264
|
+
)
|
|
265
|
+
return PullObjectResponse(object_found=False, object_available=False)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def confirm_message_received(
|
|
269
|
+
request: ConfirmMessageReceivedRequest,
|
|
270
|
+
state: LinkState,
|
|
271
|
+
store: ObjectStore,
|
|
272
|
+
) -> ConfirmMessageReceivedResponse:
|
|
273
|
+
"""Confirm message received handler."""
|
|
274
|
+
abort_msg = check_abort(
|
|
275
|
+
request.run_id,
|
|
276
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
277
|
+
state,
|
|
278
|
+
store,
|
|
279
|
+
)
|
|
280
|
+
if abort_msg:
|
|
281
|
+
raise InvalidRunStatusException(abort_msg)
|
|
282
|
+
|
|
283
|
+
# Delete the message object
|
|
284
|
+
store.delete(request.message_object_id)
|
|
285
|
+
|
|
286
|
+
return ConfirmMessageReceivedResponse()
|
|
@@ -29,18 +29,28 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
29
29
|
CreateNodeResponse,
|
|
30
30
|
DeleteNodeRequest,
|
|
31
31
|
DeleteNodeResponse,
|
|
32
|
-
PingRequest,
|
|
33
|
-
PingResponse,
|
|
34
32
|
PullMessagesRequest,
|
|
35
33
|
PullMessagesResponse,
|
|
36
34
|
PushMessagesRequest,
|
|
37
35
|
PushMessagesResponse,
|
|
38
36
|
)
|
|
37
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
38
|
+
SendNodeHeartbeatRequest,
|
|
39
|
+
SendNodeHeartbeatResponse,
|
|
40
|
+
)
|
|
41
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
42
|
+
ConfirmMessageReceivedRequest,
|
|
43
|
+
ConfirmMessageReceivedResponse,
|
|
44
|
+
PullObjectRequest,
|
|
45
|
+
PullObjectResponse,
|
|
46
|
+
PushObjectRequest,
|
|
47
|
+
PushObjectResponse,
|
|
48
|
+
)
|
|
39
49
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
|
-
from flwr.server.superlink.ffs.ffs import Ffs
|
|
41
|
-
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
42
50
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
43
51
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
52
|
+
from flwr.supercore.ffs import Ffs, FfsFactory
|
|
53
|
+
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
|
|
44
54
|
|
|
45
55
|
try:
|
|
46
56
|
from starlette.applications import Starlette
|
|
@@ -111,9 +121,10 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
|
|
111
121
|
"""Pull PullMessages."""
|
|
112
122
|
# Get state from app
|
|
113
123
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
124
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
|
114
125
|
|
|
115
126
|
# Handle message
|
|
116
|
-
return message_handler.pull_messages(request=request, state=state)
|
|
127
|
+
return message_handler.pull_messages(request=request, state=state, store=store)
|
|
117
128
|
|
|
118
129
|
|
|
119
130
|
@rest_request_response(PushMessagesRequest)
|
|
@@ -121,19 +132,44 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
|
121
132
|
"""Pull PushMessages."""
|
|
122
133
|
# Get state from app
|
|
123
134
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
135
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
|
136
|
+
|
|
137
|
+
# Handle message
|
|
138
|
+
return message_handler.push_messages(request=request, state=state, store=store)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@rest_request_response(PullObjectRequest)
|
|
142
|
+
async def pull_object(request: PullObjectRequest) -> PullObjectResponse:
|
|
143
|
+
"""Pull PullObject."""
|
|
144
|
+
# Get state from app
|
|
145
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
146
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
|
147
|
+
|
|
148
|
+
# Handle message
|
|
149
|
+
return message_handler.pull_object(request=request, state=state, store=store)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@rest_request_response(PushObjectRequest)
|
|
153
|
+
async def push_object(request: PushObjectRequest) -> PushObjectResponse:
|
|
154
|
+
"""Pull PushObject."""
|
|
155
|
+
# Get state from app
|
|
156
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
157
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
|
124
158
|
|
|
125
159
|
# Handle message
|
|
126
|
-
return message_handler.
|
|
160
|
+
return message_handler.push_object(request=request, state=state, store=store)
|
|
127
161
|
|
|
128
162
|
|
|
129
|
-
@rest_request_response(
|
|
130
|
-
async def
|
|
131
|
-
|
|
163
|
+
@rest_request_response(SendNodeHeartbeatRequest)
|
|
164
|
+
async def send_node_heartbeat(
|
|
165
|
+
request: SendNodeHeartbeatRequest,
|
|
166
|
+
) -> SendNodeHeartbeatResponse:
|
|
167
|
+
"""Send node heartbeat."""
|
|
132
168
|
# Get state from app
|
|
133
169
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
134
170
|
|
|
135
171
|
# Handle message
|
|
136
|
-
return message_handler.
|
|
172
|
+
return message_handler.send_node_heartbeat(request=request, state=state)
|
|
137
173
|
|
|
138
174
|
|
|
139
175
|
@rest_request_response(GetRunRequest)
|
|
@@ -141,9 +177,10 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
|
141
177
|
"""GetRun."""
|
|
142
178
|
# Get state from app
|
|
143
179
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
180
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
|
144
181
|
|
|
145
182
|
# Handle message
|
|
146
|
-
return message_handler.get_run(request=request, state=state)
|
|
183
|
+
return message_handler.get_run(request=request, state=state, store=store)
|
|
147
184
|
|
|
148
185
|
|
|
149
186
|
@rest_request_response(GetFabRequest)
|
|
@@ -154,9 +191,25 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
|
154
191
|
|
|
155
192
|
# Get state from app
|
|
156
193
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
194
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
|
195
|
+
|
|
196
|
+
# Handle message
|
|
197
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@rest_request_response(ConfirmMessageReceivedRequest)
|
|
201
|
+
async def confirm_message_received(
|
|
202
|
+
request: ConfirmMessageReceivedRequest,
|
|
203
|
+
) -> ConfirmMessageReceivedResponse:
|
|
204
|
+
"""Confirm message received."""
|
|
205
|
+
# Get state from app
|
|
206
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
207
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
|
157
208
|
|
|
158
209
|
# Handle message
|
|
159
|
-
return message_handler.
|
|
210
|
+
return message_handler.confirm_message_received(
|
|
211
|
+
request=request, state=state, store=store
|
|
212
|
+
)
|
|
160
213
|
|
|
161
214
|
|
|
162
215
|
routes = [
|
|
@@ -164,9 +217,16 @@ routes = [
|
|
|
164
217
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
|
165
218
|
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
|
166
219
|
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
|
167
|
-
Route("/api/v0/fleet/
|
|
220
|
+
Route("/api/v0/fleet/pull-object", pull_object, methods=["POST"]),
|
|
221
|
+
Route("/api/v0/fleet/push-object", push_object, methods=["POST"]),
|
|
222
|
+
Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
|
|
168
223
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
|
169
224
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
|
225
|
+
Route(
|
|
226
|
+
"/api/v0/fleet/confirm-message-received",
|
|
227
|
+
confirm_message_received,
|
|
228
|
+
methods=["POST"],
|
|
229
|
+
),
|
|
170
230
|
]
|
|
171
231
|
|
|
172
232
|
app: Starlette = Starlette(
|
|
@@ -25,19 +25,20 @@ from pathlib import Path
|
|
|
25
25
|
from queue import Empty, Queue
|
|
26
26
|
from time import sleep
|
|
27
27
|
from typing import Callable, Optional
|
|
28
|
+
from uuid import uuid4
|
|
28
29
|
|
|
30
|
+
from flwr.app.error import Error
|
|
29
31
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
32
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
31
33
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
32
34
|
from flwr.common import Message
|
|
33
35
|
from flwr.common.constant import (
|
|
36
|
+
HEARTBEAT_MAX_INTERVAL,
|
|
34
37
|
NUM_PARTITIONS_KEY,
|
|
35
38
|
PARTITION_ID_KEY,
|
|
36
|
-
PING_MAX_INTERVAL,
|
|
37
39
|
ErrorCode,
|
|
38
40
|
)
|
|
39
41
|
from flwr.common.logger import log
|
|
40
|
-
from flwr.common.message import Error
|
|
41
42
|
from flwr.common.typing import Run
|
|
42
43
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
43
44
|
|
|
@@ -53,7 +54,7 @@ def _register_nodes(
|
|
|
53
54
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
54
55
|
state = state_factory.state()
|
|
55
56
|
for i in range(num_nodes):
|
|
56
|
-
node_id = state.create_node(
|
|
57
|
+
node_id = state.create_node(heartbeat_interval=HEARTBEAT_MAX_INTERVAL)
|
|
57
58
|
nodes_mapping[node_id] = i
|
|
58
59
|
log(DEBUG, "Registered %i nodes", len(nodes_mapping))
|
|
59
60
|
return nodes_mapping
|
|
@@ -134,6 +135,8 @@ def worker(
|
|
|
134
135
|
|
|
135
136
|
finally:
|
|
136
137
|
if out_mssg:
|
|
138
|
+
# Assign a message_id
|
|
139
|
+
out_mssg.metadata.__dict__["_message_id"] = str(uuid4())
|
|
137
140
|
# Store reply Messages in state
|
|
138
141
|
messageres_queue.put(out_mssg)
|
|
139
142
|
|