flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +17 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +196 -42
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +109 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/clientapp/mod/__init__.py +4 -1
- flwr/clientapp/mod/centraldp_mods.py +156 -40
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +56 -13
- flwr/common/exit/exit_code.py +24 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -31
- flwr/proto/control_pb2.pyi +95 -5
- flwr/proto/control_pb2_grpc.py +136 -0
- flwr/proto/control_pb2_grpc.pyi +52 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +152 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/serverapp/strategy/__init__.py +26 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavg.py +89 -64
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +28 -32
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +41 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +16 -11
- flwr/superlink/servicer/control/control_servicer.py +207 -58
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
flwr/cli/login/login.py
CHANGED
|
@@ -20,6 +20,7 @@ from typing import Annotated, Optional
|
|
|
20
20
|
|
|
21
21
|
import typer
|
|
22
22
|
|
|
23
|
+
from flwr.cli.auth_plugin import LoginError, NoOpCliAuthPlugin
|
|
23
24
|
from flwr.cli.config_utils import (
|
|
24
25
|
exit_if_no_address,
|
|
25
26
|
get_insecure_flag,
|
|
@@ -28,14 +29,19 @@ from flwr.cli.config_utils import (
|
|
|
28
29
|
validate_federation_in_project_config,
|
|
29
30
|
)
|
|
30
31
|
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
|
|
31
|
-
from flwr.common.typing import
|
|
32
|
+
from flwr.common.typing import AccountAuthLoginDetails
|
|
32
33
|
from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
33
34
|
GetLoginDetailsRequest,
|
|
34
35
|
GetLoginDetailsResponse,
|
|
35
36
|
)
|
|
36
37
|
from flwr.proto.control_pb2_grpc import ControlStub
|
|
37
38
|
|
|
38
|
-
from ..utils import
|
|
39
|
+
from ..utils import (
|
|
40
|
+
account_auth_enabled,
|
|
41
|
+
flwr_cli_grpc_exc_handler,
|
|
42
|
+
init_channel,
|
|
43
|
+
load_cli_auth_plugin,
|
|
44
|
+
)
|
|
39
45
|
|
|
40
46
|
|
|
41
47
|
def login( # pylint: disable=R0914
|
|
@@ -67,12 +73,13 @@ def login( # pylint: disable=R0914
|
|
|
67
73
|
)
|
|
68
74
|
exit_if_no_address(federation_config, "login")
|
|
69
75
|
|
|
70
|
-
# Check if `enable-
|
|
71
|
-
|
|
76
|
+
# Check if `enable-account-auth` is set to `true`
|
|
77
|
+
|
|
78
|
+
if not account_auth_enabled(federation_config):
|
|
72
79
|
typer.secho(
|
|
73
|
-
|
|
74
|
-
"To enable it, set `enable-
|
|
75
|
-
"configuration.",
|
|
80
|
+
"❌ Account authentication is not enabled for the federation "
|
|
81
|
+
f"'{federation}'. To enable it, set `enable-account-auth = true` "
|
|
82
|
+
"in the federation configuration.",
|
|
76
83
|
fg=typer.colors.RED,
|
|
77
84
|
bold=True,
|
|
78
85
|
)
|
|
@@ -88,7 +95,7 @@ def login( # pylint: disable=R0914
|
|
|
88
95
|
)
|
|
89
96
|
raise typer.Exit(code=1)
|
|
90
97
|
|
|
91
|
-
channel = init_channel(app, federation_config,
|
|
98
|
+
channel = init_channel(app, federation_config, NoOpCliAuthPlugin(Path()))
|
|
92
99
|
stub = ControlStub(channel)
|
|
93
100
|
|
|
94
101
|
login_request = GetLoginDetailsRequest()
|
|
@@ -96,28 +103,32 @@ def login( # pylint: disable=R0914
|
|
|
96
103
|
login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request)
|
|
97
104
|
|
|
98
105
|
# Get the auth plugin
|
|
99
|
-
|
|
100
|
-
auth_plugin =
|
|
101
|
-
app, federation, federation_config, auth_type
|
|
102
|
-
)
|
|
103
|
-
if auth_plugin is None:
|
|
104
|
-
typer.secho(
|
|
105
|
-
f'❌ Authentication type "{auth_type}" not found',
|
|
106
|
-
fg=typer.colors.RED,
|
|
107
|
-
bold=True,
|
|
108
|
-
)
|
|
109
|
-
raise typer.Exit(code=1)
|
|
106
|
+
authn_type = login_response.authn_type
|
|
107
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config, authn_type)
|
|
110
108
|
|
|
111
109
|
# Login
|
|
112
|
-
details =
|
|
113
|
-
|
|
110
|
+
details = AccountAuthLoginDetails(
|
|
111
|
+
authn_type=login_response.authn_type,
|
|
114
112
|
device_code=login_response.device_code,
|
|
115
113
|
verification_uri_complete=login_response.verification_uri_complete,
|
|
116
114
|
expires_in=login_response.expires_in,
|
|
117
115
|
interval=login_response.interval,
|
|
118
116
|
)
|
|
119
|
-
|
|
120
|
-
|
|
117
|
+
try:
|
|
118
|
+
with flwr_cli_grpc_exc_handler():
|
|
119
|
+
credentials = auth_plugin.login(details, stub)
|
|
120
|
+
typer.secho(
|
|
121
|
+
"✅ Login successful.",
|
|
122
|
+
fg=typer.colors.GREEN,
|
|
123
|
+
bold=False,
|
|
124
|
+
)
|
|
125
|
+
except LoginError as e:
|
|
126
|
+
typer.secho(
|
|
127
|
+
f"❌ Login failed: {e.message}",
|
|
128
|
+
fg=typer.colors.RED,
|
|
129
|
+
bold=True,
|
|
130
|
+
)
|
|
131
|
+
raise typer.Exit(code=1) from None
|
|
121
132
|
|
|
122
133
|
# Store the tokens
|
|
123
134
|
auth_plugin.store_tokens(credentials)
|
flwr/cli/ls.py
CHANGED
|
@@ -19,7 +19,7 @@ import io
|
|
|
19
19
|
import json
|
|
20
20
|
from datetime import datetime, timedelta
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Annotated, Optional
|
|
22
|
+
from typing import Annotated, Optional, cast
|
|
23
23
|
|
|
24
24
|
import typer
|
|
25
25
|
from rich.console import Console
|
|
@@ -44,12 +44,13 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
|
44
44
|
)
|
|
45
45
|
from flwr.proto.control_pb2_grpc import ControlStub
|
|
46
46
|
|
|
47
|
-
from .utils import flwr_cli_grpc_exc_handler, init_channel,
|
|
47
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
|
|
48
48
|
|
|
49
49
|
_RunListType = tuple[int, str, str, str, str, str, str, str, str]
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
|
|
53
|
+
ctx: typer.Context,
|
|
53
54
|
app: Annotated[
|
|
54
55
|
Path,
|
|
55
56
|
typer.Argument(help="Path of the Flower project"),
|
|
@@ -102,6 +103,9 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
|
|
|
102
103
|
|
|
103
104
|
All timestamps follow ISO 8601, UTC and are formatted as ``YYYY-MM-DD HH:MM:SSZ``.
|
|
104
105
|
"""
|
|
106
|
+
# Resolve command used (list or ls)
|
|
107
|
+
command_name = cast(str, ctx.command.name) if ctx.command else "list"
|
|
108
|
+
|
|
105
109
|
suppress_output = output_format == CliOutputFormat.JSON
|
|
106
110
|
captured_output = io.StringIO()
|
|
107
111
|
try:
|
|
@@ -116,14 +120,14 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
|
|
|
116
120
|
federation, federation_config = validate_federation_in_project_config(
|
|
117
121
|
federation, config, federation_config_overrides
|
|
118
122
|
)
|
|
119
|
-
exit_if_no_address(federation_config,
|
|
123
|
+
exit_if_no_address(federation_config, command_name)
|
|
120
124
|
channel = None
|
|
121
125
|
try:
|
|
122
126
|
if runs and run_id is not None:
|
|
123
127
|
raise ValueError(
|
|
124
128
|
"The options '--runs' and '--run-id' are mutually exclusive."
|
|
125
129
|
)
|
|
126
|
-
auth_plugin =
|
|
130
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
|
|
127
131
|
channel = init_channel(app, federation_config, auth_plugin)
|
|
128
132
|
stub = ControlStub(channel)
|
|
129
133
|
|
|
@@ -216,14 +220,14 @@ def _to_table(run_list: list[_RunListType]) -> Table:
|
|
|
216
220
|
|
|
217
221
|
# Add columns
|
|
218
222
|
table.add_column(
|
|
219
|
-
Text("Run ID", justify="center"), style="
|
|
223
|
+
Text("Run ID", justify="center"), style="bright_black", no_wrap=True
|
|
220
224
|
)
|
|
221
|
-
table.add_column(Text("FAB", justify="center"), style="
|
|
225
|
+
table.add_column(Text("FAB", justify="center"), style="bright_black")
|
|
222
226
|
table.add_column(Text("Status", justify="center"))
|
|
223
227
|
table.add_column(Text("Elapsed", justify="center"), style="blue")
|
|
224
|
-
table.add_column(Text("Created At", justify="center"), style="
|
|
225
|
-
table.add_column(Text("Running At", justify="center"), style="
|
|
226
|
-
table.add_column(Text("Finished At", justify="center"), style="
|
|
228
|
+
table.add_column(Text("Created At", justify="center"), style="bright_black")
|
|
229
|
+
table.add_column(Text("Running At", justify="center"), style="bright_black")
|
|
230
|
+
table.add_column(Text("Finished At", justify="center"), style="bright_black")
|
|
227
231
|
|
|
228
232
|
for row in run_list:
|
|
229
233
|
(
|
flwr/cli/new/new.py
CHANGED
|
@@ -15,14 +15,20 @@
|
|
|
15
15
|
"""Flower command line interface `new` command."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import io
|
|
19
|
+
import json
|
|
18
20
|
import re
|
|
21
|
+
import zipfile
|
|
19
22
|
from enum import Enum
|
|
20
23
|
from pathlib import Path
|
|
21
24
|
from string import Template
|
|
22
25
|
from typing import Annotated, Optional
|
|
23
26
|
|
|
27
|
+
import requests
|
|
24
28
|
import typer
|
|
25
29
|
|
|
30
|
+
from flwr.supercore.constant import APP_ID_PATTERN, PLATFORM_API_URL
|
|
31
|
+
|
|
26
32
|
from ..utils import (
|
|
27
33
|
is_valid_project_name,
|
|
28
34
|
prompt_options,
|
|
@@ -35,15 +41,16 @@ class MlFramework(str, Enum):
|
|
|
35
41
|
"""Available frameworks."""
|
|
36
42
|
|
|
37
43
|
PYTORCH = "PyTorch"
|
|
38
|
-
PYTORCH_MSG_API = "PyTorch (Message API)"
|
|
39
44
|
TENSORFLOW = "TensorFlow"
|
|
40
45
|
SKLEARN = "sklearn"
|
|
41
46
|
HUGGINGFACE = "HuggingFace"
|
|
42
47
|
JAX = "JAX"
|
|
43
48
|
MLX = "MLX"
|
|
44
49
|
NUMPY = "NumPy"
|
|
50
|
+
XGBOOST = "XGBoost"
|
|
45
51
|
FLOWERTUNE = "FlowerTune"
|
|
46
52
|
BASELINE = "Flower Baseline"
|
|
53
|
+
PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
|
|
47
54
|
|
|
48
55
|
|
|
49
56
|
class LlmChallengeName(str, Enum):
|
|
@@ -92,6 +99,180 @@ def render_and_create(file_path: Path, template: str, context: dict[str, str]) -
|
|
|
92
99
|
create_file(file_path, content)
|
|
93
100
|
|
|
94
101
|
|
|
102
|
+
def print_success_prompt(
|
|
103
|
+
package_name: str, llm_challenge_str: Optional[str] = None
|
|
104
|
+
) -> None:
|
|
105
|
+
"""Print styled setup instructions for running a new Flower App after creation."""
|
|
106
|
+
prompt = typer.style(
|
|
107
|
+
"🎊 Flower App creation successful.\n\n"
|
|
108
|
+
"To run your Flower App, first install its dependencies:\n\n",
|
|
109
|
+
fg=typer.colors.GREEN,
|
|
110
|
+
bold=True,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
_add = " huggingface-cli login\n" if llm_challenge_str else ""
|
|
114
|
+
|
|
115
|
+
prompt += typer.style(
|
|
116
|
+
f" cd {package_name} && pip install -e .\n" + _add + "\n",
|
|
117
|
+
fg=typer.colors.BRIGHT_CYAN,
|
|
118
|
+
bold=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
prompt += typer.style(
|
|
122
|
+
"then, run the app:\n\n ",
|
|
123
|
+
fg=typer.colors.GREEN,
|
|
124
|
+
bold=True,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
prompt += typer.style(
|
|
128
|
+
"\tflwr run .\n\n",
|
|
129
|
+
fg=typer.colors.BRIGHT_CYAN,
|
|
130
|
+
bold=True,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
prompt += typer.style(
|
|
134
|
+
"💡 Check the README in your app directory to learn how to\n"
|
|
135
|
+
"customize it and how to run it using the Deployment Runtime.\n",
|
|
136
|
+
fg=typer.colors.GREEN,
|
|
137
|
+
bold=True,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
print(prompt)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# Security: prevent zip-slip
|
|
144
|
+
def _safe_extract_zip(zf: zipfile.ZipFile, dest_dir: Path) -> None:
|
|
145
|
+
"""Extract ZIP file into destination directory."""
|
|
146
|
+
dest_dir = dest_dir.resolve()
|
|
147
|
+
|
|
148
|
+
def _is_within_directory(base: Path, target: Path) -> bool:
|
|
149
|
+
try:
|
|
150
|
+
target.relative_to(base)
|
|
151
|
+
return True
|
|
152
|
+
except ValueError:
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
for member in zf.infolist():
|
|
156
|
+
# Skip directory placeholders;
|
|
157
|
+
# ZipInfo can represent them as names ending with '/'.
|
|
158
|
+
if member.is_dir():
|
|
159
|
+
target_path = (dest_dir / member.filename).resolve()
|
|
160
|
+
if not _is_within_directory(dest_dir, target_path):
|
|
161
|
+
raise ValueError(f"Unsafe path in zip: {member.filename}")
|
|
162
|
+
target_path.mkdir(parents=True, exist_ok=True)
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
# Files
|
|
166
|
+
target_path = (dest_dir / member.filename).resolve()
|
|
167
|
+
if not _is_within_directory(dest_dir, target_path):
|
|
168
|
+
raise ValueError(f"Unsafe path in zip: {member.filename}")
|
|
169
|
+
|
|
170
|
+
# Ensure parent exists
|
|
171
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
172
|
+
|
|
173
|
+
# Extract
|
|
174
|
+
with zf.open(member, "r") as src, open(target_path, "wb") as dst:
|
|
175
|
+
dst.write(src.read())
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _download_zip_to_memory(presigned_url: str) -> io.BytesIO:
|
|
179
|
+
"""Download ZIP file from Platform API to memory."""
|
|
180
|
+
try:
|
|
181
|
+
r = requests.get(presigned_url, timeout=60)
|
|
182
|
+
r.raise_for_status()
|
|
183
|
+
except requests.RequestException as e:
|
|
184
|
+
raise typer.BadParameter(f"ZIP download failed: {e}") from e
|
|
185
|
+
|
|
186
|
+
buf = io.BytesIO(r.content)
|
|
187
|
+
# Validate it's a zip
|
|
188
|
+
if not zipfile.is_zipfile(buf):
|
|
189
|
+
raise typer.BadParameter("Downloaded file is not a valid ZIP")
|
|
190
|
+
buf.seek(0)
|
|
191
|
+
return buf
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _request_download_link(identifier: str) -> str:
|
|
195
|
+
"""Request download link from Flower platform API."""
|
|
196
|
+
url = f"{PLATFORM_API_URL}/hub/fetch-zip"
|
|
197
|
+
headers = {
|
|
198
|
+
"Content-Type": "application/json",
|
|
199
|
+
"Accept": "application/json",
|
|
200
|
+
}
|
|
201
|
+
body = {
|
|
202
|
+
"identifier": identifier, # send raw string of identifier
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
resp = requests.post(url, headers=headers, data=json.dumps(body), timeout=20)
|
|
207
|
+
except requests.RequestException as e:
|
|
208
|
+
raise typer.BadParameter(f"Unable to connect to Platform API: {e}") from e
|
|
209
|
+
|
|
210
|
+
if resp.status_code == 404:
|
|
211
|
+
raise typer.BadParameter(f"'{identifier}' not found in Platform API")
|
|
212
|
+
if not resp.ok:
|
|
213
|
+
raise typer.BadParameter(
|
|
214
|
+
f"Platform API request failed with "
|
|
215
|
+
f"status {resp.status_code}. Details: {resp.text}"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
data = resp.json()
|
|
219
|
+
if "zip_url" not in data:
|
|
220
|
+
raise typer.BadParameter("Invalid response from Platform API")
|
|
221
|
+
return str(data["zip_url"])
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def download_remote_app_via_api(identifier: str) -> None:
|
|
225
|
+
"""Download App from Platform API."""
|
|
226
|
+
# Parse @user/app just to derive local dir name
|
|
227
|
+
m = re.match(APP_ID_PATTERN, identifier)
|
|
228
|
+
if not m:
|
|
229
|
+
raise typer.BadParameter(
|
|
230
|
+
"Invalid remote app ID. Expected format: '@user_name/app_name'."
|
|
231
|
+
)
|
|
232
|
+
app_name = m.group("app")
|
|
233
|
+
|
|
234
|
+
project_dir = Path.cwd() / app_name
|
|
235
|
+
if project_dir.exists():
|
|
236
|
+
if not typer.confirm(
|
|
237
|
+
typer.style(
|
|
238
|
+
f"\n💬 {app_name} already exists, do you want to override it?",
|
|
239
|
+
fg=typer.colors.MAGENTA,
|
|
240
|
+
bold=True,
|
|
241
|
+
)
|
|
242
|
+
):
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
print(
|
|
246
|
+
typer.style(
|
|
247
|
+
f"\n🔗 Requesting download link for {identifier}...",
|
|
248
|
+
fg=typer.colors.GREEN,
|
|
249
|
+
bold=True,
|
|
250
|
+
)
|
|
251
|
+
)
|
|
252
|
+
presigned_url = _request_download_link(identifier)
|
|
253
|
+
|
|
254
|
+
print(
|
|
255
|
+
typer.style(
|
|
256
|
+
"⬇️ Downloading ZIP into memory...",
|
|
257
|
+
fg=typer.colors.GREEN,
|
|
258
|
+
bold=True,
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
zip_buf = _download_zip_to_memory(presigned_url)
|
|
262
|
+
|
|
263
|
+
print(
|
|
264
|
+
typer.style(
|
|
265
|
+
f"📦 Unpacking into {project_dir}...",
|
|
266
|
+
fg=typer.colors.GREEN,
|
|
267
|
+
bold=True,
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
with zipfile.ZipFile(zip_buf) as zf:
|
|
271
|
+
_safe_extract_zip(zf, Path.cwd())
|
|
272
|
+
|
|
273
|
+
print_success_prompt(app_name)
|
|
274
|
+
|
|
275
|
+
|
|
95
276
|
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
|
96
277
|
def new(
|
|
97
278
|
app_name: Annotated[
|
|
@@ -110,6 +291,12 @@ def new(
|
|
|
110
291
|
"""Create new Flower App."""
|
|
111
292
|
if app_name is None:
|
|
112
293
|
app_name = prompt_text("Please provide the app name")
|
|
294
|
+
|
|
295
|
+
# Download remote app
|
|
296
|
+
if app_name and app_name.startswith("@"):
|
|
297
|
+
download_remote_app_via_api(app_name)
|
|
298
|
+
return
|
|
299
|
+
|
|
113
300
|
if not is_valid_project_name(app_name):
|
|
114
301
|
app_name = prompt_text(
|
|
115
302
|
"Please provide a name that only contains "
|
|
@@ -155,8 +342,8 @@ def new(
|
|
|
155
342
|
if framework_str == MlFramework.BASELINE:
|
|
156
343
|
framework_str = "baseline"
|
|
157
344
|
|
|
158
|
-
if framework_str == MlFramework.
|
|
159
|
-
framework_str = "
|
|
345
|
+
if framework_str == MlFramework.PYTORCH_LEGACY_API:
|
|
346
|
+
framework_str = "pytorch_legacy_api"
|
|
160
347
|
|
|
161
348
|
print(
|
|
162
349
|
typer.style(
|
|
@@ -201,7 +388,7 @@ def new(
|
|
|
201
388
|
}
|
|
202
389
|
|
|
203
390
|
# Challenge specific context
|
|
204
|
-
|
|
391
|
+
fraction_train = "0.2" if llm_challenge_str == "code" else "0.1"
|
|
205
392
|
if llm_challenge_str == "generalnlp":
|
|
206
393
|
challenge_name = "General NLP"
|
|
207
394
|
num_clients = "20"
|
|
@@ -220,7 +407,7 @@ def new(
|
|
|
220
407
|
dataset_name = "flwrlabs/code-alpaca-20k"
|
|
221
408
|
|
|
222
409
|
context["llm_challenge_str"] = llm_challenge_str
|
|
223
|
-
context["
|
|
410
|
+
context["fraction_train"] = fraction_train
|
|
224
411
|
context["challenge_name"] = challenge_name
|
|
225
412
|
context["num_clients"] = num_clients
|
|
226
413
|
context["dataset_name"] = dataset_name
|
|
@@ -247,14 +434,15 @@ def new(
|
|
|
247
434
|
MlFramework.TENSORFLOW.value,
|
|
248
435
|
MlFramework.SKLEARN.value,
|
|
249
436
|
MlFramework.NUMPY.value,
|
|
250
|
-
|
|
437
|
+
MlFramework.XGBOOST.value,
|
|
438
|
+
"pytorch_legacy_api",
|
|
251
439
|
]
|
|
252
440
|
if framework_str in frameworks_with_tasks:
|
|
253
441
|
files[f"{import_name}/task.py"] = {
|
|
254
442
|
"template": f"app/code/task.{template_name}.py.tpl"
|
|
255
443
|
}
|
|
256
444
|
|
|
257
|
-
if framework_str == "
|
|
445
|
+
if framework_str == "pytorch_legacy_api":
|
|
258
446
|
# Use custom __init__ that better captures name of framework
|
|
259
447
|
files[f"{import_name}/__init__.py"] = {
|
|
260
448
|
"template": f"app/code/__init__.{framework_str}.py.tpl"
|
|
@@ -280,38 +468,4 @@ def new(
|
|
|
280
468
|
context=context,
|
|
281
469
|
)
|
|
282
470
|
|
|
283
|
-
|
|
284
|
-
"🎊 Flower App creation successful.\n\n"
|
|
285
|
-
"To run your Flower App, first install its dependencies:\n\n",
|
|
286
|
-
fg=typer.colors.GREEN,
|
|
287
|
-
bold=True,
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
_add = " huggingface-cli login\n" if llm_challenge_str else ""
|
|
291
|
-
|
|
292
|
-
prompt += typer.style(
|
|
293
|
-
f" cd {package_name} && pip install -e .\n" + _add + "\n",
|
|
294
|
-
fg=typer.colors.BRIGHT_CYAN,
|
|
295
|
-
bold=True,
|
|
296
|
-
)
|
|
297
|
-
|
|
298
|
-
prompt += typer.style(
|
|
299
|
-
"then, run the app:\n\n ",
|
|
300
|
-
fg=typer.colors.GREEN,
|
|
301
|
-
bold=True,
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
prompt += typer.style(
|
|
305
|
-
"\tflwr run .\n\n",
|
|
306
|
-
fg=typer.colors.BRIGHT_CYAN,
|
|
307
|
-
bold=True,
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
prompt += typer.style(
|
|
311
|
-
"💡 Check the README in your app directory to learn how to\n"
|
|
312
|
-
"customize it and how to run it using the Deployment Runtime.\n",
|
|
313
|
-
fg=typer.colors.GREEN,
|
|
314
|
-
bold=True,
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
print(prompt)
|
|
471
|
+
print_success_prompt(package_name, llm_challenge_str)
|
|
@@ -26,7 +26,7 @@ pip install -e .
|
|
|
26
26
|
## Experimental setup
|
|
27
27
|
|
|
28
28
|
The dataset is divided into $num_clients partitions in an IID fashion, a partition is assigned to each ClientApp.
|
|
29
|
-
We randomly sample a fraction ($
|
|
29
|
+
We randomly sample a fraction ($fraction_train) of the total nodes to participate in each round, for a total of `200` rounds.
|
|
30
30
|
All settings are defined in `pyproject.toml`.
|
|
31
31
|
|
|
32
32
|
> [!IMPORTANT]
|
|
@@ -1,58 +1,75 @@
|
|
|
1
1
|
"""$project_name: A Flower Baseline."""
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
4
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
5
|
+
from flwr.clientapp import ClientApp
|
|
6
6
|
|
|
7
7
|
from $import_name.dataset import load_data
|
|
8
|
-
from $import_name.model import Net
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
self.net,
|
|
27
|
-
self.trainloader,
|
|
28
|
-
self.local_epochs,
|
|
29
|
-
self.device,
|
|
30
|
-
)
|
|
31
|
-
return (
|
|
32
|
-
get_weights(self.net),
|
|
33
|
-
len(self.trainloader.dataset),
|
|
34
|
-
{"train_loss": train_loss},
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
def evaluate(self, parameters, config):
|
|
38
|
-
"""Evaluate model using this client's data."""
|
|
39
|
-
set_weights(self.net, parameters)
|
|
40
|
-
loss, accuracy = test(self.net, self.valloader, self.device)
|
|
41
|
-
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def client_fn(context: Context):
|
|
45
|
-
"""Construct a Client that will be run in a ClientApp."""
|
|
46
|
-
# Load model and data
|
|
47
|
-
net = Net()
|
|
8
|
+
from $import_name.model import Net
|
|
9
|
+
from $import_name.model import test as test_fn
|
|
10
|
+
from $import_name.model import train as train_fn
|
|
11
|
+
|
|
12
|
+
# Flower ClientApp
|
|
13
|
+
app = ClientApp()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@app.train()
|
|
17
|
+
def train(msg: Message, context: Context):
|
|
18
|
+
"""Train the model on local data."""
|
|
19
|
+
|
|
20
|
+
# Load the model and initialize it with the received weights
|
|
21
|
+
model = Net()
|
|
22
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
|
23
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
24
|
+
|
|
25
|
+
# Load the data
|
|
48
26
|
partition_id = int(context.node_config["partition-id"])
|
|
49
27
|
num_partitions = int(context.node_config["num-partitions"])
|
|
50
|
-
trainloader,
|
|
28
|
+
trainloader, _ = load_data(partition_id, num_partitions)
|
|
51
29
|
local_epochs = context.run_config["local-epochs"]
|
|
52
30
|
|
|
53
|
-
#
|
|
54
|
-
|
|
31
|
+
# Call the training function
|
|
32
|
+
train_loss = train_fn(
|
|
33
|
+
model,
|
|
34
|
+
trainloader,
|
|
35
|
+
local_epochs,
|
|
36
|
+
device,
|
|
37
|
+
)
|
|
55
38
|
|
|
39
|
+
# Construct and return reply Message
|
|
40
|
+
model_record = ArrayRecord(model.state_dict())
|
|
41
|
+
metrics = {
|
|
42
|
+
"train_loss": train_loss,
|
|
43
|
+
"num-examples": len(trainloader.dataset),
|
|
44
|
+
}
|
|
45
|
+
metric_record = MetricRecord(metrics)
|
|
46
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
47
|
+
return Message(content=content, reply_to=msg)
|
|
56
48
|
|
|
57
|
-
|
|
58
|
-
app
|
|
49
|
+
|
|
50
|
+
@app.evaluate()
|
|
51
|
+
def evaluate(msg: Message, context: Context):
|
|
52
|
+
"""Evaluate the model on local data."""
|
|
53
|
+
|
|
54
|
+
# Load the model and initialize it with the received weights
|
|
55
|
+
model = Net()
|
|
56
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
|
57
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
58
|
+
|
|
59
|
+
# Load the data
|
|
60
|
+
partition_id = int(context.node_config["partition-id"])
|
|
61
|
+
num_partitions = int(context.node_config["num-partitions"])
|
|
62
|
+
_, valloader = load_data(partition_id, num_partitions)
|
|
63
|
+
|
|
64
|
+
# Call the evaluation function
|
|
65
|
+
eval_loss, eval_acc = test_fn(model, valloader, device)
|
|
66
|
+
|
|
67
|
+
# Construct and return reply Message
|
|
68
|
+
metrics = {
|
|
69
|
+
"eval_loss": eval_loss,
|
|
70
|
+
"eval_acc": eval_acc,
|
|
71
|
+
"num-examples": len(valloader.dataset),
|
|
72
|
+
}
|
|
73
|
+
metric_record = MetricRecord(metrics)
|
|
74
|
+
content = RecordDict({"metrics": metric_record})
|
|
75
|
+
return Message(content=content, reply_to=msg)
|