flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +162 -99
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +6 -6
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/logger.py +2 -2
- flwr/common/message.py +327 -102
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +56 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +47 -18
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -18
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
- flwr/server/superlink/linkstate/utils.py +93 -27
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +48 -57
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/exec_user_auth_interceptor.py +18 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,35 +12,37 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower in-memory
|
|
15
|
+
"""Flower in-memory Grid."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
import warnings
|
|
20
19
|
from collections.abc import Iterable
|
|
21
20
|
from typing import Optional, cast
|
|
22
21
|
from uuid import UUID
|
|
23
22
|
|
|
24
|
-
from flwr.common import
|
|
23
|
+
from flwr.common import Message, RecordDict
|
|
25
24
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
25
|
+
from flwr.common.logger import warn_deprecated_feature
|
|
26
26
|
from flwr.common.typing import Run
|
|
27
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
28
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
|
-
from .
|
|
30
|
+
from .grid import Grid
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class
|
|
34
|
-
"""`
|
|
33
|
+
class InMemoryGrid(Grid):
|
|
34
|
+
"""`InMemoryGrid` class provides an interface to the ServerAppIo API.
|
|
35
35
|
|
|
36
36
|
Parameters
|
|
37
37
|
----------
|
|
38
38
|
state_factory : StateFactory
|
|
39
|
-
A StateFactory embedding a state that this
|
|
39
|
+
A StateFactory embedding a state that this grid can interface with.
|
|
40
40
|
pull_interval : float (default=0.1)
|
|
41
41
|
Sleep duration between calls to `pull_messages`.
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
|
+
_deprecation_warning_logged = False
|
|
45
|
+
|
|
44
46
|
def __init__(
|
|
45
47
|
self,
|
|
46
48
|
state_factory: LinkStateFactory,
|
|
@@ -54,10 +56,8 @@ class InMemoryDriver(Driver):
|
|
|
54
56
|
def _check_message(self, message: Message) -> None:
|
|
55
57
|
# Check if the message is valid
|
|
56
58
|
if not (
|
|
57
|
-
message.metadata.
|
|
58
|
-
and message.metadata.
|
|
59
|
-
and message.metadata.message_id == ""
|
|
60
|
-
and message.metadata.reply_to_message == ""
|
|
59
|
+
message.metadata.message_id == ""
|
|
60
|
+
and message.metadata.reply_to_message_id == ""
|
|
61
61
|
and message.metadata.ttl > 0
|
|
62
62
|
and message.metadata.delivered_at == ""
|
|
63
63
|
):
|
|
@@ -77,7 +77,7 @@ class InMemoryDriver(Driver):
|
|
|
77
77
|
|
|
78
78
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
79
79
|
self,
|
|
80
|
-
content:
|
|
80
|
+
content: RecordDict,
|
|
81
81
|
message_type: str,
|
|
82
82
|
dst_node_id: int,
|
|
83
83
|
group_id: str,
|
|
@@ -88,26 +88,13 @@ class InMemoryDriver(Driver):
|
|
|
88
88
|
This method constructs a new `Message` with given content and metadata.
|
|
89
89
|
The `run_id` and `src_node_id` will be set automatically.
|
|
90
90
|
"""
|
|
91
|
-
if
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
"
|
|
95
|
-
"
|
|
96
|
-
stacklevel=2,
|
|
91
|
+
if not InMemoryGrid._deprecation_warning_logged:
|
|
92
|
+
InMemoryGrid._deprecation_warning_logged = True
|
|
93
|
+
warn_deprecated_feature(
|
|
94
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
|
95
|
+
"Use `Message` constructor instead."
|
|
97
96
|
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
metadata = Metadata(
|
|
101
|
-
run_id=cast(Run, self._run).run_id,
|
|
102
|
-
message_id="", # Will be set by the server
|
|
103
|
-
src_node_id=self.node.node_id,
|
|
104
|
-
dst_node_id=dst_node_id,
|
|
105
|
-
reply_to_message="",
|
|
106
|
-
group_id=group_id,
|
|
107
|
-
ttl=ttl_,
|
|
108
|
-
message_type=message_type,
|
|
109
|
-
)
|
|
110
|
-
return Message(metadata=metadata, content=content)
|
|
97
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
|
111
98
|
|
|
112
99
|
def get_node_ids(self) -> Iterable[int]:
|
|
113
100
|
"""Get node IDs."""
|
|
@@ -121,6 +108,9 @@ class InMemoryDriver(Driver):
|
|
|
121
108
|
"""
|
|
122
109
|
msg_ids: list[str] = []
|
|
123
110
|
for msg in messages:
|
|
111
|
+
# Populate metadata
|
|
112
|
+
msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
|
|
113
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
124
114
|
# Check message
|
|
125
115
|
self._check_message(msg)
|
|
126
116
|
# Store in state
|
|
@@ -141,7 +131,7 @@ class InMemoryDriver(Driver):
|
|
|
141
131
|
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
|
142
132
|
# Get IDs of Messages these replies are for
|
|
143
133
|
message_ins_ids_to_delete = {
|
|
144
|
-
UUID(msg_res.metadata.
|
|
134
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
|
|
145
135
|
}
|
|
146
136
|
# Delete
|
|
147
137
|
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
|
@@ -170,7 +160,7 @@ class InMemoryDriver(Driver):
|
|
|
170
160
|
res_msgs = self.pull_messages(msg_ids)
|
|
171
161
|
ret.extend(res_msgs)
|
|
172
162
|
msg_ids.difference_update(
|
|
173
|
-
{msg.metadata.
|
|
163
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
|
174
164
|
)
|
|
175
165
|
if len(msg_ids) == 0:
|
|
176
166
|
break
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -22,18 +22,18 @@ from flwr.common import Context
|
|
|
22
22
|
from flwr.common.logger import log
|
|
23
23
|
from flwr.common.object_ref import load_app
|
|
24
24
|
|
|
25
|
-
from .
|
|
25
|
+
from .grid import Grid
|
|
26
26
|
from .server_app import LoadServerAppError, ServerApp
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def run(
|
|
30
|
-
|
|
30
|
+
grid: Grid,
|
|
31
31
|
context: Context,
|
|
32
32
|
server_app_dir: str,
|
|
33
33
|
server_app_attr: Optional[str] = None,
|
|
34
34
|
loaded_server_app: Optional[ServerApp] = None,
|
|
35
35
|
) -> Context:
|
|
36
|
-
"""Run ServerApp with a given
|
|
36
|
+
"""Run ServerApp with a given Grid."""
|
|
37
37
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
38
38
|
raise ValueError(
|
|
39
39
|
"Either `server_app_attr` or `loaded_server_app` should be set "
|
|
@@ -59,7 +59,7 @@ def run(
|
|
|
59
59
|
server_app = _load()
|
|
60
60
|
|
|
61
61
|
# Call ServerApp
|
|
62
|
-
server_app(
|
|
62
|
+
server_app(grid=grid, context=context)
|
|
63
63
|
|
|
64
64
|
log(DEBUG, "ServerApp finished running.")
|
|
65
65
|
return context
|
flwr/server/server_app.py
CHANGED
|
@@ -15,20 +15,18 @@
|
|
|
15
15
|
"""Flower ServerApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import inspect
|
|
18
19
|
from collections.abc import Iterator
|
|
19
20
|
from contextlib import contextmanager
|
|
20
21
|
from typing import Callable, Optional
|
|
21
22
|
|
|
22
23
|
from flwr.common import Context
|
|
23
|
-
from flwr.common.logger import
|
|
24
|
-
warn_deprecated_feature_with_example,
|
|
25
|
-
warn_preview_feature,
|
|
26
|
-
)
|
|
24
|
+
from flwr.common.logger import warn_deprecated_feature_with_example
|
|
27
25
|
from flwr.server.strategy import Strategy
|
|
28
26
|
|
|
29
27
|
from .client_manager import ClientManager
|
|
30
|
-
from .compat import
|
|
31
|
-
from .
|
|
28
|
+
from .compat import start_grid
|
|
29
|
+
from .grid import Driver, Grid
|
|
32
30
|
from .server import Server
|
|
33
31
|
from .server_config import ServerConfig
|
|
34
32
|
from .typing import ServerAppCallable, ServerFn
|
|
@@ -46,6 +44,21 @@ SERVER_FN_USAGE_EXAMPLE = """
|
|
|
46
44
|
app = ServerApp(server_fn=server_fn)
|
|
47
45
|
"""
|
|
48
46
|
|
|
47
|
+
GRID_USAGE_EXAMPLE = """
|
|
48
|
+
app = ServerApp()
|
|
49
|
+
|
|
50
|
+
@app.main()
|
|
51
|
+
def main(grid: Grid, context: Context) -> None:
|
|
52
|
+
# Your existing ServerApp code ...
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
DRIVER_DEPRECATION_MSG = """
|
|
56
|
+
The `Driver` class is deprecated, it will be removed in a future release.
|
|
57
|
+
"""
|
|
58
|
+
DRIVER_EXAMPLE_MSG = """
|
|
59
|
+
Instead, use `Grid` in the signature of your `ServerApp`. For example:
|
|
60
|
+
"""
|
|
61
|
+
|
|
49
62
|
|
|
50
63
|
@contextmanager
|
|
51
64
|
def _empty_lifespan(_: Context) -> Iterator[None]:
|
|
@@ -57,7 +70,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
57
70
|
|
|
58
71
|
Examples
|
|
59
72
|
--------
|
|
60
|
-
Use the
|
|
73
|
+
Use the ``ServerApp`` with an existing ``Strategy``:
|
|
61
74
|
|
|
62
75
|
>>> def server_fn(context: Context):
|
|
63
76
|
>>> server_config = ServerConfig(num_rounds=3)
|
|
@@ -69,12 +82,12 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
69
82
|
>>>
|
|
70
83
|
>>> app = ServerApp(server_fn=server_fn)
|
|
71
84
|
|
|
72
|
-
Use the
|
|
85
|
+
Use the ``ServerApp`` with a custom main function:
|
|
73
86
|
|
|
74
87
|
>>> app = ServerApp()
|
|
75
88
|
>>>
|
|
76
89
|
>>> @app.main()
|
|
77
|
-
>>> def main(
|
|
90
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
|
78
91
|
>>> print("ServerApp running")
|
|
79
92
|
"""
|
|
80
93
|
|
|
@@ -114,7 +127,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
114
127
|
self._main: Optional[ServerAppCallable] = None
|
|
115
128
|
self._lifespan = _empty_lifespan
|
|
116
129
|
|
|
117
|
-
def __call__(self,
|
|
130
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
|
118
131
|
"""Execute `ServerApp`."""
|
|
119
132
|
with self._lifespan(context):
|
|
120
133
|
# Compatibility mode
|
|
@@ -126,17 +139,17 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
126
139
|
self._config = components.config
|
|
127
140
|
self._strategy = components.strategy
|
|
128
141
|
self._client_manager = components.client_manager
|
|
129
|
-
|
|
142
|
+
start_grid(
|
|
130
143
|
server=self._server,
|
|
131
144
|
config=self._config,
|
|
132
145
|
strategy=self._strategy,
|
|
133
146
|
client_manager=self._client_manager,
|
|
134
|
-
|
|
147
|
+
grid=grid,
|
|
135
148
|
)
|
|
136
149
|
return
|
|
137
150
|
|
|
138
151
|
# New execution mode
|
|
139
|
-
self._main(
|
|
152
|
+
self._main(grid, context)
|
|
140
153
|
|
|
141
154
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
|
142
155
|
"""Return a decorator that registers the main fn with the server app.
|
|
@@ -146,7 +159,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
146
159
|
>>> app = ServerApp()
|
|
147
160
|
>>>
|
|
148
161
|
>>> @app.main()
|
|
149
|
-
>>> def main(
|
|
162
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
|
150
163
|
>>> print("ServerApp running")
|
|
151
164
|
"""
|
|
152
165
|
|
|
@@ -171,12 +184,20 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
171
184
|
>>> app = ServerApp()
|
|
172
185
|
>>>
|
|
173
186
|
>>> @app.main()
|
|
174
|
-
>>> def main(
|
|
187
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
|
175
188
|
>>> print("ServerApp running")
|
|
176
189
|
""",
|
|
177
190
|
)
|
|
178
191
|
|
|
179
|
-
|
|
192
|
+
sig = inspect.signature(main_fn)
|
|
193
|
+
param = list(sig.parameters.values())[0]
|
|
194
|
+
# Check if parameter name or the annotation should be updated
|
|
195
|
+
if param.name == "driver" or param.annotation is Driver:
|
|
196
|
+
warn_deprecated_feature_with_example(
|
|
197
|
+
deprecation_message=DRIVER_DEPRECATION_MSG,
|
|
198
|
+
example_message=DRIVER_EXAMPLE_MSG,
|
|
199
|
+
code_example=GRID_USAGE_EXAMPLE,
|
|
200
|
+
)
|
|
180
201
|
|
|
181
202
|
# Register provided function with the ServerApp object
|
|
182
203
|
self._main = main_fn
|
|
@@ -212,10 +233,9 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
212
233
|
"""
|
|
213
234
|
|
|
214
235
|
def lifespan_decorator(
|
|
215
|
-
lifespan_fn: Callable[[Context], Iterator[None]]
|
|
236
|
+
lifespan_fn: Callable[[Context], Iterator[None]],
|
|
216
237
|
) -> Callable[[Context], Iterator[None]]:
|
|
217
238
|
"""Register the lifespan fn with the ServerApp object."""
|
|
218
|
-
warn_preview_feature("ServerApp-register-lifespan-function")
|
|
219
239
|
|
|
220
240
|
@contextmanager
|
|
221
241
|
def decorated_lifespan(context: Context) -> Iterator[None]:
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
|
60
60
|
PullServerAppInputsResponse,
|
|
61
61
|
PushServerAppOutputsRequest,
|
|
62
62
|
)
|
|
63
|
-
from flwr.server.
|
|
63
|
+
from flwr.server.grid.grpc_grid import GrpcGrid
|
|
64
64
|
from flwr.server.run_serverapp import run as run_
|
|
65
65
|
|
|
66
66
|
|
|
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
106
106
|
certificates: Optional[bytes] = None,
|
|
107
107
|
) -> None:
|
|
108
108
|
"""Run Flower ServerApp process."""
|
|
109
|
-
|
|
109
|
+
grid = GrpcGrid(
|
|
110
110
|
serverappio_service_address=serverappio_api_address,
|
|
111
111
|
root_certificates=certificates,
|
|
112
112
|
)
|
|
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
123
123
|
# Pull ServerAppInputs from LinkState
|
|
124
124
|
req = PullServerAppInputsRequest()
|
|
125
125
|
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
|
126
|
-
res: PullServerAppInputsResponse =
|
|
126
|
+
res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
|
|
127
127
|
if not res.HasField("run"):
|
|
128
128
|
sleep(3)
|
|
129
129
|
run_status = None
|
|
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
135
135
|
|
|
136
136
|
hash_run_id = get_sha256_hash(run.run_id)
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
grid.set_run(run.run_id)
|
|
139
139
|
|
|
140
140
|
# Start log uploader for this run
|
|
141
141
|
log_uploader = start_log_uploader(
|
|
142
142
|
log_queue=log_queue,
|
|
143
143
|
node_id=0,
|
|
144
144
|
run_id=run.run_id,
|
|
145
|
-
stub=
|
|
145
|
+
stub=grid._stub,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
148
|
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
|
|
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
173
173
|
|
|
174
174
|
# Change status to Running
|
|
175
175
|
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
176
|
-
|
|
176
|
+
grid._stub.UpdateRunStatus(
|
|
177
177
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
178
178
|
)
|
|
179
179
|
|
|
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
182
182
|
event_details={"run-id-hash": hash_run_id},
|
|
183
183
|
)
|
|
184
184
|
|
|
185
|
-
# Load and run the ServerApp with the
|
|
185
|
+
# Load and run the ServerApp with the Grid
|
|
186
186
|
updated_context = run_(
|
|
187
|
-
|
|
187
|
+
grid=grid,
|
|
188
188
|
server_app_dir=app_path,
|
|
189
189
|
server_app_attr=server_app_attr,
|
|
190
190
|
context=context,
|
|
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
196
196
|
out_req = PushServerAppOutputsRequest(
|
|
197
197
|
run_id=run.run_id, context=context_proto
|
|
198
198
|
)
|
|
199
|
-
_ =
|
|
199
|
+
_ = grid._stub.PushServerAppOutputs(out_req)
|
|
200
200
|
|
|
201
201
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
202
202
|
except RunNotRunningException:
|
|
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
221
221
|
# Update run status
|
|
222
222
|
if run_status:
|
|
223
223
|
run_status_proto = run_status_to_proto(run_status)
|
|
224
|
-
|
|
224
|
+
grid._stub.UpdateRunStatus(
|
|
225
225
|
UpdateRunStatusRequest(
|
|
226
226
|
run_id=run.run_id, run_status=run_status_proto
|
|
227
227
|
)
|
|
@@ -21,9 +21,9 @@ from typing import Callable
|
|
|
21
21
|
from flwr.client.client_app import ClientApp
|
|
22
22
|
from flwr.common.context import Context
|
|
23
23
|
from flwr.common.message import Message
|
|
24
|
-
from flwr.common.typing import
|
|
24
|
+
from flwr.common.typing import ConfigRecordValues
|
|
25
25
|
|
|
26
|
-
BackendConfig = dict[str, dict[str,
|
|
26
|
+
BackendConfig = dict[str, dict[str, ConfigRecordValues]]
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Backend(ABC):
|
|
@@ -26,7 +26,7 @@ from flwr.common.constant import PARTITION_ID_KEY
|
|
|
26
26
|
from flwr.common.context import Context
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.message import Message
|
|
29
|
-
from flwr.common.typing import
|
|
29
|
+
from flwr.common.typing import ConfigRecordValues
|
|
30
30
|
from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
|
|
31
31
|
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
|
32
32
|
|
|
@@ -104,7 +104,7 @@ class RayBackend(Backend):
|
|
|
104
104
|
if not ray.is_initialized():
|
|
105
105
|
ray_init_args: dict[
|
|
106
106
|
str,
|
|
107
|
-
|
|
107
|
+
ConfigRecordValues,
|
|
108
108
|
] = {}
|
|
109
109
|
|
|
110
110
|
if backend_config.get(self.init_args_key):
|
|
@@ -130,9 +130,7 @@ def worker(
|
|
|
130
130
|
e_code = ErrorCode.UNKNOWN
|
|
131
131
|
|
|
132
132
|
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
133
|
-
out_mssg = message
|
|
134
|
-
error=Error(code=e_code, reason=reason)
|
|
135
|
-
)
|
|
133
|
+
out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
|
|
136
134
|
|
|
137
135
|
finally:
|
|
138
136
|
if out_mssg:
|
|
@@ -27,16 +27,18 @@ from flwr.common import Context, Message, log, now
|
|
|
27
27
|
from flwr.common.constant import (
|
|
28
28
|
MESSAGE_TTL_TOLERANCE,
|
|
29
29
|
NODE_ID_NUM_BYTES,
|
|
30
|
+
PING_PATIENCE,
|
|
30
31
|
RUN_ID_NUM_BYTES,
|
|
31
32
|
SUPERLINK_NODE_ID,
|
|
32
33
|
Status,
|
|
33
34
|
)
|
|
34
|
-
from flwr.common.record import
|
|
35
|
+
from flwr.common.record import ConfigRecord
|
|
35
36
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
36
37
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
37
38
|
from flwr.server.utils import validate_message
|
|
38
39
|
|
|
39
40
|
from .utils import (
|
|
41
|
+
check_node_availability_for_in_message,
|
|
40
42
|
generate_rand_int_from_bytes,
|
|
41
43
|
has_valid_sub_status,
|
|
42
44
|
is_valid_transition,
|
|
@@ -67,7 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
67
69
|
# Map run_id to RunRecord
|
|
68
70
|
self.run_ids: dict[int, RunRecord] = {}
|
|
69
71
|
self.contexts: dict[int, Context] = {}
|
|
70
|
-
self.federation_options: dict[int,
|
|
72
|
+
self.federation_options: dict[int, ConfigRecord] = {}
|
|
71
73
|
self.message_ins_store: dict[UUID, Message] = {}
|
|
72
74
|
self.message_res_store: dict[UUID, Message] = {}
|
|
73
75
|
self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
|
|
@@ -156,7 +158,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
156
158
|
res_metadata = message.metadata
|
|
157
159
|
with self.lock:
|
|
158
160
|
# Check if the Message it is replying to exists and is valid
|
|
159
|
-
msg_ins_id = res_metadata.
|
|
161
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
|
160
162
|
msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
|
|
161
163
|
|
|
162
164
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
|
@@ -232,13 +234,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
232
234
|
with self.lock:
|
|
233
235
|
current = time.time()
|
|
234
236
|
|
|
235
|
-
# Verify
|
|
237
|
+
# Verify Message IDs
|
|
236
238
|
ret = verify_message_ids(
|
|
237
239
|
inquired_message_ids=message_ids,
|
|
238
240
|
found_message_ins_dict=self.message_ins_store,
|
|
239
241
|
current_time=current,
|
|
240
242
|
)
|
|
241
243
|
|
|
244
|
+
# Check node availability
|
|
245
|
+
dst_node_ids = {
|
|
246
|
+
self.message_ins_store[message_id].metadata.dst_node_id
|
|
247
|
+
for message_id in message_ids
|
|
248
|
+
}
|
|
249
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
|
250
|
+
inquired_in_message_ids=message_ids,
|
|
251
|
+
found_in_message_dict=self.message_ins_store,
|
|
252
|
+
node_id_to_online_until={
|
|
253
|
+
node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
|
|
254
|
+
},
|
|
255
|
+
current_time=current,
|
|
256
|
+
)
|
|
257
|
+
ret.update(tmp_ret_dict)
|
|
258
|
+
|
|
242
259
|
# Find all reply Messages
|
|
243
260
|
message_res_found: list[Message] = []
|
|
244
261
|
for message_id in message_ids:
|
|
@@ -317,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
317
334
|
log(ERROR, "Unexpected node registration failure.")
|
|
318
335
|
return 0
|
|
319
336
|
|
|
337
|
+
# Mark the node online util time.time() + ping_interval
|
|
320
338
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
321
339
|
return node_id
|
|
322
340
|
|
|
@@ -381,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
381
399
|
fab_version: Optional[str],
|
|
382
400
|
fab_hash: Optional[str],
|
|
383
401
|
override_config: UserConfig,
|
|
384
|
-
federation_options:
|
|
402
|
+
federation_options: ConfigRecord,
|
|
385
403
|
) -> int:
|
|
386
404
|
"""Create a new run for the specified `fab_hash`."""
|
|
387
405
|
# Sample a random int64 as run_id
|
|
@@ -510,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
510
528
|
|
|
511
529
|
return pending_run_id
|
|
512
530
|
|
|
513
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
|
531
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
|
514
532
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
515
533
|
with self.lock:
|
|
516
534
|
if run_id not in self.run_ids:
|
|
@@ -519,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
519
537
|
return self.federation_options[run_id]
|
|
520
538
|
|
|
521
539
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
522
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
540
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
|
541
|
+
|
|
542
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
|
543
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
|
544
|
+
"""
|
|
523
545
|
with self.lock:
|
|
524
546
|
if node_id in self.node_ids:
|
|
525
|
-
self.node_ids[node_id] = (
|
|
547
|
+
self.node_ids[node_id] = (
|
|
548
|
+
time.time() + PING_PATIENCE * ping_interval,
|
|
549
|
+
ping_interval,
|
|
550
|
+
)
|
|
526
551
|
return True
|
|
527
552
|
return False
|
|
528
553
|
|
|
@@ -20,7 +20,7 @@ from typing import Optional
|
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.common import Context, Message
|
|
23
|
-
from flwr.common.record import
|
|
23
|
+
from flwr.common.record import ConfigRecord
|
|
24
24
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
25
25
|
|
|
26
26
|
|
|
@@ -164,7 +164,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
164
164
|
fab_version: Optional[str],
|
|
165
165
|
fab_hash: Optional[str],
|
|
166
166
|
override_config: UserConfig,
|
|
167
|
-
federation_options:
|
|
167
|
+
federation_options: ConfigRecord,
|
|
168
168
|
) -> int:
|
|
169
169
|
"""Create a new run for the specified `fab_hash`."""
|
|
170
170
|
|
|
@@ -236,7 +236,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
236
236
|
"""
|
|
237
237
|
|
|
238
238
|
@abc.abstractmethod
|
|
239
|
-
def get_federation_options(self, run_id: int) -> Optional[
|
|
239
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
|
240
240
|
"""Retrieve the federation options for the specified `run_id`.
|
|
241
241
|
|
|
242
242
|
Parameters
|
|
@@ -246,7 +246,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
246
246
|
|
|
247
247
|
Returns
|
|
248
248
|
-------
|
|
249
|
-
Optional[
|
|
249
|
+
Optional[ConfigRecord]
|
|
250
250
|
The federation options for the run if it exists; None otherwise.
|
|
251
251
|
"""
|
|
252
252
|
|