flwr-nightly 1.15.0.dev20250127__py3-none-any.whl → 1.15.0.dev20250129__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/utils.py +5 -8
- flwr/client/clientapp/app.py +15 -19
- flwr/client/grpc_client/connection.py +1 -6
- flwr/client/grpc_rere_client/connection.py +12 -12
- flwr/client/grpc_rere_client/grpc_adapter.py +0 -16
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/supernode/app.py +0 -24
- flwr/common/differential_privacy.py +2 -1
- flwr/common/grpc.py +6 -1
- flwr/proto/fleet_pb2.py +27 -40
- flwr/proto/fleet_pb2.pyi +0 -84
- flwr/proto/fleet_pb2_grpc.py +5 -93
- flwr/proto/fleet_pb2_grpc.pyi +12 -38
- flwr/server/app.py +6 -2
- flwr/server/driver/grpc_driver.py +5 -4
- flwr/server/serverapp/app.py +7 -4
- flwr/server/superlink/driver/serverappio_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +6 -43
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +0 -27
- flwr/simulation/simulationio_connection.py +2 -1
- {flwr_nightly-1.15.0.dev20250127.dist-info → flwr_nightly-1.15.0.dev20250129.dist-info}/METADATA +1 -1
- {flwr_nightly-1.15.0.dev20250127.dist-info → flwr_nightly-1.15.0.dev20250129.dist-info}/RECORD +27 -27
- {flwr_nightly-1.15.0.dev20250127.dist-info → flwr_nightly-1.15.0.dev20250129.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.15.0.dev20250127.dist-info → flwr_nightly-1.15.0.dev20250129.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.15.0.dev20250127.dist-info → flwr_nightly-1.15.0.dev20250129.dist-info}/entry_points.txt +0 -0
flwr/proto/fleet_pb2_grpc.py
CHANGED
@@ -51,21 +51,11 @@ class FleetStub(object):
|
|
51
51
|
request_serializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString,
|
52
52
|
response_deserializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString,
|
53
53
|
_registered_method=True)
|
54
|
-
self.PullTaskIns = channel.unary_unary(
|
55
|
-
'/flwr.proto.Fleet/PullTaskIns',
|
56
|
-
request_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString,
|
57
|
-
response_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.FromString,
|
58
|
-
_registered_method=True)
|
59
54
|
self.PullMessages = channel.unary_unary(
|
60
55
|
'/flwr.proto.Fleet/PullMessages',
|
61
56
|
request_serializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesRequest.SerializeToString,
|
62
57
|
response_deserializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesResponse.FromString,
|
63
58
|
_registered_method=True)
|
64
|
-
self.PushTaskRes = channel.unary_unary(
|
65
|
-
'/flwr.proto.Fleet/PushTaskRes',
|
66
|
-
request_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.SerializeToString,
|
67
|
-
response_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString,
|
68
|
-
_registered_method=True)
|
69
59
|
self.PushMessages = channel.unary_unary(
|
70
60
|
'/flwr.proto.Fleet/PushMessages',
|
71
61
|
request_serializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesRequest.SerializeToString,
|
@@ -104,33 +94,19 @@ class FleetServicer(object):
|
|
104
94
|
context.set_details('Method not implemented!')
|
105
95
|
raise NotImplementedError('Method not implemented!')
|
106
96
|
|
107
|
-
def PullTaskIns(self, request, context):
|
108
|
-
"""Retrieve one or more tasks, if possible
|
109
|
-
|
110
|
-
HTTP API path: /api/v1/fleet/pull-task-ins
|
111
|
-
"""
|
112
|
-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
113
|
-
context.set_details('Method not implemented!')
|
114
|
-
raise NotImplementedError('Method not implemented!')
|
115
|
-
|
116
97
|
def PullMessages(self, request, context):
|
117
|
-
"""
|
118
|
-
"""
|
119
|
-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
120
|
-
context.set_details('Method not implemented!')
|
121
|
-
raise NotImplementedError('Method not implemented!')
|
98
|
+
"""Retrieve one or more messages, if possible
|
122
99
|
|
123
|
-
|
124
|
-
"""Complete one or more tasks, if possible
|
125
|
-
|
126
|
-
HTTP API path: /api/v1/fleet/push-task-res
|
100
|
+
HTTP API path: /api/v1/fleet/pull-messages
|
127
101
|
"""
|
128
102
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
129
103
|
context.set_details('Method not implemented!')
|
130
104
|
raise NotImplementedError('Method not implemented!')
|
131
105
|
|
132
106
|
def PushMessages(self, request, context):
|
133
|
-
"""
|
107
|
+
"""Complete one or more messages, if possible
|
108
|
+
|
109
|
+
HTTP API path: /api/v1/fleet/push-messages
|
134
110
|
"""
|
135
111
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
136
112
|
context.set_details('Method not implemented!')
|
@@ -167,21 +143,11 @@ def add_FleetServicer_to_server(servicer, server):
|
|
167
143
|
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.FromString,
|
168
144
|
response_serializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.SerializeToString,
|
169
145
|
),
|
170
|
-
'PullTaskIns': grpc.unary_unary_rpc_method_handler(
|
171
|
-
servicer.PullTaskIns,
|
172
|
-
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString,
|
173
|
-
response_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.SerializeToString,
|
174
|
-
),
|
175
146
|
'PullMessages': grpc.unary_unary_rpc_method_handler(
|
176
147
|
servicer.PullMessages,
|
177
148
|
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesRequest.FromString,
|
178
149
|
response_serializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesResponse.SerializeToString,
|
179
150
|
),
|
180
|
-
'PushTaskRes': grpc.unary_unary_rpc_method_handler(
|
181
|
-
servicer.PushTaskRes,
|
182
|
-
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString,
|
183
|
-
response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString,
|
184
|
-
),
|
185
151
|
'PushMessages': grpc.unary_unary_rpc_method_handler(
|
186
152
|
servicer.PushMessages,
|
187
153
|
request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesRequest.FromString,
|
@@ -289,33 +255,6 @@ class Fleet(object):
|
|
289
255
|
metadata,
|
290
256
|
_registered_method=True)
|
291
257
|
|
292
|
-
@staticmethod
|
293
|
-
def PullTaskIns(request,
|
294
|
-
target,
|
295
|
-
options=(),
|
296
|
-
channel_credentials=None,
|
297
|
-
call_credentials=None,
|
298
|
-
insecure=False,
|
299
|
-
compression=None,
|
300
|
-
wait_for_ready=None,
|
301
|
-
timeout=None,
|
302
|
-
metadata=None):
|
303
|
-
return grpc.experimental.unary_unary(
|
304
|
-
request,
|
305
|
-
target,
|
306
|
-
'/flwr.proto.Fleet/PullTaskIns',
|
307
|
-
flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString,
|
308
|
-
flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.FromString,
|
309
|
-
options,
|
310
|
-
channel_credentials,
|
311
|
-
insecure,
|
312
|
-
call_credentials,
|
313
|
-
compression,
|
314
|
-
wait_for_ready,
|
315
|
-
timeout,
|
316
|
-
metadata,
|
317
|
-
_registered_method=True)
|
318
|
-
|
319
258
|
@staticmethod
|
320
259
|
def PullMessages(request,
|
321
260
|
target,
|
@@ -343,33 +282,6 @@ class Fleet(object):
|
|
343
282
|
metadata,
|
344
283
|
_registered_method=True)
|
345
284
|
|
346
|
-
@staticmethod
|
347
|
-
def PushTaskRes(request,
|
348
|
-
target,
|
349
|
-
options=(),
|
350
|
-
channel_credentials=None,
|
351
|
-
call_credentials=None,
|
352
|
-
insecure=False,
|
353
|
-
compression=None,
|
354
|
-
wait_for_ready=None,
|
355
|
-
timeout=None,
|
356
|
-
metadata=None):
|
357
|
-
return grpc.experimental.unary_unary(
|
358
|
-
request,
|
359
|
-
target,
|
360
|
-
'/flwr.proto.Fleet/PushTaskRes',
|
361
|
-
flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.SerializeToString,
|
362
|
-
flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString,
|
363
|
-
options,
|
364
|
-
channel_credentials,
|
365
|
-
insecure,
|
366
|
-
call_credentials,
|
367
|
-
compression,
|
368
|
-
wait_for_ready,
|
369
|
-
timeout,
|
370
|
-
metadata,
|
371
|
-
_registered_method=True)
|
372
|
-
|
373
285
|
@staticmethod
|
374
286
|
def PushMessages(request,
|
375
287
|
target,
|
flwr/proto/fleet_pb2_grpc.pyi
CHANGED
@@ -22,31 +22,21 @@ class FleetStub:
|
|
22
22
|
flwr.proto.fleet_pb2.PingRequest,
|
23
23
|
flwr.proto.fleet_pb2.PingResponse]
|
24
24
|
|
25
|
-
PullTaskIns: grpc.UnaryUnaryMultiCallable[
|
26
|
-
flwr.proto.fleet_pb2.PullTaskInsRequest,
|
27
|
-
flwr.proto.fleet_pb2.PullTaskInsResponse]
|
28
|
-
"""Retrieve one or more tasks, if possible
|
29
|
-
|
30
|
-
HTTP API path: /api/v1/fleet/pull-task-ins
|
31
|
-
"""
|
32
|
-
|
33
25
|
PullMessages: grpc.UnaryUnaryMultiCallable[
|
34
26
|
flwr.proto.fleet_pb2.PullMessagesRequest,
|
35
27
|
flwr.proto.fleet_pb2.PullMessagesResponse]
|
36
|
-
"""
|
28
|
+
"""Retrieve one or more messages, if possible
|
37
29
|
|
38
|
-
|
39
|
-
flwr.proto.fleet_pb2.PushTaskResRequest,
|
40
|
-
flwr.proto.fleet_pb2.PushTaskResResponse]
|
41
|
-
"""Complete one or more tasks, if possible
|
42
|
-
|
43
|
-
HTTP API path: /api/v1/fleet/push-task-res
|
30
|
+
HTTP API path: /api/v1/fleet/pull-messages
|
44
31
|
"""
|
45
32
|
|
46
33
|
PushMessages: grpc.UnaryUnaryMultiCallable[
|
47
34
|
flwr.proto.fleet_pb2.PushMessagesRequest,
|
48
35
|
flwr.proto.fleet_pb2.PushMessagesResponse]
|
49
|
-
"""
|
36
|
+
"""Complete one or more messages, if possible
|
37
|
+
|
38
|
+
HTTP API path: /api/v1/fleet/push-messages
|
39
|
+
"""
|
50
40
|
|
51
41
|
GetRun: grpc.UnaryUnaryMultiCallable[
|
52
42
|
flwr.proto.run_pb2.GetRunRequest,
|
@@ -77,33 +67,14 @@ class FleetServicer(metaclass=abc.ABCMeta):
|
|
77
67
|
context: grpc.ServicerContext,
|
78
68
|
) -> flwr.proto.fleet_pb2.PingResponse: ...
|
79
69
|
|
80
|
-
@abc.abstractmethod
|
81
|
-
def PullTaskIns(self,
|
82
|
-
request: flwr.proto.fleet_pb2.PullTaskInsRequest,
|
83
|
-
context: grpc.ServicerContext,
|
84
|
-
) -> flwr.proto.fleet_pb2.PullTaskInsResponse:
|
85
|
-
"""Retrieve one or more tasks, if possible
|
86
|
-
|
87
|
-
HTTP API path: /api/v1/fleet/pull-task-ins
|
88
|
-
"""
|
89
|
-
pass
|
90
|
-
|
91
70
|
@abc.abstractmethod
|
92
71
|
def PullMessages(self,
|
93
72
|
request: flwr.proto.fleet_pb2.PullMessagesRequest,
|
94
73
|
context: grpc.ServicerContext,
|
95
74
|
) -> flwr.proto.fleet_pb2.PullMessagesResponse:
|
96
|
-
"""
|
97
|
-
pass
|
75
|
+
"""Retrieve one or more messages, if possible
|
98
76
|
|
99
|
-
|
100
|
-
def PushTaskRes(self,
|
101
|
-
request: flwr.proto.fleet_pb2.PushTaskResRequest,
|
102
|
-
context: grpc.ServicerContext,
|
103
|
-
) -> flwr.proto.fleet_pb2.PushTaskResResponse:
|
104
|
-
"""Complete one or more tasks, if possible
|
105
|
-
|
106
|
-
HTTP API path: /api/v1/fleet/push-task-res
|
77
|
+
HTTP API path: /api/v1/fleet/pull-messages
|
107
78
|
"""
|
108
79
|
pass
|
109
80
|
|
@@ -112,7 +83,10 @@ class FleetServicer(metaclass=abc.ABCMeta):
|
|
112
83
|
request: flwr.proto.fleet_pb2.PushMessagesRequest,
|
113
84
|
context: grpc.ServicerContext,
|
114
85
|
) -> flwr.proto.fleet_pb2.PushMessagesResponse:
|
115
|
-
"""
|
86
|
+
"""Complete one or more messages, if possible
|
87
|
+
|
88
|
+
HTTP API path: /api/v1/fleet/push-messages
|
89
|
+
"""
|
116
90
|
pass
|
117
91
|
|
118
92
|
@abc.abstractmethod
|
flwr/server/app.py
CHANGED
@@ -374,8 +374,9 @@ def run_superlink() -> None:
|
|
374
374
|
bckg_threads.append(fleet_thread)
|
375
375
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
376
376
|
node_public_keys = _try_load_public_keys_node_authentication(args)
|
377
|
-
|
377
|
+
auto_auth = True
|
378
378
|
if node_public_keys is not None:
|
379
|
+
auto_auth = False
|
379
380
|
state = state_factory.state()
|
380
381
|
state.clear_supernode_auth_keys()
|
381
382
|
state.store_node_public_keys(node_public_keys)
|
@@ -384,7 +385,10 @@ def run_superlink() -> None:
|
|
384
385
|
"Node authentication enabled with %d known public keys",
|
385
386
|
len(node_public_keys),
|
386
387
|
)
|
387
|
-
|
388
|
+
else:
|
389
|
+
log(DEBUG, "Automatic node authentication enabled")
|
390
|
+
|
391
|
+
interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)]
|
388
392
|
|
389
393
|
fleet_server = _run_fleet_api_grpc_rere(
|
390
394
|
address=fleet_address,
|
@@ -28,7 +28,7 @@ from flwr.common.constant import (
|
|
28
28
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
29
29
|
SUPERLINK_NODE_ID,
|
30
30
|
)
|
31
|
-
from flwr.common.grpc import create_channel
|
31
|
+
from flwr.common.grpc import create_channel, on_channel_state_change
|
32
32
|
from flwr.common.logger import log
|
33
33
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
34
34
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
@@ -49,7 +49,7 @@ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E
|
|
49
49
|
from .driver import Driver
|
50
50
|
|
51
51
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
52
|
-
[
|
52
|
+
[flwr-serverapp] Error: Not connected.
|
53
53
|
|
54
54
|
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
|
55
55
|
`GrpcDriverStub` methods.
|
@@ -100,9 +100,10 @@ class GrpcDriver(Driver):
|
|
100
100
|
insecure=(self._cert is None),
|
101
101
|
root_certificates=self._cert,
|
102
102
|
)
|
103
|
+
self._channel.subscribe(on_channel_state_change)
|
103
104
|
self._grpc_stub = ServerAppIoStub(self._channel)
|
104
105
|
_wrap_stub(self._grpc_stub, self._retry_invoker)
|
105
|
-
log(DEBUG, "[
|
106
|
+
log(DEBUG, "[flwr-serverapp] Connected to %s", self._addr)
|
106
107
|
|
107
108
|
def _disconnect(self) -> None:
|
108
109
|
"""Disconnect from the ServerAppIo API."""
|
@@ -113,7 +114,7 @@ class GrpcDriver(Driver):
|
|
113
114
|
self._channel = None
|
114
115
|
self._grpc_stub = None
|
115
116
|
channel.close()
|
116
|
-
log(DEBUG, "[
|
117
|
+
log(DEBUG, "[flwr-serverapp] Disconnected")
|
117
118
|
|
118
119
|
def set_run(self, run_id: int) -> None:
|
119
120
|
"""Set the run."""
|
flwr/server/serverapp/app.py
CHANGED
@@ -72,7 +72,7 @@ def flwr_serverapp() -> None:
|
|
72
72
|
|
73
73
|
args = _parse_args_run_flwr_serverapp().parse_args()
|
74
74
|
|
75
|
-
log(INFO, "
|
75
|
+
log(INFO, "Start `flwr-serverapp` process")
|
76
76
|
|
77
77
|
if not args.insecure:
|
78
78
|
flwr_exit(
|
@@ -82,7 +82,8 @@ def flwr_serverapp() -> None:
|
|
82
82
|
|
83
83
|
log(
|
84
84
|
DEBUG,
|
85
|
-
"
|
85
|
+
"`flwr-serverapp` will attempt to connect to SuperLink's "
|
86
|
+
"ServerAppIo API at %s",
|
86
87
|
args.serverappio_api_address,
|
87
88
|
)
|
88
89
|
run_serverapp(
|
@@ -121,6 +122,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
121
122
|
try:
|
122
123
|
# Pull ServerAppInputs from LinkState
|
123
124
|
req = PullServerAppInputsRequest()
|
125
|
+
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
124
126
|
res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
|
125
127
|
if not res.HasField("run"):
|
126
128
|
sleep(3)
|
@@ -143,7 +145,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
143
145
|
stub=driver._stub,
|
144
146
|
)
|
145
147
|
|
146
|
-
log(DEBUG, "
|
148
|
+
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
|
147
149
|
install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
|
148
150
|
|
149
151
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
@@ -164,7 +166,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
164
166
|
|
165
167
|
log(
|
166
168
|
DEBUG,
|
167
|
-
"
|
169
|
+
"[flwr-serverapp] Will load ServerApp `%s` in %s",
|
168
170
|
server_app_attr,
|
169
171
|
app_path,
|
170
172
|
)
|
@@ -190,6 +192,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
190
192
|
|
191
193
|
# Send resulting context
|
192
194
|
context_proto = context_to_proto(updated_context)
|
195
|
+
log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
|
193
196
|
out_req = PushServerAppOutputsRequest(
|
194
197
|
run_id=run.run_id, context=context_proto
|
195
198
|
)
|
@@ -362,9 +362,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
362
362
|
) -> PullServerAppInputsResponse:
|
363
363
|
"""Pull ServerApp process inputs."""
|
364
364
|
log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
|
365
|
-
# Init access to LinkState
|
365
|
+
# Init access to LinkState
|
366
366
|
state = self.state_factory.state()
|
367
|
-
ffs = self.ffs_factory.ffs()
|
368
367
|
|
369
368
|
# Lock access to LinkState, preventing obtaining the same pending run_id
|
370
369
|
with self.lock:
|
@@ -374,6 +373,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
374
373
|
if run_id is None:
|
375
374
|
return PullServerAppInputsResponse()
|
376
375
|
|
376
|
+
# Init access to Ffs
|
377
|
+
ffs = self.ffs_factory.ffs()
|
378
|
+
|
377
379
|
# Retrieve Context, Run and Fab for the run_id
|
378
380
|
serverapp_ctxt = state.get_serverapp_context(run_id)
|
379
381
|
run = state.get_run(run_id)
|
@@ -18,6 +18,7 @@
|
|
18
18
|
from logging import DEBUG, INFO
|
19
19
|
|
20
20
|
import grpc
|
21
|
+
from google.protobuf.json_format import MessageToDict
|
21
22
|
|
22
23
|
from flwr.common.logger import log
|
23
24
|
from flwr.common.typing import InvalidRunStatusException
|
@@ -32,12 +33,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
33
|
PingResponse,
|
33
34
|
PullMessagesRequest,
|
34
35
|
PullMessagesResponse,
|
35
|
-
PullTaskInsRequest,
|
36
|
-
PullTaskInsResponse,
|
37
36
|
PushMessagesRequest,
|
38
37
|
PushMessagesResponse,
|
39
|
-
PushTaskResRequest,
|
40
|
-
PushTaskResResponse,
|
41
38
|
)
|
42
39
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
43
40
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
@@ -60,13 +57,13 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
60
57
|
) -> CreateNodeResponse:
|
61
58
|
"""."""
|
62
59
|
log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
|
63
|
-
log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
|
60
|
+
log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
|
64
61
|
response = message_handler.create_node(
|
65
62
|
request=request,
|
66
63
|
state=self.state_factory.state(),
|
67
64
|
)
|
68
65
|
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
69
|
-
log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
|
66
|
+
log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
|
70
67
|
return response
|
71
68
|
|
72
69
|
def DeleteNode(
|
@@ -74,7 +71,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
74
71
|
) -> DeleteNodeResponse:
|
75
72
|
"""."""
|
76
73
|
log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
|
77
|
-
log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
|
74
|
+
log(DEBUG, "[Fleet.DeleteNode] Request: %s", MessageToDict(request))
|
78
75
|
return message_handler.delete_node(
|
79
76
|
request=request,
|
80
77
|
state=self.state_factory.state(),
|
@@ -82,57 +79,23 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
82
79
|
|
83
80
|
def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
|
84
81
|
"""."""
|
85
|
-
log(DEBUG, "[Fleet.Ping] Request: %s", request)
|
82
|
+
log(DEBUG, "[Fleet.Ping] Request: %s", MessageToDict(request))
|
86
83
|
return message_handler.ping(
|
87
84
|
request=request,
|
88
85
|
state=self.state_factory.state(),
|
89
86
|
)
|
90
87
|
|
91
|
-
def PullTaskIns(
|
92
|
-
self, request: PullTaskInsRequest, context: grpc.ServicerContext
|
93
|
-
) -> PullTaskInsResponse:
|
94
|
-
"""Pull TaskIns."""
|
95
|
-
log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id)
|
96
|
-
log(DEBUG, "[Fleet.PullTaskIns] Request: %s", request)
|
97
|
-
return message_handler.pull_task_ins(
|
98
|
-
request=request,
|
99
|
-
state=self.state_factory.state(),
|
100
|
-
)
|
101
|
-
|
102
88
|
def PullMessages(
|
103
89
|
self, request: PullMessagesRequest, context: grpc.ServicerContext
|
104
90
|
) -> PullMessagesResponse:
|
105
91
|
"""Pull Messages."""
|
106
92
|
log(INFO, "[Fleet.PullMessages] node_id=%s", request.node.node_id)
|
107
|
-
log(DEBUG, "[Fleet.PullMessages] Request: %s", request)
|
93
|
+
log(DEBUG, "[Fleet.PullMessages] Request: %s", MessageToDict(request))
|
108
94
|
return message_handler.pull_messages(
|
109
95
|
request=request,
|
110
96
|
state=self.state_factory.state(),
|
111
97
|
)
|
112
98
|
|
113
|
-
def PushTaskRes(
|
114
|
-
self, request: PushTaskResRequest, context: grpc.ServicerContext
|
115
|
-
) -> PushTaskResResponse:
|
116
|
-
"""Push TaskRes."""
|
117
|
-
if request.task_res_list:
|
118
|
-
log(
|
119
|
-
INFO,
|
120
|
-
"[Fleet.PushTaskRes] Push results from node_id=%s",
|
121
|
-
request.task_res_list[0].task.producer.node_id,
|
122
|
-
)
|
123
|
-
else:
|
124
|
-
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
125
|
-
|
126
|
-
try:
|
127
|
-
res = message_handler.push_task_res(
|
128
|
-
request=request,
|
129
|
-
state=self.state_factory.state(),
|
130
|
-
)
|
131
|
-
except InvalidRunStatusException as e:
|
132
|
-
abort_grpc_context(e.message, context)
|
133
|
-
|
134
|
-
return res
|
135
|
-
|
136
99
|
def PushMessages(
|
137
100
|
self, request: PushMessagesRequest, context: grpc.ServicerContext
|
138
101
|
) -> PushMessagesResponse:
|
@@ -106,6 +106,8 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
106
106
|
# Continue the RPC call
|
107
107
|
expected_node_id = state.get_node_id(node_pk_bytes)
|
108
108
|
if not handler_call_details.method.endswith("CreateNode"):
|
109
|
+
# All calls, except for `CreateNode`, must provide a public key that is
|
110
|
+
# already mapped to a `node_id` (in `LinkState`)
|
109
111
|
if expected_node_id is None:
|
110
112
|
return _unary_unary_rpc_terminator("Invalid node ID")
|
111
113
|
# One of the method handlers in
|
@@ -39,12 +39,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
39
39
|
PingResponse,
|
40
40
|
PullMessagesRequest,
|
41
41
|
PullMessagesResponse,
|
42
|
-
PullTaskInsRequest,
|
43
|
-
PullTaskInsResponse,
|
44
42
|
PushMessagesRequest,
|
45
43
|
PushMessagesResponse,
|
46
|
-
PushTaskResRequest,
|
47
|
-
PushTaskResResponse,
|
48
44
|
Reconnect,
|
49
45
|
)
|
50
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
@@ -53,7 +49,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
53
49
|
GetRunResponse,
|
54
50
|
Run,
|
55
51
|
)
|
56
|
-
from flwr.proto.task_pb2 import TaskIns
|
52
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
57
53
|
from flwr.server.superlink.ffs.ffs import Ffs
|
58
54
|
from flwr.server.superlink.linkstate import LinkState
|
59
55
|
from flwr.server.superlink.utils import check_abort
|
@@ -89,21 +85,6 @@ def ping(
|
|
89
85
|
return PingResponse(success=res)
|
90
86
|
|
91
87
|
|
92
|
-
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
|
93
|
-
"""Pull TaskIns handler."""
|
94
|
-
node = request.node # pylint: disable=no-member
|
95
|
-
node_id: int = node.node_id
|
96
|
-
|
97
|
-
# Retrieve TaskIns from State
|
98
|
-
task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
|
99
|
-
|
100
|
-
# Build response
|
101
|
-
response = PullTaskInsResponse(
|
102
|
-
task_ins_list=task_ins_list,
|
103
|
-
)
|
104
|
-
return response
|
105
|
-
|
106
|
-
|
107
88
|
def pull_messages(
|
108
89
|
request: PullMessagesRequest, state: LinkState
|
109
90
|
) -> PullMessagesResponse:
|
@@ -124,35 +105,6 @@ def pull_messages(
|
|
124
105
|
return PullMessagesResponse(messages_list=msg_proto)
|
125
106
|
|
126
107
|
|
127
|
-
def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
|
128
|
-
"""Push TaskRes handler."""
|
129
|
-
# pylint: disable=no-member
|
130
|
-
task_res: TaskRes = request.task_res_list[0]
|
131
|
-
# pylint: enable=no-member
|
132
|
-
|
133
|
-
# Abort if the run is not running
|
134
|
-
abort_msg = check_abort(
|
135
|
-
task_res.run_id,
|
136
|
-
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
137
|
-
state,
|
138
|
-
)
|
139
|
-
if abort_msg:
|
140
|
-
raise InvalidRunStatusException(abort_msg)
|
141
|
-
|
142
|
-
# Set pushed_at (timestamp in seconds)
|
143
|
-
task_res.task.pushed_at = time.time()
|
144
|
-
|
145
|
-
# Store TaskRes in State
|
146
|
-
task_id: Optional[UUID] = state.store_task_res(task_res=task_res)
|
147
|
-
|
148
|
-
# Build response
|
149
|
-
response = PushTaskResResponse(
|
150
|
-
reconnect=Reconnect(reconnect=5),
|
151
|
-
results={str(task_id): 0},
|
152
|
-
)
|
153
|
-
return response
|
154
|
-
|
155
|
-
|
156
108
|
def push_messages(
|
157
109
|
request: PushMessagesRequest, state: LinkState
|
158
110
|
) -> PushMessagesResponse:
|
@@ -33,12 +33,8 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
33
33
|
PingResponse,
|
34
34
|
PullMessagesRequest,
|
35
35
|
PullMessagesResponse,
|
36
|
-
PullTaskInsRequest,
|
37
|
-
PullTaskInsResponse,
|
38
36
|
PushMessagesRequest,
|
39
37
|
PushMessagesResponse,
|
40
|
-
PushTaskResRequest,
|
41
|
-
PushTaskResResponse,
|
42
38
|
)
|
43
39
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
44
40
|
from flwr.server.superlink.ffs.ffs import Ffs
|
@@ -110,16 +106,6 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
110
106
|
return message_handler.delete_node(request=request, state=state)
|
111
107
|
|
112
108
|
|
113
|
-
@rest_request_response(PullTaskInsRequest)
|
114
|
-
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
115
|
-
"""Pull TaskIns."""
|
116
|
-
# Get state from app
|
117
|
-
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
118
|
-
|
119
|
-
# Handle message
|
120
|
-
return message_handler.pull_task_ins(request=request, state=state)
|
121
|
-
|
122
|
-
|
123
109
|
@rest_request_response(PullMessagesRequest)
|
124
110
|
async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
125
111
|
"""Pull PullMessages."""
|
@@ -130,17 +116,6 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
|
130
116
|
return message_handler.pull_messages(request=request, state=state)
|
131
117
|
|
132
118
|
|
133
|
-
# Check if token is needed here
|
134
|
-
@rest_request_response(PushTaskResRequest)
|
135
|
-
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
136
|
-
"""Push TaskRes."""
|
137
|
-
# Get state from app
|
138
|
-
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
139
|
-
|
140
|
-
# Handle message
|
141
|
-
return message_handler.push_task_res(request=request, state=state)
|
142
|
-
|
143
|
-
|
144
119
|
@rest_request_response(PushMessagesRequest)
|
145
120
|
async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
146
121
|
"""Pull PushMessages."""
|
@@ -187,9 +162,7 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
187
162
|
routes = [
|
188
163
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
189
164
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
190
|
-
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
|
191
165
|
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
192
|
-
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
193
166
|
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
194
167
|
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
195
168
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
@@ -21,7 +21,7 @@ from typing import Optional, cast
|
|
21
21
|
import grpc
|
22
22
|
|
23
23
|
from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS
|
24
|
-
from flwr.common.grpc import create_channel
|
24
|
+
from flwr.common.grpc import create_channel, on_channel_state_change
|
25
25
|
from flwr.common.logger import log
|
26
26
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
27
27
|
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
|
@@ -73,6 +73,7 @@ class SimulationIoConnection:
|
|
73
73
|
insecure=(self._cert is None),
|
74
74
|
root_certificates=self._cert,
|
75
75
|
)
|
76
|
+
self._channel.subscribe(on_channel_state_change)
|
76
77
|
self._grpc_stub = SimulationIoStub(self._channel)
|
77
78
|
_wrap_stub(self._grpc_stub, self._retry_invoker)
|
78
79
|
log(DEBUG, "[SimulationIO] Connected to %s", self._addr)
|