flwr-nightly 1.12.0.dev20240916__py3-none-any.whl → 1.12.0.dev20241006__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/log.py +234 -0
- flwr/cli/new/new.py +1 -1
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/run/run.py +17 -1
- flwr/client/grpc_rere_client/client_interceptor.py +3 -0
- flwr/client/grpc_rere_client/connection.py +3 -3
- flwr/client/grpc_rere_client/grpc_adapter.py +14 -3
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +1 -0
- flwr/common/constant.py +6 -3
- flwr/common/secure_aggregation/secaggplus_utils.py +4 -4
- flwr/common/serde.py +22 -7
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/control_pb2.py +27 -0
- flwr/proto/control_pb2.pyi +7 -0
- flwr/proto/control_pb2_grpc.py +135 -0
- flwr/proto/control_pb2_grpc.pyi +53 -0
- flwr/proto/driver_pb2.py +15 -24
- flwr/proto/driver_pb2.pyi +0 -52
- flwr/proto/driver_pb2_grpc.py +6 -6
- flwr/proto/driver_pb2_grpc.pyi +4 -4
- flwr/proto/exec_pb2.py +1 -1
- flwr/proto/fab_pb2.py +8 -7
- flwr/proto/fab_pb2.pyi +7 -1
- flwr/proto/fleet_pb2.py +10 -10
- flwr/proto/fleet_pb2.pyi +6 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +1 -1
- flwr/proto/recordset_pb2.py +35 -33
- flwr/proto/recordset_pb2.pyi +40 -14
- flwr/proto/run_pb2.py +33 -9
- flwr/proto/run_pb2.pyi +150 -1
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +8 -8
- flwr/proto/transport_pb2.pyi +9 -6
- flwr/server/run_serverapp.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +4 -0
- flwr/server/superlink/state/in_memory_state.py +17 -0
- flwr/server/superlink/state/sqlite_state.py +142 -24
- flwr/server/superlink/state/utils.py +98 -2
- flwr/server/utils/validator.py +6 -0
- flwr/superexec/deployment.py +3 -1
- flwr/superexec/exec_servicer.py +68 -3
- flwr/superexec/executor.py +2 -1
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/METADATA +4 -2
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/RECORD +53 -48
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240916.dist-info → flwr_nightly-1.12.0.dev20241006.dist-info}/entry_points.txt +0 -0
flwr/cli/app.py
CHANGED
|
@@ -19,6 +19,7 @@ from typer.main import get_command
|
|
|
19
19
|
|
|
20
20
|
from .build import build
|
|
21
21
|
from .install import install
|
|
22
|
+
from .log import log
|
|
22
23
|
from .new import new
|
|
23
24
|
from .run import run
|
|
24
25
|
|
|
@@ -35,6 +36,7 @@ app.command()(new)
|
|
|
35
36
|
app.command()(run)
|
|
36
37
|
app.command()(build)
|
|
37
38
|
app.command()(install)
|
|
39
|
+
app.command()(log)
|
|
38
40
|
|
|
39
41
|
typer_click_object = get_command(app)
|
|
40
42
|
|
flwr/cli/log.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower command line interface `log` command."""
|
|
16
|
+
|
|
17
|
+
import sys
|
|
18
|
+
import time
|
|
19
|
+
from logging import DEBUG, ERROR, INFO
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Annotated, Optional
|
|
22
|
+
|
|
23
|
+
import grpc
|
|
24
|
+
import typer
|
|
25
|
+
|
|
26
|
+
from flwr.cli.config_utils import load_and_validate
|
|
27
|
+
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
28
|
+
from flwr.common.logger import log as logger
|
|
29
|
+
from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
|
|
30
|
+
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
31
|
+
|
|
32
|
+
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def start_stream(
|
|
36
|
+
run_id: int, channel: grpc.Channel, refresh_period: int = CONN_REFRESH_PERIOD
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Start log streaming for a given run ID."""
|
|
39
|
+
try:
|
|
40
|
+
while True:
|
|
41
|
+
logger(INFO, "Starting logstream for run_id `%s`", run_id)
|
|
42
|
+
stream_logs(run_id, channel, refresh_period)
|
|
43
|
+
time.sleep(2)
|
|
44
|
+
logger(DEBUG, "Reconnecting to logstream")
|
|
45
|
+
except KeyboardInterrupt:
|
|
46
|
+
logger(INFO, "Exiting logstream")
|
|
47
|
+
except grpc.RpcError as e:
|
|
48
|
+
# pylint: disable=E1101
|
|
49
|
+
if e.code() == grpc.StatusCode.NOT_FOUND:
|
|
50
|
+
logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
|
|
51
|
+
if e.code() == grpc.StatusCode.CANCELLED:
|
|
52
|
+
pass
|
|
53
|
+
finally:
|
|
54
|
+
channel.close()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None:
|
|
58
|
+
"""Stream logs from the beginning of a run with connection refresh."""
|
|
59
|
+
start_time = time.time()
|
|
60
|
+
stub = ExecStub(channel)
|
|
61
|
+
req = StreamLogsRequest(run_id=run_id)
|
|
62
|
+
|
|
63
|
+
for res in stub.StreamLogs(req):
|
|
64
|
+
print(res.log_output)
|
|
65
|
+
if time.time() - start_time > duration:
|
|
66
|
+
break
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
|
|
70
|
+
"""Print logs from the beginning of a run."""
|
|
71
|
+
stub = ExecStub(channel)
|
|
72
|
+
req = StreamLogsRequest(run_id=run_id)
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
while True:
|
|
76
|
+
try:
|
|
77
|
+
# Enforce timeout for graceful exit
|
|
78
|
+
for res in stub.StreamLogs(req, timeout=timeout):
|
|
79
|
+
print(res.log_output)
|
|
80
|
+
except grpc.RpcError as e:
|
|
81
|
+
# pylint: disable=E1101
|
|
82
|
+
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
83
|
+
break
|
|
84
|
+
if e.code() == grpc.StatusCode.NOT_FOUND:
|
|
85
|
+
logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
|
|
86
|
+
break
|
|
87
|
+
if e.code() == grpc.StatusCode.CANCELLED:
|
|
88
|
+
break
|
|
89
|
+
except KeyboardInterrupt:
|
|
90
|
+
logger(DEBUG, "Stream interrupted by user")
|
|
91
|
+
finally:
|
|
92
|
+
channel.close()
|
|
93
|
+
logger(DEBUG, "Channel closed")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
97
|
+
"""Log channel connectivity."""
|
|
98
|
+
logger(DEBUG, channel_connectivity)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def log(
|
|
102
|
+
run_id: Annotated[
|
|
103
|
+
int,
|
|
104
|
+
typer.Argument(help="The Flower run ID to query"),
|
|
105
|
+
],
|
|
106
|
+
app: Annotated[
|
|
107
|
+
Path,
|
|
108
|
+
typer.Argument(help="Path of the Flower project to run"),
|
|
109
|
+
] = Path("."),
|
|
110
|
+
federation: Annotated[
|
|
111
|
+
Optional[str],
|
|
112
|
+
typer.Argument(help="Name of the federation to run the app on"),
|
|
113
|
+
] = None,
|
|
114
|
+
stream: Annotated[
|
|
115
|
+
bool,
|
|
116
|
+
typer.Option(
|
|
117
|
+
"--stream/--show",
|
|
118
|
+
help="Flag to stream or print logs from the Flower run",
|
|
119
|
+
),
|
|
120
|
+
] = True,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""Get logs from a Flower project run."""
|
|
123
|
+
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
124
|
+
|
|
125
|
+
pyproject_path = app / "pyproject.toml" if app else None
|
|
126
|
+
config, errors, warnings = load_and_validate(path=pyproject_path)
|
|
127
|
+
|
|
128
|
+
if config is None:
|
|
129
|
+
typer.secho(
|
|
130
|
+
"Project configuration could not be loaded.\n"
|
|
131
|
+
"pyproject.toml is invalid:\n"
|
|
132
|
+
+ "\n".join([f"- {line}" for line in errors]),
|
|
133
|
+
fg=typer.colors.RED,
|
|
134
|
+
bold=True,
|
|
135
|
+
)
|
|
136
|
+
sys.exit()
|
|
137
|
+
|
|
138
|
+
if warnings:
|
|
139
|
+
typer.secho(
|
|
140
|
+
"Project configuration is missing the following "
|
|
141
|
+
"recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]),
|
|
142
|
+
fg=typer.colors.RED,
|
|
143
|
+
bold=True,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
typer.secho("Success", fg=typer.colors.GREEN)
|
|
147
|
+
|
|
148
|
+
federation = federation or config["tool"]["flwr"]["federations"].get("default")
|
|
149
|
+
|
|
150
|
+
if federation is None:
|
|
151
|
+
typer.secho(
|
|
152
|
+
"❌ No federation name was provided and the project's `pyproject.toml` "
|
|
153
|
+
"doesn't declare a default federation (with a SuperExec address or an "
|
|
154
|
+
"`options.num-supernodes` value).",
|
|
155
|
+
fg=typer.colors.RED,
|
|
156
|
+
bold=True,
|
|
157
|
+
)
|
|
158
|
+
raise typer.Exit(code=1)
|
|
159
|
+
|
|
160
|
+
# Validate the federation exists in the configuration
|
|
161
|
+
federation_config = config["tool"]["flwr"]["federations"].get(federation)
|
|
162
|
+
if federation_config is None:
|
|
163
|
+
available_feds = {
|
|
164
|
+
fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
|
|
165
|
+
}
|
|
166
|
+
typer.secho(
|
|
167
|
+
f"❌ There is no `{federation}` federation declared in the "
|
|
168
|
+
"`pyproject.toml`.\n The following federations were found:\n\n"
|
|
169
|
+
+ "\n".join(available_feds),
|
|
170
|
+
fg=typer.colors.RED,
|
|
171
|
+
bold=True,
|
|
172
|
+
)
|
|
173
|
+
raise typer.Exit(code=1)
|
|
174
|
+
|
|
175
|
+
if "address" not in federation_config:
|
|
176
|
+
typer.secho(
|
|
177
|
+
"❌ `flwr log` currently works with `SuperExec`. Ensure that the correct"
|
|
178
|
+
"`SuperExec` address is provided in the `pyproject.toml`.",
|
|
179
|
+
fg=typer.colors.RED,
|
|
180
|
+
bold=True,
|
|
181
|
+
)
|
|
182
|
+
raise typer.Exit(code=1)
|
|
183
|
+
|
|
184
|
+
_log_with_superexec(federation_config, run_id, stream)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# pylint: disable-next=too-many-branches
|
|
188
|
+
def _log_with_superexec(
|
|
189
|
+
federation_config: dict[str, str],
|
|
190
|
+
run_id: int,
|
|
191
|
+
stream: bool,
|
|
192
|
+
) -> None:
|
|
193
|
+
insecure_str = federation_config.get("insecure")
|
|
194
|
+
if root_certificates := federation_config.get("root-certificates"):
|
|
195
|
+
root_certificates_bytes = Path(root_certificates).read_bytes()
|
|
196
|
+
if insecure := bool(insecure_str):
|
|
197
|
+
typer.secho(
|
|
198
|
+
"❌ `root_certificates` were provided but the `insecure` parameter"
|
|
199
|
+
"is set to `True`.",
|
|
200
|
+
fg=typer.colors.RED,
|
|
201
|
+
bold=True,
|
|
202
|
+
)
|
|
203
|
+
raise typer.Exit(code=1)
|
|
204
|
+
else:
|
|
205
|
+
root_certificates_bytes = None
|
|
206
|
+
if insecure_str is None:
|
|
207
|
+
typer.secho(
|
|
208
|
+
"❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
|
|
209
|
+
fg=typer.colors.RED,
|
|
210
|
+
bold=True,
|
|
211
|
+
)
|
|
212
|
+
raise typer.Exit(code=1)
|
|
213
|
+
if not (insecure := bool(insecure_str)):
|
|
214
|
+
typer.secho(
|
|
215
|
+
"❌ No certificate were given yet `insecure` is set to `False`.",
|
|
216
|
+
fg=typer.colors.RED,
|
|
217
|
+
bold=True,
|
|
218
|
+
)
|
|
219
|
+
raise typer.Exit(code=1)
|
|
220
|
+
|
|
221
|
+
channel = create_channel(
|
|
222
|
+
server_address=federation_config["address"],
|
|
223
|
+
insecure=insecure,
|
|
224
|
+
root_certificates=root_certificates_bytes,
|
|
225
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
226
|
+
interceptors=None,
|
|
227
|
+
)
|
|
228
|
+
channel.subscribe(on_channel_state_change)
|
|
229
|
+
|
|
230
|
+
if stream:
|
|
231
|
+
start_stream(run_id, channel, CONN_REFRESH_PERIOD)
|
|
232
|
+
else:
|
|
233
|
+
logger(INFO, "Printing logstream for run_id `%s`", run_id)
|
|
234
|
+
print_logs(run_id, channel, timeout=5)
|
flwr/cli/new/new.py
CHANGED
|
@@ -275,7 +275,7 @@ def new(
|
|
|
275
275
|
)
|
|
276
276
|
)
|
|
277
277
|
|
|
278
|
-
_add = " huggingface-cli login\n" if
|
|
278
|
+
_add = " huggingface-cli login\n" if llm_challenge_str else ""
|
|
279
279
|
print(
|
|
280
280
|
typer.style(
|
|
281
281
|
f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
|
|
@@ -55,7 +55,7 @@ We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM c
|
|
|
55
55
|
| :--------: | :--------: | :--------: | :--------: | :--------: |
|
|
56
56
|
| VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
|
|
57
57
|
|
|
58
|
-
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which are specified with `options.backend.
|
|
58
|
+
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which are specified with `options.backend.client-resources.num-cpus` and `options.backend.client-resources.num-gpus` under `[tool.flwr.federations.local-simulation]` entry in `pyproject.toml`.
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
## Model saving
|
flwr/cli/run/run.py
CHANGED
|
@@ -34,6 +34,10 @@ from flwr.common.typing import Fab
|
|
|
34
34
|
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
35
35
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
36
36
|
|
|
37
|
+
from ..log import start_stream
|
|
38
|
+
|
|
39
|
+
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
|
|
40
|
+
|
|
37
41
|
|
|
38
42
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
39
43
|
"""Log channel connectivity."""
|
|
@@ -62,6 +66,14 @@ def run(
|
|
|
62
66
|
"inside the `pyproject.toml` in order to be properly overriden.",
|
|
63
67
|
),
|
|
64
68
|
] = None,
|
|
69
|
+
stream: Annotated[
|
|
70
|
+
bool,
|
|
71
|
+
typer.Option(
|
|
72
|
+
"--stream",
|
|
73
|
+
help="Use `--stream` with `flwr run` to display logs;\n "
|
|
74
|
+
"logs are not streamed by default.",
|
|
75
|
+
),
|
|
76
|
+
] = False,
|
|
65
77
|
) -> None:
|
|
66
78
|
"""Run Flower App."""
|
|
67
79
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
@@ -117,7 +129,7 @@ def run(
|
|
|
117
129
|
raise typer.Exit(code=1)
|
|
118
130
|
|
|
119
131
|
if "address" in federation_config:
|
|
120
|
-
_run_with_superexec(app, federation_config, config_overrides)
|
|
132
|
+
_run_with_superexec(app, federation_config, config_overrides, stream)
|
|
121
133
|
else:
|
|
122
134
|
_run_without_superexec(app, federation_config, config_overrides, federation)
|
|
123
135
|
|
|
@@ -126,6 +138,7 @@ def _run_with_superexec(
|
|
|
126
138
|
app: Path,
|
|
127
139
|
federation_config: dict[str, Any],
|
|
128
140
|
config_overrides: Optional[list[str]],
|
|
141
|
+
stream: bool,
|
|
129
142
|
) -> None:
|
|
130
143
|
|
|
131
144
|
insecure_str = federation_config.get("insecure")
|
|
@@ -183,6 +196,9 @@ def _run_with_superexec(
|
|
|
183
196
|
fab_path.unlink()
|
|
184
197
|
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
|
185
198
|
|
|
199
|
+
if stream:
|
|
200
|
+
start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
|
|
201
|
+
|
|
186
202
|
|
|
187
203
|
def _run_without_superexec(
|
|
188
204
|
app: Optional[Path],
|
|
@@ -31,6 +31,7 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
|
31
31
|
generate_shared_key,
|
|
32
32
|
public_key_to_bytes,
|
|
33
33
|
)
|
|
34
|
+
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
|
34
35
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
35
36
|
CreateNodeRequest,
|
|
36
37
|
DeleteNodeRequest,
|
|
@@ -50,6 +51,7 @@ Request = Union[
|
|
|
50
51
|
PushTaskResRequest,
|
|
51
52
|
GetRunRequest,
|
|
52
53
|
PingRequest,
|
|
54
|
+
GetFabRequest,
|
|
53
55
|
]
|
|
54
56
|
|
|
55
57
|
|
|
@@ -126,6 +128,7 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
126
128
|
PushTaskResRequest,
|
|
127
129
|
GetRunRequest,
|
|
128
130
|
PingRequest,
|
|
131
|
+
GetFabRequest,
|
|
129
132
|
),
|
|
130
133
|
):
|
|
131
134
|
if self.shared_secret is None:
|
|
@@ -269,7 +269,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
269
269
|
task_res = message_to_taskres(message)
|
|
270
270
|
|
|
271
271
|
# Serialize ProtoBuf to bytes
|
|
272
|
-
request = PushTaskResRequest(task_res_list=[task_res])
|
|
272
|
+
request = PushTaskResRequest(node=node, task_res_list=[task_res])
|
|
273
273
|
_ = retry_invoker.invoke(stub.PushTaskRes, request)
|
|
274
274
|
|
|
275
275
|
# Cleanup
|
|
@@ -277,7 +277,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
277
277
|
|
|
278
278
|
def get_run(run_id: int) -> Run:
|
|
279
279
|
# Call FleetAPI
|
|
280
|
-
get_run_request = GetRunRequest(run_id=run_id)
|
|
280
|
+
get_run_request = GetRunRequest(node=node, run_id=run_id)
|
|
281
281
|
get_run_response: GetRunResponse = retry_invoker.invoke(
|
|
282
282
|
stub.GetRun,
|
|
283
283
|
request=get_run_request,
|
|
@@ -294,7 +294,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
294
294
|
|
|
295
295
|
def get_fab(fab_hash: str) -> Fab:
|
|
296
296
|
# Call FleetAPI
|
|
297
|
-
get_fab_request = GetFabRequest(hash_str=fab_hash)
|
|
297
|
+
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash)
|
|
298
298
|
get_fab_response: GetFabResponse = retry_invoker.invoke(
|
|
299
299
|
stub.GetFab,
|
|
300
300
|
request=get_fab_request,
|
|
@@ -24,10 +24,14 @@ from google.protobuf.message import Message as GrpcMessage
|
|
|
24
24
|
|
|
25
25
|
from flwr.common import log
|
|
26
26
|
from flwr.common.constant import (
|
|
27
|
+
GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY,
|
|
28
|
+
GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY,
|
|
27
29
|
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY,
|
|
30
|
+
GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY,
|
|
31
|
+
GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY,
|
|
28
32
|
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
|
|
29
33
|
)
|
|
30
|
-
from flwr.common.version import package_version
|
|
34
|
+
from flwr.common.version import package_name, package_version
|
|
31
35
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
32
36
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
33
37
|
CreateNodeRequest,
|
|
@@ -62,9 +66,16 @@ class GrpcAdapter:
|
|
|
62
66
|
self, request: GrpcMessage, response_type: type[T], **kwargs: Any
|
|
63
67
|
) -> T:
|
|
64
68
|
# Serialize request
|
|
69
|
+
req_cls = request.__class__
|
|
65
70
|
container_req = MessageContainer(
|
|
66
|
-
metadata={
|
|
67
|
-
|
|
71
|
+
metadata={
|
|
72
|
+
GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY: package_name,
|
|
73
|
+
GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY: package_version,
|
|
74
|
+
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version,
|
|
75
|
+
GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY: req_cls.__module__,
|
|
76
|
+
GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY: req_cls.__qualname__,
|
|
77
|
+
},
|
|
78
|
+
grpc_message_name=req_cls.__qualname__,
|
|
68
79
|
grpc_message_content=request.SerializeToString(),
|
|
69
80
|
)
|
|
70
81
|
|
|
@@ -340,7 +340,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
340
340
|
task_res = message_to_taskres(message)
|
|
341
341
|
|
|
342
342
|
# Serialize ProtoBuf to bytes
|
|
343
|
-
req = PushTaskResRequest(task_res_list=[task_res])
|
|
343
|
+
req = PushTaskResRequest(node=node, task_res_list=[task_res])
|
|
344
344
|
|
|
345
345
|
# Send the request
|
|
346
346
|
res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES)
|
|
@@ -356,7 +356,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
356
356
|
|
|
357
357
|
def get_run(run_id: int) -> Run:
|
|
358
358
|
# Construct the request
|
|
359
|
-
req = GetRunRequest(run_id=run_id)
|
|
359
|
+
req = GetRunRequest(node=node, run_id=run_id)
|
|
360
360
|
|
|
361
361
|
# Send the request
|
|
362
362
|
res = _request(req, GetRunResponse, PATH_GET_RUN)
|
|
@@ -373,7 +373,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
373
373
|
|
|
374
374
|
def get_fab(fab_hash: str) -> Fab:
|
|
375
375
|
# Construct the request
|
|
376
|
-
req = GetFabRequest(hash_str=fab_hash)
|
|
376
|
+
req = GetFabRequest(node=node, hash_str=fab_hash)
|
|
377
377
|
|
|
378
378
|
# Send the request
|
|
379
379
|
res = _request(req, GetFabResponse, PATH_GET_FAB)
|
flwr/client/supernode/app.py
CHANGED
flwr/common/constant.py
CHANGED
|
@@ -60,8 +60,6 @@ PING_MAX_INTERVAL = 1e300
|
|
|
60
60
|
# IDs
|
|
61
61
|
RUN_ID_NUM_BYTES = 8
|
|
62
62
|
NODE_ID_NUM_BYTES = 8
|
|
63
|
-
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
|
|
64
|
-
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
|
|
65
63
|
|
|
66
64
|
# Constants for FAB
|
|
67
65
|
APP_DIR = "apps"
|
|
@@ -72,8 +70,13 @@ FLWR_HOME = "FLWR_HOME"
|
|
|
72
70
|
PARTITION_ID_KEY = "partition-id"
|
|
73
71
|
NUM_PARTITIONS_KEY = "num-partitions"
|
|
74
72
|
|
|
75
|
-
|
|
73
|
+
# Constants for keys in `metadata` of `MessageContainer` in `grpc-adapter`
|
|
74
|
+
GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_NAME_KEY = "flower-package-name"
|
|
75
|
+
GRPC_ADAPTER_METADATA_FLOWER_PACKAGE_VERSION_KEY = "flower-package-version"
|
|
76
|
+
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version" # Deprecated
|
|
76
77
|
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
|
|
78
|
+
GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY = "grpc-message-module"
|
|
79
|
+
GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY = "grpc-message-qualname"
|
|
77
80
|
|
|
78
81
|
|
|
79
82
|
class MessageType:
|
|
@@ -43,8 +43,8 @@ def share_keys_plaintext_concat(
|
|
|
43
43
|
"""
|
|
44
44
|
return b"".join(
|
|
45
45
|
[
|
|
46
|
-
int.to_bytes(src_node_id, 8, "little", signed=
|
|
47
|
-
int.to_bytes(dst_node_id, 8, "little", signed=
|
|
46
|
+
int.to_bytes(src_node_id, 8, "little", signed=False),
|
|
47
|
+
int.to_bytes(dst_node_id, 8, "little", signed=False),
|
|
48
48
|
int.to_bytes(len(b_share), 4, "little"),
|
|
49
49
|
b_share,
|
|
50
50
|
sk_share,
|
|
@@ -72,8 +72,8 @@ def share_keys_plaintext_separate(plaintext: bytes) -> tuple[int, int, bytes, by
|
|
|
72
72
|
the secret key share of the source sent to the destination.
|
|
73
73
|
"""
|
|
74
74
|
src, dst, mark = (
|
|
75
|
-
int.from_bytes(plaintext[:8], "little", signed=
|
|
76
|
-
int.from_bytes(plaintext[8:16], "little", signed=
|
|
75
|
+
int.from_bytes(plaintext[:8], "little", signed=False),
|
|
76
|
+
int.from_bytes(plaintext[8:16], "little", signed=False),
|
|
77
77
|
int.from_bytes(plaintext[16:20], "little"),
|
|
78
78
|
)
|
|
79
79
|
ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :])
|
flwr/common/serde.py
CHANGED
|
@@ -38,7 +38,7 @@ from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
|
|
|
38
38
|
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
|
|
39
39
|
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
|
|
40
40
|
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
41
|
-
from flwr.proto.recordset_pb2 import
|
|
41
|
+
from flwr.proto.recordset_pb2 import SintList, StringList, UintList
|
|
42
42
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
|
43
43
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
44
44
|
from flwr.proto.transport_pb2 import (
|
|
@@ -340,6 +340,7 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
|
|
|
340
340
|
|
|
341
341
|
|
|
342
342
|
# === Scalar messages ===
|
|
343
|
+
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
|
343
344
|
|
|
344
345
|
|
|
345
346
|
def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
|
|
@@ -354,6 +355,9 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
|
|
|
354
355
|
return Scalar(double=scalar)
|
|
355
356
|
|
|
356
357
|
if isinstance(scalar, int):
|
|
358
|
+
# Use uint64 for integers larger than the maximum value of sint64
|
|
359
|
+
if scalar > INT64_MAX_VALUE:
|
|
360
|
+
return Scalar(uint64=scalar)
|
|
357
361
|
return Scalar(sint64=scalar)
|
|
358
362
|
|
|
359
363
|
if isinstance(scalar, str):
|
|
@@ -374,16 +378,16 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
|
374
378
|
# === Record messages ===
|
|
375
379
|
|
|
376
380
|
|
|
377
|
-
_type_to_field = {
|
|
381
|
+
_type_to_field: dict[type, str] = {
|
|
378
382
|
float: "double",
|
|
379
383
|
int: "sint64",
|
|
380
384
|
bool: "bool",
|
|
381
385
|
str: "string",
|
|
382
386
|
bytes: "bytes",
|
|
383
387
|
}
|
|
384
|
-
_list_type_to_class_and_field = {
|
|
388
|
+
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
|
385
389
|
float: (DoubleList, "double_list"),
|
|
386
|
-
int: (
|
|
390
|
+
int: (SintList, "sint_list"),
|
|
387
391
|
bool: (BoolList, "bool_list"),
|
|
388
392
|
str: (StringList, "string_list"),
|
|
389
393
|
bytes: (BytesList, "bytes_list"),
|
|
@@ -391,6 +395,11 @@ _list_type_to_class_and_field = {
|
|
|
391
395
|
T = TypeVar("T")
|
|
392
396
|
|
|
393
397
|
|
|
398
|
+
def _is_uint64(value: Any) -> bool:
|
|
399
|
+
"""Check if a value is uint64."""
|
|
400
|
+
return isinstance(value, int) and value > INT64_MAX_VALUE
|
|
401
|
+
|
|
402
|
+
|
|
394
403
|
def _record_value_to_proto(
|
|
395
404
|
value: Any, allowed_types: list[type], proto_class: type[T]
|
|
396
405
|
) -> T:
|
|
@@ -403,12 +412,18 @@ def _record_value_to_proto(
|
|
|
403
412
|
# Single element
|
|
404
413
|
# Note: `isinstance(False, int) == True`.
|
|
405
414
|
if isinstance(value, t):
|
|
406
|
-
|
|
415
|
+
fld = _type_to_field[t]
|
|
416
|
+
if t is int and _is_uint64(value):
|
|
417
|
+
fld = "uint64"
|
|
418
|
+
arg[fld] = value
|
|
407
419
|
return proto_class(**arg)
|
|
408
420
|
# List
|
|
409
421
|
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
|
410
|
-
list_class,
|
|
411
|
-
|
|
422
|
+
list_class, fld = _list_type_to_class_and_field[t]
|
|
423
|
+
# Use UintList if any element is of type `uint64`.
|
|
424
|
+
if t is int and any(_is_uint64(v) for v in value):
|
|
425
|
+
list_class, fld = UintList, "uint_list"
|
|
426
|
+
arg[fld] = list_class(vals=value)
|
|
412
427
|
return proto_class(**arg)
|
|
413
428
|
# Invalid types
|
|
414
429
|
raise TypeError(
|
flwr/proto/clientappio_pb2.py
CHANGED
|
@@ -17,7 +17,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
|
17
17
|
from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/clientappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x18\x66lwr/proto/message.proto\"W\n\x15\x43lientAppOutputStatus\x12-\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1f.flwr.proto.ClientAppOutputCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x11\n\x0fGetTokenRequest\"!\n\x10GetTokenResponse\x12\r\n\x05token\x18\x01 \x01(\
|
|
20
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/clientappio.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x18\x66lwr/proto/message.proto\"W\n\x15\x43lientAppOutputStatus\x12-\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1f.flwr.proto.ClientAppOutputCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x11\n\x0fGetTokenRequest\"!\n\x10GetTokenResponse\x12\r\n\x05token\x18\x01 \x01(\x04\"+\n\x1aPullClientAppInputsRequest\x12\r\n\x05token\x18\x01 \x01(\x04\"\xa5\x01\n\x1bPullClientAppInputsResponse\x12$\n\x07message\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Message\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\"x\n\x1bPushClientAppOutputsRequest\x12\r\n\x05token\x18\x01 \x01(\x04\x12$\n\x07message\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Message\x12$\n\x07\x63ontext\x18\x03 \x01(\x0b\x32\x13.flwr.proto.Context\"Q\n\x1cPushClientAppOutputsResponse\x12\x31\n\x06status\x18\x01 \x01(\x0b\x32!.flwr.proto.ClientAppOutputStatus*L\n\x13\x43lientAppOutputCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x15\n\x11\x44\x45\x41\x44LINE_EXCEEDED\x10\x01\x12\x11\n\rUNKNOWN_ERROR\x10\x02\x32\xad\x02\n\x0b\x43lientAppIo\x12G\n\x08GetToken\x12\x1b.flwr.proto.GetTokenRequest\x1a\x1c.flwr.proto.GetTokenResponse\"\x00\x12h\n\x13PullClientAppInputs\x12&.flwr.proto.PullClientAppInputsRequest\x1a\'.flwr.proto.PullClientAppInputsResponse\"\x00\x12k\n\x14PushClientAppOutputs\x12\'.flwr.proto.PushClientAppOutputsRequest\x1a(.flwr.proto.PushClientAppOutputsResponse\"\x00\x62\x06proto3')
|
|
21
21
|
|
|
22
22
|
_globals = globals()
|
|
23
23
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# source: flwr/proto/control.proto
|
|
4
|
+
# Protobuf Python Version: 4.25.0
|
|
5
|
+
"""Generated protocol buffer code."""
|
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
9
|
+
from google.protobuf.internal import builder as _builder
|
|
10
|
+
# @@protoc_insertion_point(imports)
|
|
11
|
+
|
|
12
|
+
_sym_db = _symbol_database.Default()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2\x88\x02\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x62\x06proto3')
|
|
19
|
+
|
|
20
|
+
_globals = globals()
|
|
21
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
22
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.control_pb2', _globals)
|
|
23
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
24
|
+
DESCRIPTOR._options = None
|
|
25
|
+
_globals['_CONTROL']._serialized_start=63
|
|
26
|
+
_globals['_CONTROL']._serialized_end=327
|
|
27
|
+
# @@protoc_insertion_point(module_scope)
|