flwr 1.13.1__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +5 -0
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +90 -0
- flwr/cli/config_utils.py +43 -149
- flwr/cli/constant.py +27 -0
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +2 -1
- flwr/cli/log.py +34 -37
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +116 -0
- flwr/cli/ls.py +214 -106
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -4
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +103 -43
- flwr/cli/stop.py +139 -0
- flwr/cli/utils.py +186 -8
- flwr/client/app.py +49 -50
- flwr/client/client.py +1 -32
- flwr/client/clientapp/app.py +23 -26
- flwr/client/clientapp/utils.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +2 -13
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +59 -43
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/message_handler.py +1 -2
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/numpy_client.py +0 -44
- flwr/client/rest_client/connection.py +37 -29
- flwr/client/supernode/app.py +20 -74
- flwr/common/address.py +1 -0
- flwr/common/args.py +26 -47
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +122 -0
- flwr/common/config.py +169 -17
- flwr/common/constant.py +38 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +66 -7
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/record/recordset.py +1 -1
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/serde.py +6 -4
- flwr/common/telemetry.py +15 -4
- flwr/common/typing.py +32 -0
- flwr/common/version.py +1 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +27 -15
- flwr/proto/exec_pb2.pyi +80 -2
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +5 -5
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +32 -32
- flwr/proto/serverappio_pb2_grpc.py +62 -28
- flwr/proto/serverappio_pb2_grpc.pyi +29 -16
- flwr/proto/simulationio_pb2.py +3 -3
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +152 -112
- flwr/server/compat/app_utils.py +7 -2
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +38 -85
- flwr/server/driver/inmemory_driver.py +7 -2
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +37 -13
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +2 -1
- flwr/server/superlink/driver/serverappio_servicer.py +148 -63
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -87
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +56 -35
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +99 -169
- flwr/server/superlink/fleet/message_handler/message_handler.py +69 -29
- flwr/server/superlink/fleet/rest_rere/rest_api.py +20 -19
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +60 -99
- flwr/server/superlink/linkstate/linkstate.py +30 -36
- flwr/server/superlink/linkstate/sqlite_linkstate.py +105 -188
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +20 -10
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +36 -22
- flwr/simulation/simulationio_connection.py +5 -1
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +20 -2
- flwr/superexec/exec_servicer.py +97 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/METADATA +14 -13
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/RECORD +150 -144
- flwr/proto/common_pb2.py +0 -36
- flwr/proto/common_pb2.pyi +0 -121
- flwr/proto/common_pb2_grpc.py +0 -4
- flwr/proto/common_pb2_grpc.pyi +0 -4
- flwr/proto/control_pb2.py +0 -27
- flwr/proto/control_pb2.pyi +0 -7
- flwr/proto/control_pb2_grpc.py +0 -135
- flwr/proto/control_pb2_grpc.pyi +0 -53
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
|
@@ -16,14 +16,13 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
|
-
import time
|
|
20
19
|
from logging import DEBUG, INFO
|
|
21
20
|
from typing import Optional
|
|
22
21
|
from uuid import UUID
|
|
23
22
|
|
|
24
23
|
import grpc
|
|
25
24
|
|
|
26
|
-
from flwr.common import ConfigsRecord
|
|
25
|
+
from flwr.common import ConfigsRecord, now
|
|
27
26
|
from flwr.common.constant import Status
|
|
28
27
|
from flwr.common.logger import log
|
|
29
28
|
from flwr.common.serde import (
|
|
@@ -31,7 +30,12 @@ from flwr.common.serde import (
|
|
|
31
30
|
context_to_proto,
|
|
32
31
|
fab_from_proto,
|
|
33
32
|
fab_to_proto,
|
|
33
|
+
message_from_proto,
|
|
34
|
+
message_from_taskres,
|
|
35
|
+
message_to_proto,
|
|
36
|
+
message_to_taskins,
|
|
34
37
|
run_status_from_proto,
|
|
38
|
+
run_status_to_proto,
|
|
35
39
|
run_to_proto,
|
|
36
40
|
user_config_from_proto,
|
|
37
41
|
)
|
|
@@ -48,25 +52,28 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
48
52
|
CreateRunResponse,
|
|
49
53
|
GetRunRequest,
|
|
50
54
|
GetRunResponse,
|
|
55
|
+
GetRunStatusRequest,
|
|
56
|
+
GetRunStatusResponse,
|
|
51
57
|
UpdateRunStatusRequest,
|
|
52
58
|
UpdateRunStatusResponse,
|
|
53
59
|
)
|
|
54
60
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
55
61
|
GetNodesRequest,
|
|
56
62
|
GetNodesResponse,
|
|
63
|
+
PullResMessagesRequest,
|
|
64
|
+
PullResMessagesResponse,
|
|
57
65
|
PullServerAppInputsRequest,
|
|
58
66
|
PullServerAppInputsResponse,
|
|
59
|
-
|
|
60
|
-
|
|
67
|
+
PushInsMessagesRequest,
|
|
68
|
+
PushInsMessagesResponse,
|
|
61
69
|
PushServerAppOutputsRequest,
|
|
62
70
|
PushServerAppOutputsResponse,
|
|
63
|
-
PushTaskInsRequest,
|
|
64
|
-
PushTaskInsResponse,
|
|
65
71
|
)
|
|
66
72
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
67
73
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
68
74
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
69
75
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
76
|
+
from flwr.server.superlink.utils import abort_if
|
|
70
77
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
71
78
|
|
|
72
79
|
|
|
@@ -85,11 +92,20 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
85
92
|
) -> GetNodesResponse:
|
|
86
93
|
"""Get available nodes."""
|
|
87
94
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
95
|
+
|
|
96
|
+
# Init state
|
|
88
97
|
state: LinkState = self.state_factory.state()
|
|
98
|
+
|
|
99
|
+
# Abort if the run is not running
|
|
100
|
+
abort_if(
|
|
101
|
+
request.run_id,
|
|
102
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
103
|
+
state,
|
|
104
|
+
context,
|
|
105
|
+
)
|
|
106
|
+
|
|
89
107
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
90
|
-
nodes: list[Node] = [
|
|
91
|
-
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
92
|
-
]
|
|
108
|
+
nodes: list[Node] = [Node(node_id=node_id) for node_id in all_ids]
|
|
93
109
|
return GetNodesResponse(nodes=nodes)
|
|
94
110
|
|
|
95
111
|
def CreateRun(
|
|
@@ -103,8 +119,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
103
119
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
104
120
|
fab_hash = ffs.put(fab.content, {})
|
|
105
121
|
_raise_if(
|
|
106
|
-
fab_hash != fab.hash_str,
|
|
107
|
-
|
|
122
|
+
validation_error=fab_hash != fab.hash_str,
|
|
123
|
+
request_name="CreateRun",
|
|
124
|
+
detail=f"FAB ({fab.hash_str}) hash from request doesn't match contents",
|
|
108
125
|
)
|
|
109
126
|
else:
|
|
110
127
|
fab_hash = ""
|
|
@@ -117,70 +134,104 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
117
134
|
)
|
|
118
135
|
return CreateRunResponse(run_id=run_id)
|
|
119
136
|
|
|
120
|
-
def
|
|
121
|
-
self, request:
|
|
122
|
-
) ->
|
|
123
|
-
"""Push a set of
|
|
124
|
-
log(DEBUG, "ServerAppIoServicer.
|
|
137
|
+
def PushMessages(
|
|
138
|
+
self, request: PushInsMessagesRequest, context: grpc.ServicerContext
|
|
139
|
+
) -> PushInsMessagesResponse:
|
|
140
|
+
"""Push a set of Messages."""
|
|
141
|
+
log(DEBUG, "ServerAppIoServicer.PushMessages")
|
|
142
|
+
|
|
143
|
+
# Init state
|
|
144
|
+
state: LinkState = self.state_factory.state()
|
|
145
|
+
|
|
146
|
+
# Abort if the run is not running
|
|
147
|
+
abort_if(
|
|
148
|
+
request.run_id,
|
|
149
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
150
|
+
state,
|
|
151
|
+
context,
|
|
152
|
+
)
|
|
125
153
|
|
|
126
154
|
# Set pushed_at (timestamp in seconds)
|
|
127
|
-
pushed_at =
|
|
128
|
-
for task_ins in request.task_ins_list:
|
|
129
|
-
task_ins.task.pushed_at = pushed_at
|
|
155
|
+
pushed_at = now().timestamp()
|
|
130
156
|
|
|
131
|
-
# Validate request
|
|
132
|
-
_raise_if(
|
|
133
|
-
|
|
157
|
+
# Validate request and insert in State
|
|
158
|
+
_raise_if(
|
|
159
|
+
validation_error=len(request.messages_list) == 0,
|
|
160
|
+
request_name="PushMessages",
|
|
161
|
+
detail="`messages_list` must not be empty",
|
|
162
|
+
)
|
|
163
|
+
message_ids: list[Optional[UUID]] = []
|
|
164
|
+
while request.messages_list:
|
|
165
|
+
message_proto = request.messages_list.pop(0)
|
|
166
|
+
message = message_from_proto(message_proto=message_proto)
|
|
167
|
+
task_ins = message_to_taskins(message=message)
|
|
168
|
+
task_ins.task.pushed_at = pushed_at
|
|
134
169
|
validation_errors = validate_task_ins_or_res(task_ins)
|
|
135
|
-
_raise_if(
|
|
170
|
+
_raise_if(
|
|
171
|
+
validation_error=bool(validation_errors),
|
|
172
|
+
request_name="PushMessages",
|
|
173
|
+
detail=", ".join(validation_errors),
|
|
174
|
+
)
|
|
175
|
+
_raise_if(
|
|
176
|
+
validation_error=request.run_id != task_ins.run_id,
|
|
177
|
+
request_name="PushMessages",
|
|
178
|
+
detail="`task_ins` has mismatched `run_id`",
|
|
179
|
+
)
|
|
180
|
+
# Store
|
|
181
|
+
message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
|
|
182
|
+
message_ids.append(message_id)
|
|
183
|
+
|
|
184
|
+
return PushInsMessagesResponse(
|
|
185
|
+
message_ids=[
|
|
186
|
+
str(message_id) if message_id else "" for message_id in message_ids
|
|
187
|
+
]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def PullMessages(
|
|
191
|
+
self, request: PullResMessagesRequest, context: grpc.ServicerContext
|
|
192
|
+
) -> PullResMessagesResponse:
|
|
193
|
+
"""Pull a set of Messages."""
|
|
194
|
+
log(DEBUG, "ServerAppIoServicer.PullMessages")
|
|
136
195
|
|
|
137
196
|
# Init state
|
|
138
197
|
state: LinkState = self.state_factory.state()
|
|
139
198
|
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
return PushTaskInsResponse(
|
|
147
|
-
task_ids=[str(task_id) if task_id else "" for task_id in task_ids]
|
|
199
|
+
# Abort if the run is not running
|
|
200
|
+
abort_if(
|
|
201
|
+
request.run_id,
|
|
202
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
203
|
+
state,
|
|
204
|
+
context,
|
|
148
205
|
)
|
|
149
206
|
|
|
150
|
-
def PullTaskRes(
|
|
151
|
-
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
|
152
|
-
) -> PullTaskResResponse:
|
|
153
|
-
"""Pull a set of TaskRes."""
|
|
154
|
-
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
155
|
-
|
|
156
207
|
# Convert each task_id str to UUID
|
|
157
|
-
|
|
208
|
+
message_ids: set[UUID] = {
|
|
209
|
+
UUID(message_id) for message_id in request.message_ids
|
|
210
|
+
}
|
|
158
211
|
|
|
159
|
-
#
|
|
160
|
-
|
|
212
|
+
# Read from state
|
|
213
|
+
task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)
|
|
161
214
|
|
|
162
|
-
#
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
215
|
+
# Convert to Messages
|
|
216
|
+
messages_list = []
|
|
217
|
+
while task_res_list:
|
|
218
|
+
task_res = task_res_list.pop(0)
|
|
219
|
+
_raise_if(
|
|
220
|
+
validation_error=request.run_id != task_res.run_id,
|
|
221
|
+
request_name="PullMessages",
|
|
222
|
+
detail="`task_res` has mismatched `run_id`",
|
|
167
223
|
)
|
|
224
|
+
message = message_from_taskres(taskres=task_res)
|
|
225
|
+
messages_list.append(message_to_proto(message))
|
|
168
226
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
227
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
228
|
+
task_ins_ids_to_delete = {
|
|
229
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
230
|
+
}
|
|
173
231
|
|
|
174
|
-
|
|
175
|
-
state.delete_tasks(task_ids=task_ids)
|
|
232
|
+
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
176
233
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
# Read from state
|
|
180
|
-
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
181
|
-
|
|
182
|
-
context.set_code(grpc.StatusCode.OK)
|
|
183
|
-
return PullTaskResResponse(task_res_list=task_res_list)
|
|
234
|
+
return PullResMessagesResponse(messages_list=messages_list)
|
|
184
235
|
|
|
185
236
|
def GetRun(
|
|
186
237
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
@@ -217,9 +268,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
217
268
|
) -> PullServerAppInputsResponse:
|
|
218
269
|
"""Pull ServerApp process inputs."""
|
|
219
270
|
log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
|
|
220
|
-
# Init access to LinkState
|
|
271
|
+
# Init access to LinkState
|
|
221
272
|
state = self.state_factory.state()
|
|
222
|
-
ffs = self.ffs_factory.ffs()
|
|
223
273
|
|
|
224
274
|
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
225
275
|
with self.lock:
|
|
@@ -229,6 +279,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
229
279
|
if run_id is None:
|
|
230
280
|
return PullServerAppInputsResponse()
|
|
231
281
|
|
|
282
|
+
# Init access to Ffs
|
|
283
|
+
ffs = self.ffs_factory.ffs()
|
|
284
|
+
|
|
232
285
|
# Retrieve Context, Run and Fab for the run_id
|
|
233
286
|
serverapp_ctxt = state.get_serverapp_context(run_id)
|
|
234
287
|
run = state.get_run(run_id)
|
|
@@ -255,7 +308,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
255
308
|
) -> PushServerAppOutputsResponse:
|
|
256
309
|
"""Push ServerApp process outputs."""
|
|
257
310
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
311
|
+
|
|
312
|
+
# Init state
|
|
258
313
|
state = self.state_factory.state()
|
|
314
|
+
|
|
315
|
+
# Abort if the run is not running
|
|
316
|
+
abort_if(
|
|
317
|
+
request.run_id,
|
|
318
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
319
|
+
state,
|
|
320
|
+
context,
|
|
321
|
+
)
|
|
322
|
+
|
|
259
323
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
260
324
|
return PushServerAppOutputsResponse()
|
|
261
325
|
|
|
@@ -263,9 +327,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
263
327
|
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
264
328
|
) -> UpdateRunStatusResponse:
|
|
265
329
|
"""Update the status of a run."""
|
|
266
|
-
log(DEBUG, "
|
|
330
|
+
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
|
331
|
+
|
|
332
|
+
# Init state
|
|
267
333
|
state = self.state_factory.state()
|
|
268
334
|
|
|
335
|
+
# Abort if the run is finished
|
|
336
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
337
|
+
|
|
269
338
|
# Update the run status
|
|
270
339
|
state.update_run_status(
|
|
271
340
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
@@ -284,7 +353,23 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
284
353
|
state.add_serverapp_log(request.run_id, merged_logs)
|
|
285
354
|
return PushLogsResponse()
|
|
286
355
|
|
|
356
|
+
def GetRunStatus(
|
|
357
|
+
self, request: GetRunStatusRequest, context: grpc.ServicerContext
|
|
358
|
+
) -> GetRunStatusResponse:
|
|
359
|
+
"""Get the status of a run."""
|
|
360
|
+
log(DEBUG, "ServerAppIoServicer.GetRunStatus")
|
|
361
|
+
state = self.state_factory.state()
|
|
362
|
+
|
|
363
|
+
# Get run status from LinkState
|
|
364
|
+
run_statuses = state.get_run_status(set(request.run_ids))
|
|
365
|
+
run_status_dict = {
|
|
366
|
+
run_id: run_status_to_proto(run_status)
|
|
367
|
+
for run_id, run_status in run_statuses.items()
|
|
368
|
+
}
|
|
369
|
+
return GetRunStatusResponse(run_status_dict=run_status_dict)
|
|
370
|
+
|
|
287
371
|
|
|
288
|
-
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
372
|
+
def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
|
|
373
|
+
"""Raise a `ValueError` with a detailed message if a validation error occurs."""
|
|
289
374
|
if validation_error:
|
|
290
|
-
raise ValueError(f"Malformed
|
|
375
|
+
raise ValueError(f"Malformed {request_name}: {detail}")
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Fleet API gRPC adapter servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from logging import DEBUG
|
|
18
|
+
from logging import DEBUG
|
|
19
19
|
from typing import Callable, TypeVar
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
@@ -31,35 +31,30 @@ from flwr.common.constant import (
|
|
|
31
31
|
from flwr.common.logger import log
|
|
32
32
|
from flwr.common.version import package_name, package_version
|
|
33
33
|
from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
|
|
34
|
-
from flwr.proto.fab_pb2 import GetFabRequest
|
|
34
|
+
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
|
35
35
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
36
36
|
CreateNodeRequest,
|
|
37
|
-
CreateNodeResponse,
|
|
38
37
|
DeleteNodeRequest,
|
|
39
|
-
DeleteNodeResponse,
|
|
40
38
|
PingRequest,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
PullTaskInsResponse,
|
|
44
|
-
PushTaskResRequest,
|
|
45
|
-
PushTaskResResponse,
|
|
39
|
+
PullMessagesRequest,
|
|
40
|
+
PushMessagesRequest,
|
|
46
41
|
)
|
|
47
42
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
48
|
-
from flwr.proto.run_pb2 import GetRunRequest
|
|
49
|
-
|
|
50
|
-
from
|
|
51
|
-
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
43
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
44
|
+
|
|
45
|
+
from ..grpc_rere.fleet_servicer import FleetServicer
|
|
52
46
|
|
|
53
47
|
T = TypeVar("T", bound=GrpcMessage)
|
|
54
48
|
|
|
55
49
|
|
|
56
50
|
def _handle(
|
|
57
51
|
msg_container: MessageContainer,
|
|
52
|
+
context: grpc.ServicerContext,
|
|
58
53
|
request_type: type[T],
|
|
59
|
-
handler: Callable[[T], GrpcMessage],
|
|
54
|
+
handler: Callable[[T, grpc.ServicerContext], GrpcMessage],
|
|
60
55
|
) -> MessageContainer:
|
|
61
56
|
req = request_type.FromString(msg_container.grpc_message_content)
|
|
62
|
-
res = handler(req)
|
|
57
|
+
res = handler(req, context)
|
|
63
58
|
res_cls = res.__class__
|
|
64
59
|
return MessageContainer(
|
|
65
60
|
metadata={
|
|
@@ -74,88 +69,26 @@ def _handle(
|
|
|
74
69
|
)
|
|
75
70
|
|
|
76
71
|
|
|
77
|
-
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
72
|
+
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetServicer):
|
|
78
73
|
"""Fleet API via GrpcAdapter servicer."""
|
|
79
74
|
|
|
80
|
-
def __init__(
|
|
81
|
-
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
82
|
-
) -> None:
|
|
83
|
-
self.state_factory = state_factory
|
|
84
|
-
self.ffs_factory = ffs_factory
|
|
85
|
-
|
|
86
75
|
def SendReceive( # pylint: disable=too-many-return-statements
|
|
87
76
|
self, request: MessageContainer, context: grpc.ServicerContext
|
|
88
77
|
) -> MessageContainer:
|
|
89
78
|
"""."""
|
|
90
79
|
log(DEBUG, "GrpcAdapterServicer.SendReceive")
|
|
91
80
|
if request.grpc_message_name == CreateNodeRequest.__qualname__:
|
|
92
|
-
return _handle(request, CreateNodeRequest, self.
|
|
81
|
+
return _handle(request, context, CreateNodeRequest, self.CreateNode)
|
|
93
82
|
if request.grpc_message_name == DeleteNodeRequest.__qualname__:
|
|
94
|
-
return _handle(request, DeleteNodeRequest, self.
|
|
83
|
+
return _handle(request, context, DeleteNodeRequest, self.DeleteNode)
|
|
95
84
|
if request.grpc_message_name == PingRequest.__qualname__:
|
|
96
|
-
return _handle(request, PingRequest, self.
|
|
97
|
-
if request.grpc_message_name == PullTaskInsRequest.__qualname__:
|
|
98
|
-
return _handle(request, PullTaskInsRequest, self._pull_task_ins)
|
|
99
|
-
if request.grpc_message_name == PushTaskResRequest.__qualname__:
|
|
100
|
-
return _handle(request, PushTaskResRequest, self._push_task_res)
|
|
85
|
+
return _handle(request, context, PingRequest, self.Ping)
|
|
101
86
|
if request.grpc_message_name == GetRunRequest.__qualname__:
|
|
102
|
-
return _handle(request, GetRunRequest, self.
|
|
87
|
+
return _handle(request, context, GetRunRequest, self.GetRun)
|
|
103
88
|
if request.grpc_message_name == GetFabRequest.__qualname__:
|
|
104
|
-
return _handle(request, GetFabRequest, self.
|
|
89
|
+
return _handle(request, context, GetFabRequest, self.GetFab)
|
|
90
|
+
if request.grpc_message_name == PullMessagesRequest.__qualname__:
|
|
91
|
+
return _handle(request, context, PullMessagesRequest, self.PullMessages)
|
|
92
|
+
if request.grpc_message_name == PushMessagesRequest.__qualname__:
|
|
93
|
+
return _handle(request, context, PushMessagesRequest, self.PushMessages)
|
|
105
94
|
raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
|
|
106
|
-
|
|
107
|
-
def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
|
|
108
|
-
"""."""
|
|
109
|
-
log(INFO, "GrpcAdapter.CreateNode")
|
|
110
|
-
return message_handler.create_node(
|
|
111
|
-
request=request,
|
|
112
|
-
state=self.state_factory.state(),
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
116
|
-
"""."""
|
|
117
|
-
log(INFO, "GrpcAdapter.DeleteNode")
|
|
118
|
-
return message_handler.delete_node(
|
|
119
|
-
request=request,
|
|
120
|
-
state=self.state_factory.state(),
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
def _ping(self, request: PingRequest) -> PingResponse:
|
|
124
|
-
"""."""
|
|
125
|
-
log(DEBUG, "GrpcAdapter.Ping")
|
|
126
|
-
return message_handler.ping(
|
|
127
|
-
request=request,
|
|
128
|
-
state=self.state_factory.state(),
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
132
|
-
"""Pull TaskIns."""
|
|
133
|
-
log(INFO, "GrpcAdapter.PullTaskIns")
|
|
134
|
-
return message_handler.pull_task_ins(
|
|
135
|
-
request=request,
|
|
136
|
-
state=self.state_factory.state(),
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse:
|
|
140
|
-
"""Push TaskRes."""
|
|
141
|
-
log(INFO, "GrpcAdapter.PushTaskRes")
|
|
142
|
-
return message_handler.push_task_res(
|
|
143
|
-
request=request,
|
|
144
|
-
state=self.state_factory.state(),
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
def _get_run(self, request: GetRunRequest) -> GetRunResponse:
|
|
148
|
-
"""Get run information."""
|
|
149
|
-
log(INFO, "GrpcAdapter.GetRun")
|
|
150
|
-
return message_handler.get_run(
|
|
151
|
-
request=request,
|
|
152
|
-
state=self.state_factory.state(),
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
|
|
156
|
-
"""Get FAB."""
|
|
157
|
-
log(INFO, "GrpcAdapter.GetFab")
|
|
158
|
-
return message_handler.get_fab(
|
|
159
|
-
request=request,
|
|
160
|
-
ffs=self.ffs_factory.ffs(),
|
|
161
|
-
)
|
|
@@ -15,49 +15,19 @@
|
|
|
15
15
|
"""Implements utility function to create a gRPC server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
import sys
|
|
20
|
-
from collections.abc import Sequence
|
|
21
|
-
from logging import ERROR
|
|
22
|
-
from typing import Any, Callable, Optional, Union
|
|
18
|
+
from typing import Optional
|
|
23
19
|
|
|
24
20
|
import grpc
|
|
25
21
|
|
|
26
22
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
|
-
from flwr.common.
|
|
28
|
-
from flwr.common.logger import log
|
|
23
|
+
from flwr.common.grpc import generic_create_grpc_server
|
|
29
24
|
from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
|
|
30
25
|
add_FlowerServiceServicer_to_server,
|
|
31
26
|
)
|
|
32
27
|
from flwr.server.client_manager import ClientManager
|
|
33
|
-
from flwr.server.superlink.driver.serverappio_servicer import ServerAppIoServicer
|
|
34
|
-
from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
|
|
35
|
-
GrpcAdapterServicer,
|
|
36
|
-
)
|
|
37
28
|
from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
|
|
38
29
|
FlowerServiceServicer,
|
|
39
30
|
)
|
|
40
|
-
from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
41
|
-
|
|
42
|
-
INVALID_CERTIFICATES_ERR_MSG = """
|
|
43
|
-
When setting any of root_certificate, certificate, or private_key,
|
|
44
|
-
all of them need to be set.
|
|
45
|
-
"""
|
|
46
|
-
|
|
47
|
-
AddServicerToServerFn = Callable[..., Any]
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool:
|
|
51
|
-
"""Validate certificates tuple."""
|
|
52
|
-
is_valid = (
|
|
53
|
-
all(isinstance(certificate, bytes) for certificate in certificates)
|
|
54
|
-
and len(certificates) == 3
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
if not is_valid:
|
|
58
|
-
log(ERROR, INVALID_CERTIFICATES_ERR_MSG)
|
|
59
|
-
|
|
60
|
-
return is_valid
|
|
61
31
|
|
|
62
32
|
|
|
63
33
|
def start_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
@@ -154,136 +124,3 @@ def start_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
|
154
124
|
server.start()
|
|
155
125
|
|
|
156
126
|
return server
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
160
|
-
servicer_and_add_fn: Union[
|
|
161
|
-
tuple[FleetServicer, AddServicerToServerFn],
|
|
162
|
-
tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
|
163
|
-
tuple[FlowerServiceServicer, AddServicerToServerFn],
|
|
164
|
-
tuple[ServerAppIoServicer, AddServicerToServerFn],
|
|
165
|
-
],
|
|
166
|
-
server_address: str,
|
|
167
|
-
max_concurrent_workers: int = 1000,
|
|
168
|
-
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
169
|
-
keepalive_time_ms: int = 210000,
|
|
170
|
-
certificates: Optional[tuple[bytes, bytes, bytes]] = None,
|
|
171
|
-
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
172
|
-
) -> grpc.Server:
|
|
173
|
-
"""Create a gRPC server with a single servicer.
|
|
174
|
-
|
|
175
|
-
Parameters
|
|
176
|
-
----------
|
|
177
|
-
servicer_and_add_fn : tuple
|
|
178
|
-
A tuple holding a servicer implementation and a matching
|
|
179
|
-
add_Servicer_to_server function.
|
|
180
|
-
server_address : str
|
|
181
|
-
Server address in the form of HOST:PORT e.g. "[::]:8080"
|
|
182
|
-
max_concurrent_workers : int
|
|
183
|
-
Maximum number of clients the server can process before returning
|
|
184
|
-
RESOURCE_EXHAUSTED status (default: 1000)
|
|
185
|
-
max_message_length : int
|
|
186
|
-
Maximum message length that the server can send or receive.
|
|
187
|
-
Int valued in bytes. -1 means unlimited. (default: GRPC_MAX_MESSAGE_LENGTH)
|
|
188
|
-
keepalive_time_ms : int
|
|
189
|
-
Flower uses a default gRPC keepalive time of 210000ms (3 minutes 30 seconds)
|
|
190
|
-
because some cloud providers (for example, Azure) agressively clean up idle
|
|
191
|
-
TCP connections by terminating them after some time (4 minutes in the case
|
|
192
|
-
of Azure). Flower does not use application-level keepalive signals and relies
|
|
193
|
-
on the assumption that the transport layer will fail in cases where the
|
|
194
|
-
connection is no longer active. `keepalive_time_ms` can be used to customize
|
|
195
|
-
the keepalive interval for specific environments. The default Flower gRPC
|
|
196
|
-
keepalive of 210000 ms (3 minutes 30 seconds) ensures that Flower can keep
|
|
197
|
-
the long running streaming connection alive in most environments. The actual
|
|
198
|
-
gRPC default of this setting is 7200000 (2 hours), which results in dropped
|
|
199
|
-
connections in some cloud environments.
|
|
200
|
-
|
|
201
|
-
These settings are related to the issue described here:
|
|
202
|
-
- https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md
|
|
203
|
-
- https://github.com/grpc/grpc/blob/master/doc/keepalive.md
|
|
204
|
-
- https://grpc.io/docs/guides/performance/
|
|
205
|
-
|
|
206
|
-
Mobile Flower clients may choose to increase this value if their server
|
|
207
|
-
environment allows long-running idle TCP connections.
|
|
208
|
-
(default: 210000)
|
|
209
|
-
certificates : Tuple[bytes, bytes, bytes] (default: None)
|
|
210
|
-
Tuple containing root certificate, server certificate, and private key to
|
|
211
|
-
start a secure SSL-enabled server. The tuple is expected to have three bytes
|
|
212
|
-
elements in the following order:
|
|
213
|
-
|
|
214
|
-
* CA certificate.
|
|
215
|
-
* server certificate.
|
|
216
|
-
* server private key.
|
|
217
|
-
interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None)
|
|
218
|
-
A list of gRPC interceptors.
|
|
219
|
-
|
|
220
|
-
Returns
|
|
221
|
-
-------
|
|
222
|
-
server : grpc.Server
|
|
223
|
-
A non-running instance of a gRPC server.
|
|
224
|
-
"""
|
|
225
|
-
# Check if port is in use
|
|
226
|
-
if is_port_in_use(server_address):
|
|
227
|
-
sys.exit(f"Port in server address {server_address} is already in use.")
|
|
228
|
-
|
|
229
|
-
# Deconstruct tuple into servicer and function
|
|
230
|
-
servicer, add_servicer_to_server_fn = servicer_and_add_fn
|
|
231
|
-
|
|
232
|
-
# Possible options:
|
|
233
|
-
# https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h
|
|
234
|
-
options = [
|
|
235
|
-
# Maximum number of concurrent incoming streams to allow on a http2
|
|
236
|
-
# connection. Int valued.
|
|
237
|
-
("grpc.max_concurrent_streams", max(100, max_concurrent_workers)),
|
|
238
|
-
# Maximum message length that the channel can send.
|
|
239
|
-
# Int valued, bytes. -1 means unlimited.
|
|
240
|
-
("grpc.max_send_message_length", max_message_length),
|
|
241
|
-
# Maximum message length that the channel can receive.
|
|
242
|
-
# Int valued, bytes. -1 means unlimited.
|
|
243
|
-
("grpc.max_receive_message_length", max_message_length),
|
|
244
|
-
# The gRPC default for this setting is 7200000 (2 hours). Flower uses a
|
|
245
|
-
# customized default of 210000 (3 minutes and 30 seconds) to improve
|
|
246
|
-
# compatibility with popular cloud providers. Mobile Flower clients may
|
|
247
|
-
# choose to increase this value if their server environment allows
|
|
248
|
-
# long-running idle TCP connections.
|
|
249
|
-
("grpc.keepalive_time_ms", keepalive_time_ms),
|
|
250
|
-
# Setting this to zero will allow sending unlimited keepalive pings in between
|
|
251
|
-
# sending actual data frames.
|
|
252
|
-
("grpc.http2.max_pings_without_data", 0),
|
|
253
|
-
# Is it permissible to send keepalive pings from the client without
|
|
254
|
-
# any outstanding streams. More explanation here:
|
|
255
|
-
# https://github.com/adap/flower/pull/2197
|
|
256
|
-
("grpc.keepalive_permit_without_calls", 0),
|
|
257
|
-
]
|
|
258
|
-
|
|
259
|
-
server = grpc.server(
|
|
260
|
-
concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_workers),
|
|
261
|
-
# Set the maximum number of concurrent RPCs this server will service before
|
|
262
|
-
# returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
|
|
263
|
-
maximum_concurrent_rpcs=max_concurrent_workers,
|
|
264
|
-
options=options,
|
|
265
|
-
interceptors=interceptors,
|
|
266
|
-
)
|
|
267
|
-
add_servicer_to_server_fn(servicer, server)
|
|
268
|
-
|
|
269
|
-
if certificates is not None:
|
|
270
|
-
if not valid_certificates(certificates):
|
|
271
|
-
sys.exit(1)
|
|
272
|
-
|
|
273
|
-
root_certificate_b, certificate_b, private_key_b = certificates
|
|
274
|
-
|
|
275
|
-
server_credentials = grpc.ssl_server_credentials(
|
|
276
|
-
((private_key_b, certificate_b),),
|
|
277
|
-
root_certificates=root_certificate_b,
|
|
278
|
-
# A boolean indicating whether or not to require clients to be
|
|
279
|
-
# authenticated. May only be True if root_certificates is not None.
|
|
280
|
-
# We are explicitly setting the current gRPC default to document
|
|
281
|
-
# the option. For further reference see:
|
|
282
|
-
# https://grpc.github.io/grpc/python/grpc.html#create-server-credentials
|
|
283
|
-
require_client_auth=False,
|
|
284
|
-
)
|
|
285
|
-
server.add_secure_port(server_address, server_credentials)
|
|
286
|
-
else:
|
|
287
|
-
server.add_insecure_port(server_address)
|
|
288
|
-
|
|
289
|
-
return server
|