flwr-nightly 1.13.0.dev20241023__py3-none-any.whl → 1.13.0.dev20241025__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/proto/driver_pb2.py +24 -15
- flwr/proto/driver_pb2.pyi +59 -0
- flwr/proto/driver_pb2_grpc.py +68 -0
- flwr/proto/driver_pb2_grpc.pyi +26 -0
- flwr/server/app.py +4 -2
- flwr/server/run_serverapp.py +13 -9
- flwr/server/superlink/driver/driver_servicer.py +65 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +12 -1
- flwr/server/superlink/linkstate/linkstate.py +29 -0
- flwr/server/superlink/linkstate/sqlite_linkstate.py +51 -6
- flwr/server/superlink/linkstate/utils.py +12 -1
- flwr/simulation/run_simulation.py +12 -4
- flwr/superexec/app.py +3 -138
- flwr/superexec/deployment.py +34 -25
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +11 -1
- flwr/superexec/executor.py +19 -0
- flwr/superexec/simulation.py +8 -0
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/RECORD +24 -25
- flwr/client/node_state_tests.py +0 -65
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.13.0.dev20241023.dist-info → flwr_nightly-1.13.0.dev20241025.dist-info}/entry_points.txt +0 -0
flwr/proto/driver_pb2.py
CHANGED
|
@@ -13,30 +13,39 @@ _sym_db = _symbol_database.Default()
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
16
|
+
from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
|
|
16
17
|
from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
|
|
17
18
|
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
18
19
|
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.
|
|
22
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\",\n\x1aPullServerAppInputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9e\x05\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x62\x06proto3')
|
|
22
23
|
|
|
23
24
|
_globals = globals()
|
|
24
25
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
25
26
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals)
|
|
26
27
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
27
28
|
DESCRIPTOR._options = None
|
|
28
|
-
_globals['_GETNODESREQUEST']._serialized_start=
|
|
29
|
-
_globals['_GETNODESREQUEST']._serialized_end=
|
|
30
|
-
_globals['_GETNODESRESPONSE']._serialized_start=
|
|
31
|
-
_globals['_GETNODESRESPONSE']._serialized_end=
|
|
32
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_start=
|
|
33
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_end=
|
|
34
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_start=
|
|
35
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_end=
|
|
36
|
-
_globals['_PULLTASKRESREQUEST']._serialized_start=
|
|
37
|
-
_globals['_PULLTASKRESREQUEST']._serialized_end=
|
|
38
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_start=
|
|
39
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_end=
|
|
40
|
-
_globals['
|
|
41
|
-
_globals['
|
|
29
|
+
_globals['_GETNODESREQUEST']._serialized_start=155
|
|
30
|
+
_globals['_GETNODESREQUEST']._serialized_end=188
|
|
31
|
+
_globals['_GETNODESRESPONSE']._serialized_start=190
|
|
32
|
+
_globals['_GETNODESRESPONSE']._serialized_end=241
|
|
33
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_start=243
|
|
34
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_end=307
|
|
35
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_start=309
|
|
36
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_end=348
|
|
37
|
+
_globals['_PULLTASKRESREQUEST']._serialized_start=350
|
|
38
|
+
_globals['_PULLTASKRESREQUEST']._serialized_end=420
|
|
39
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_start=422
|
|
40
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_end=487
|
|
41
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=489
|
|
42
|
+
_globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=533
|
|
43
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=535
|
|
44
|
+
_globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=662
|
|
45
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=664
|
|
46
|
+
_globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=747
|
|
47
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=749
|
|
48
|
+
_globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=779
|
|
49
|
+
_globals['_DRIVER']._serialized_start=782
|
|
50
|
+
_globals['_DRIVER']._serialized_end=1452
|
|
42
51
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/driver_pb2.pyi
CHANGED
|
@@ -3,7 +3,10 @@
|
|
|
3
3
|
isort:skip_file
|
|
4
4
|
"""
|
|
5
5
|
import builtins
|
|
6
|
+
import flwr.proto.fab_pb2
|
|
7
|
+
import flwr.proto.message_pb2
|
|
6
8
|
import flwr.proto.node_pb2
|
|
9
|
+
import flwr.proto.run_pb2
|
|
7
10
|
import flwr.proto.task_pb2
|
|
8
11
|
import google.protobuf.descriptor
|
|
9
12
|
import google.protobuf.internal.containers
|
|
@@ -91,3 +94,59 @@ class PullTaskResResponse(google.protobuf.message.Message):
|
|
|
91
94
|
) -> None: ...
|
|
92
95
|
def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ...
|
|
93
96
|
global___PullTaskResResponse = PullTaskResResponse
|
|
97
|
+
|
|
98
|
+
class PullServerAppInputsRequest(google.protobuf.message.Message):
|
|
99
|
+
"""PullServerAppInputs messages"""
|
|
100
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
101
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
102
|
+
run_id: builtins.int
|
|
103
|
+
def __init__(self,
|
|
104
|
+
*,
|
|
105
|
+
run_id: builtins.int = ...,
|
|
106
|
+
) -> None: ...
|
|
107
|
+
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
|
|
108
|
+
global___PullServerAppInputsRequest = PullServerAppInputsRequest
|
|
109
|
+
|
|
110
|
+
class PullServerAppInputsResponse(google.protobuf.message.Message):
|
|
111
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
112
|
+
CONTEXT_FIELD_NUMBER: builtins.int
|
|
113
|
+
RUN_FIELD_NUMBER: builtins.int
|
|
114
|
+
FAB_FIELD_NUMBER: builtins.int
|
|
115
|
+
@property
|
|
116
|
+
def context(self) -> flwr.proto.message_pb2.Context: ...
|
|
117
|
+
@property
|
|
118
|
+
def run(self) -> flwr.proto.run_pb2.Run: ...
|
|
119
|
+
@property
|
|
120
|
+
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
|
|
121
|
+
def __init__(self,
|
|
122
|
+
*,
|
|
123
|
+
context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
|
|
124
|
+
run: typing.Optional[flwr.proto.run_pb2.Run] = ...,
|
|
125
|
+
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
|
|
126
|
+
) -> None: ...
|
|
127
|
+
def HasField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> builtins.bool: ...
|
|
128
|
+
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> None: ...
|
|
129
|
+
global___PullServerAppInputsResponse = PullServerAppInputsResponse
|
|
130
|
+
|
|
131
|
+
class PushServerAppOutputsRequest(google.protobuf.message.Message):
|
|
132
|
+
"""PushServerAppOutputs messages"""
|
|
133
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
134
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
135
|
+
CONTEXT_FIELD_NUMBER: builtins.int
|
|
136
|
+
run_id: builtins.int
|
|
137
|
+
@property
|
|
138
|
+
def context(self) -> flwr.proto.message_pb2.Context: ...
|
|
139
|
+
def __init__(self,
|
|
140
|
+
*,
|
|
141
|
+
run_id: builtins.int = ...,
|
|
142
|
+
context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
|
|
143
|
+
) -> None: ...
|
|
144
|
+
def HasField(self, field_name: typing_extensions.Literal["context",b"context"]) -> builtins.bool: ...
|
|
145
|
+
def ClearField(self, field_name: typing_extensions.Literal["context",b"context","run_id",b"run_id"]) -> None: ...
|
|
146
|
+
global___PushServerAppOutputsRequest = PushServerAppOutputsRequest
|
|
147
|
+
|
|
148
|
+
class PushServerAppOutputsResponse(google.protobuf.message.Message):
|
|
149
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
150
|
+
def __init__(self,
|
|
151
|
+
) -> None: ...
|
|
152
|
+
global___PushServerAppOutputsResponse = PushServerAppOutputsResponse
|
flwr/proto/driver_pb2_grpc.py
CHANGED
|
@@ -46,6 +46,16 @@ class DriverStub(object):
|
|
|
46
46
|
request_serializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.SerializeToString,
|
|
47
47
|
response_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
|
|
48
48
|
)
|
|
49
|
+
self.PullServerAppInputs = channel.unary_unary(
|
|
50
|
+
'/flwr.proto.Driver/PullServerAppInputs',
|
|
51
|
+
request_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString,
|
|
52
|
+
response_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString,
|
|
53
|
+
)
|
|
54
|
+
self.PushServerAppOutputs = channel.unary_unary(
|
|
55
|
+
'/flwr.proto.Driver/PushServerAppOutputs',
|
|
56
|
+
request_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
|
|
57
|
+
response_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
|
|
58
|
+
)
|
|
49
59
|
|
|
50
60
|
|
|
51
61
|
class DriverServicer(object):
|
|
@@ -93,6 +103,20 @@ class DriverServicer(object):
|
|
|
93
103
|
context.set_details('Method not implemented!')
|
|
94
104
|
raise NotImplementedError('Method not implemented!')
|
|
95
105
|
|
|
106
|
+
def PullServerAppInputs(self, request, context):
|
|
107
|
+
"""Pull ServerApp inputs
|
|
108
|
+
"""
|
|
109
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
110
|
+
context.set_details('Method not implemented!')
|
|
111
|
+
raise NotImplementedError('Method not implemented!')
|
|
112
|
+
|
|
113
|
+
def PushServerAppOutputs(self, request, context):
|
|
114
|
+
"""Push ServerApp outputs
|
|
115
|
+
"""
|
|
116
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
117
|
+
context.set_details('Method not implemented!')
|
|
118
|
+
raise NotImplementedError('Method not implemented!')
|
|
119
|
+
|
|
96
120
|
|
|
97
121
|
def add_DriverServicer_to_server(servicer, server):
|
|
98
122
|
rpc_method_handlers = {
|
|
@@ -126,6 +150,16 @@ def add_DriverServicer_to_server(servicer, server):
|
|
|
126
150
|
request_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.FromString,
|
|
127
151
|
response_serializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.SerializeToString,
|
|
128
152
|
),
|
|
153
|
+
'PullServerAppInputs': grpc.unary_unary_rpc_method_handler(
|
|
154
|
+
servicer.PullServerAppInputs,
|
|
155
|
+
request_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.FromString,
|
|
156
|
+
response_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.SerializeToString,
|
|
157
|
+
),
|
|
158
|
+
'PushServerAppOutputs': grpc.unary_unary_rpc_method_handler(
|
|
159
|
+
servicer.PushServerAppOutputs,
|
|
160
|
+
request_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.FromString,
|
|
161
|
+
response_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.SerializeToString,
|
|
162
|
+
),
|
|
129
163
|
}
|
|
130
164
|
generic_handler = grpc.method_handlers_generic_handler(
|
|
131
165
|
'flwr.proto.Driver', rpc_method_handlers)
|
|
@@ -237,3 +271,37 @@ class Driver(object):
|
|
|
237
271
|
flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
|
|
238
272
|
options, channel_credentials,
|
|
239
273
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
274
|
+
|
|
275
|
+
@staticmethod
|
|
276
|
+
def PullServerAppInputs(request,
|
|
277
|
+
target,
|
|
278
|
+
options=(),
|
|
279
|
+
channel_credentials=None,
|
|
280
|
+
call_credentials=None,
|
|
281
|
+
insecure=False,
|
|
282
|
+
compression=None,
|
|
283
|
+
wait_for_ready=None,
|
|
284
|
+
timeout=None,
|
|
285
|
+
metadata=None):
|
|
286
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PullServerAppInputs',
|
|
287
|
+
flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString,
|
|
288
|
+
flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString,
|
|
289
|
+
options, channel_credentials,
|
|
290
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def PushServerAppOutputs(request,
|
|
294
|
+
target,
|
|
295
|
+
options=(),
|
|
296
|
+
channel_credentials=None,
|
|
297
|
+
call_credentials=None,
|
|
298
|
+
insecure=False,
|
|
299
|
+
compression=None,
|
|
300
|
+
wait_for_ready=None,
|
|
301
|
+
timeout=None,
|
|
302
|
+
metadata=None):
|
|
303
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PushServerAppOutputs',
|
|
304
|
+
flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
|
|
305
|
+
flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
|
|
306
|
+
options, channel_credentials,
|
|
307
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
flwr/proto/driver_pb2_grpc.pyi
CHANGED
|
@@ -40,6 +40,16 @@ class DriverStub:
|
|
|
40
40
|
flwr.proto.fab_pb2.GetFabResponse]
|
|
41
41
|
"""Get FAB"""
|
|
42
42
|
|
|
43
|
+
PullServerAppInputs: grpc.UnaryUnaryMultiCallable[
|
|
44
|
+
flwr.proto.driver_pb2.PullServerAppInputsRequest,
|
|
45
|
+
flwr.proto.driver_pb2.PullServerAppInputsResponse]
|
|
46
|
+
"""Pull ServerApp inputs"""
|
|
47
|
+
|
|
48
|
+
PushServerAppOutputs: grpc.UnaryUnaryMultiCallable[
|
|
49
|
+
flwr.proto.driver_pb2.PushServerAppOutputsRequest,
|
|
50
|
+
flwr.proto.driver_pb2.PushServerAppOutputsResponse]
|
|
51
|
+
"""Push ServerApp outputs"""
|
|
52
|
+
|
|
43
53
|
|
|
44
54
|
class DriverServicer(metaclass=abc.ABCMeta):
|
|
45
55
|
@abc.abstractmethod
|
|
@@ -90,5 +100,21 @@ class DriverServicer(metaclass=abc.ABCMeta):
|
|
|
90
100
|
"""Get FAB"""
|
|
91
101
|
pass
|
|
92
102
|
|
|
103
|
+
@abc.abstractmethod
|
|
104
|
+
def PullServerAppInputs(self,
|
|
105
|
+
request: flwr.proto.driver_pb2.PullServerAppInputsRequest,
|
|
106
|
+
context: grpc.ServicerContext,
|
|
107
|
+
) -> flwr.proto.driver_pb2.PullServerAppInputsResponse:
|
|
108
|
+
"""Pull ServerApp inputs"""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abc.abstractmethod
|
|
112
|
+
def PushServerAppOutputs(self,
|
|
113
|
+
request: flwr.proto.driver_pb2.PushServerAppOutputsRequest,
|
|
114
|
+
context: grpc.ServicerContext,
|
|
115
|
+
) -> flwr.proto.driver_pb2.PushServerAppOutputsResponse:
|
|
116
|
+
"""Push ServerApp outputs"""
|
|
117
|
+
pass
|
|
118
|
+
|
|
93
119
|
|
|
94
120
|
def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
|
flwr/server/app.py
CHANGED
|
@@ -64,7 +64,7 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
|
64
64
|
)
|
|
65
65
|
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
66
66
|
from flwr.superexec.app import load_executor
|
|
67
|
-
from flwr.superexec.exec_grpc import
|
|
67
|
+
from flwr.superexec.exec_grpc import run_exec_api_grpc
|
|
68
68
|
|
|
69
69
|
from .client_manager import ClientManager
|
|
70
70
|
from .history import History
|
|
@@ -329,8 +329,10 @@ def run_superlink() -> None:
|
|
|
329
329
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
330
330
|
|
|
331
331
|
# Start Exec API
|
|
332
|
-
exec_server: grpc.Server =
|
|
332
|
+
exec_server: grpc.Server = run_exec_api_grpc(
|
|
333
333
|
address=exec_address,
|
|
334
|
+
state_factory=state_factory,
|
|
335
|
+
ffs_factory=ffs_factory,
|
|
334
336
|
executor=load_executor(args),
|
|
335
337
|
certificates=certificates,
|
|
336
338
|
config=parse_config_args(
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -34,7 +34,6 @@ from flwr.common.config import (
|
|
|
34
34
|
from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
|
|
35
35
|
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
36
36
|
from flwr.common.object_ref import load_app
|
|
37
|
-
from flwr.common.typing import UserConfig
|
|
38
37
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
39
38
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
40
39
|
CreateRunRequest,
|
|
@@ -46,13 +45,14 @@ from .driver.grpc_driver import GrpcDriver
|
|
|
46
45
|
from .server_app import LoadServerAppError, ServerApp
|
|
47
46
|
|
|
48
47
|
|
|
48
|
+
# pylint: disable-next=too-many-arguments,too-many-positional-arguments
|
|
49
49
|
def run(
|
|
50
50
|
driver: Driver,
|
|
51
|
+
context: Context,
|
|
51
52
|
server_app_dir: str,
|
|
52
|
-
server_app_run_config: UserConfig,
|
|
53
53
|
server_app_attr: Optional[str] = None,
|
|
54
54
|
loaded_server_app: Optional[ServerApp] = None,
|
|
55
|
-
) ->
|
|
55
|
+
) -> Context:
|
|
56
56
|
"""Run ServerApp with a given Driver."""
|
|
57
57
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
58
58
|
raise ValueError(
|
|
@@ -78,15 +78,11 @@ def run(
|
|
|
78
78
|
|
|
79
79
|
server_app = _load()
|
|
80
80
|
|
|
81
|
-
# Initialize Context
|
|
82
|
-
context = Context(
|
|
83
|
-
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
|
|
84
|
-
)
|
|
85
|
-
|
|
86
81
|
# Call ServerApp
|
|
87
82
|
server_app(driver=driver, context=context)
|
|
88
83
|
|
|
89
84
|
log(DEBUG, "ServerApp finished running.")
|
|
85
|
+
return context
|
|
90
86
|
|
|
91
87
|
|
|
92
88
|
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
|
@@ -225,11 +221,19 @@ def run_server_app() -> None:
|
|
|
225
221
|
root_certificates,
|
|
226
222
|
)
|
|
227
223
|
|
|
224
|
+
# Initialize Context
|
|
225
|
+
context = Context(
|
|
226
|
+
node_id=0,
|
|
227
|
+
node_config={},
|
|
228
|
+
state=RecordSet(),
|
|
229
|
+
run_config=server_app_run_config,
|
|
230
|
+
)
|
|
231
|
+
|
|
228
232
|
# Run the ServerApp with the Driver
|
|
229
233
|
run(
|
|
230
234
|
driver=driver,
|
|
235
|
+
context=context,
|
|
231
236
|
server_app_dir=app_path,
|
|
232
|
-
server_app_run_config=server_app_run_config,
|
|
233
237
|
server_app_attr=server_app_attr,
|
|
234
238
|
)
|
|
235
239
|
|
|
@@ -15,27 +15,35 @@
|
|
|
15
15
|
"""Driver API servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import threading
|
|
18
19
|
import time
|
|
19
|
-
from logging import DEBUG
|
|
20
|
+
from logging import DEBUG, INFO
|
|
20
21
|
from typing import Optional
|
|
21
22
|
from uuid import UUID
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
26
|
+
from flwr.common.constant import Status
|
|
25
27
|
from flwr.common.logger import log
|
|
26
28
|
from flwr.common.serde import (
|
|
29
|
+
context_from_proto,
|
|
30
|
+
context_to_proto,
|
|
27
31
|
fab_from_proto,
|
|
28
32
|
fab_to_proto,
|
|
33
|
+
run_to_proto,
|
|
29
34
|
user_config_from_proto,
|
|
30
|
-
user_config_to_proto,
|
|
31
35
|
)
|
|
32
|
-
from flwr.common.typing import Fab
|
|
36
|
+
from flwr.common.typing import Fab, RunStatus
|
|
33
37
|
from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
|
|
34
38
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
35
39
|
GetNodesRequest,
|
|
36
40
|
GetNodesResponse,
|
|
41
|
+
PullServerAppInputsRequest,
|
|
42
|
+
PullServerAppInputsResponse,
|
|
37
43
|
PullTaskResRequest,
|
|
38
44
|
PullTaskResResponse,
|
|
45
|
+
PushServerAppOutputsRequest,
|
|
46
|
+
PushServerAppOutputsResponse,
|
|
39
47
|
PushTaskInsRequest,
|
|
40
48
|
PushTaskInsResponse,
|
|
41
49
|
)
|
|
@@ -46,7 +54,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
46
54
|
CreateRunResponse,
|
|
47
55
|
GetRunRequest,
|
|
48
56
|
GetRunResponse,
|
|
49
|
-
Run,
|
|
50
57
|
)
|
|
51
58
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
52
59
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
@@ -63,6 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
63
70
|
) -> None:
|
|
64
71
|
self.state_factory = state_factory
|
|
65
72
|
self.ffs_factory = ffs_factory
|
|
73
|
+
self.lock = threading.RLock()
|
|
66
74
|
|
|
67
75
|
def GetNodes(
|
|
68
76
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
|
@@ -177,15 +185,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
177
185
|
if run is None:
|
|
178
186
|
return GetRunResponse()
|
|
179
187
|
|
|
180
|
-
return GetRunResponse(
|
|
181
|
-
run=Run(
|
|
182
|
-
run_id=run.run_id,
|
|
183
|
-
fab_id=run.fab_id,
|
|
184
|
-
fab_version=run.fab_version,
|
|
185
|
-
override_config=user_config_to_proto(run.override_config),
|
|
186
|
-
fab_hash=run.fab_hash,
|
|
187
|
-
)
|
|
188
|
-
)
|
|
188
|
+
return GetRunResponse(run=run_to_proto(run))
|
|
189
189
|
|
|
190
190
|
def GetFab(
|
|
191
191
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
@@ -200,6 +200,58 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
200
200
|
|
|
201
201
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
202
202
|
|
|
203
|
+
def PullServerAppInputs(
|
|
204
|
+
self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
|
|
205
|
+
) -> PullServerAppInputsResponse:
|
|
206
|
+
"""Pull ServerApp process inputs."""
|
|
207
|
+
log(DEBUG, "DriverServicer.PullServerAppInputs")
|
|
208
|
+
# Init access to LinkState and Ffs
|
|
209
|
+
state = self.state_factory.state()
|
|
210
|
+
ffs = self.ffs_factory.ffs()
|
|
211
|
+
|
|
212
|
+
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
213
|
+
with self.lock:
|
|
214
|
+
# If run_id is provided, use it, otherwise use the pending run_id
|
|
215
|
+
if request.HasField("run_id"):
|
|
216
|
+
run_id: Optional[int] = request.run_id
|
|
217
|
+
else:
|
|
218
|
+
run_id = state.get_pending_run_id()
|
|
219
|
+
# If there's no pending run, return an empty response
|
|
220
|
+
if run_id is None:
|
|
221
|
+
return PullServerAppInputsResponse()
|
|
222
|
+
|
|
223
|
+
# Retrieve Context, Run and Fab for the run_id
|
|
224
|
+
serverapp_ctxt = state.get_serverapp_context(run_id)
|
|
225
|
+
run = state.get_run(run_id)
|
|
226
|
+
fab = None
|
|
227
|
+
if run and run.fab_hash:
|
|
228
|
+
if result := ffs.get(run.fab_hash):
|
|
229
|
+
fab = Fab(run.fab_hash, result[0])
|
|
230
|
+
if run and fab:
|
|
231
|
+
# Update run status to STARTING
|
|
232
|
+
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
|
233
|
+
log(INFO, "Starting run %d", run_id)
|
|
234
|
+
return PullServerAppInputsResponse(
|
|
235
|
+
context=(
|
|
236
|
+
context_to_proto(serverapp_ctxt) if serverapp_ctxt else None
|
|
237
|
+
),
|
|
238
|
+
run=run_to_proto(run),
|
|
239
|
+
fab=fab_to_proto(fab),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Raise an exception if the Run or Fab is not found,
|
|
243
|
+
# or if the status cannot be updated to STARTING
|
|
244
|
+
raise RuntimeError(f"Failed to start run {run_id}")
|
|
245
|
+
|
|
246
|
+
def PushServerAppOutputs(
|
|
247
|
+
self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
|
|
248
|
+
) -> PushServerAppOutputsResponse:
|
|
249
|
+
"""Push ServerApp process outputs."""
|
|
250
|
+
log(DEBUG, "DriverServicer.PushServerAppOutputs")
|
|
251
|
+
state = self.state_factory.state()
|
|
252
|
+
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
253
|
+
return PushServerAppOutputsResponse()
|
|
254
|
+
|
|
203
255
|
|
|
204
256
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
205
257
|
if validation_error:
|
|
@@ -22,7 +22,7 @@ from logging import ERROR, WARNING
|
|
|
22
22
|
from typing import Optional
|
|
23
23
|
from uuid import UUID, uuid4
|
|
24
24
|
|
|
25
|
-
from flwr.common import log, now
|
|
25
|
+
from flwr.common import Context, log, now
|
|
26
26
|
from flwr.common.constant import (
|
|
27
27
|
MESSAGE_TTL_TOLERANCE,
|
|
28
28
|
NODE_ID_NUM_BYTES,
|
|
@@ -65,6 +65,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
65
65
|
|
|
66
66
|
# Map run_id to RunRecord
|
|
67
67
|
self.run_ids: dict[int, RunRecord] = {}
|
|
68
|
+
self.contexts: dict[int, Context] = {}
|
|
68
69
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
69
70
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
70
71
|
|
|
@@ -500,3 +501,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
500
501
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
501
502
|
return True
|
|
502
503
|
return False
|
|
504
|
+
|
|
505
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
506
|
+
"""Get the context for the specified `run_id`."""
|
|
507
|
+
return self.contexts.get(run_id)
|
|
508
|
+
|
|
509
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
510
|
+
"""Set the context for the specified `run_id`."""
|
|
511
|
+
if run_id not in self.run_ids:
|
|
512
|
+
raise ValueError(f"Run {run_id} not found")
|
|
513
|
+
self.contexts[run_id] = context
|
|
@@ -19,6 +19,7 @@ import abc
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common import Context
|
|
22
23
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
23
24
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
24
25
|
|
|
@@ -270,3 +271,31 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
270
271
|
is_acknowledged : bool
|
|
271
272
|
True if the ping is successfully acknowledged; otherwise, False.
|
|
272
273
|
"""
|
|
274
|
+
|
|
275
|
+
@abc.abstractmethod
|
|
276
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
277
|
+
"""Get the context for the specified `run_id`.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
run_id : int
|
|
282
|
+
The identifier of the run for which to retrieve the context.
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
Optional[Context]
|
|
287
|
+
The context associated with the specified `run_id`, or `None` if no context
|
|
288
|
+
exists for the given `run_id`.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
@abc.abstractmethod
|
|
292
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
293
|
+
"""Set the context for the specified `run_id`.
|
|
294
|
+
|
|
295
|
+
Parameters
|
|
296
|
+
----------
|
|
297
|
+
run_id : int
|
|
298
|
+
The identifier of the run for which to set the context.
|
|
299
|
+
context : Context
|
|
300
|
+
The context to be associated with the specified `run_id`.
|
|
301
|
+
"""
|