flwr 1.24.0__py3-none-any.whl → 1.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/app/__init__.py +4 -1
- flwr/app/message_type.py +29 -0
- flwr/app/metadata.py +5 -2
- flwr/app/user_config.py +19 -0
- flwr/cli/app.py +37 -19
- flwr/cli/app_cmd/publish.py +25 -75
- flwr/cli/app_cmd/review.py +25 -66
- flwr/cli/auth_plugin/auth_plugin.py +5 -10
- flwr/cli/auth_plugin/noop_auth_plugin.py +1 -2
- flwr/cli/auth_plugin/oidc_cli_plugin.py +38 -38
- flwr/cli/build.py +15 -28
- flwr/cli/config/__init__.py +21 -0
- flwr/cli/config/ls.py +71 -0
- flwr/cli/config_migration.py +297 -0
- flwr/cli/config_utils.py +63 -156
- flwr/cli/constant.py +71 -0
- flwr/cli/federation/__init__.py +0 -2
- flwr/cli/federation/ls.py +256 -64
- flwr/cli/flower_config.py +429 -0
- flwr/cli/install.py +23 -62
- flwr/cli/log.py +23 -37
- flwr/cli/login/login.py +29 -63
- flwr/cli/ls.py +72 -61
- flwr/cli/new/new.py +98 -309
- flwr/cli/pull.py +19 -37
- flwr/cli/run/run.py +87 -100
- flwr/cli/run_utils.py +23 -5
- flwr/cli/stop.py +33 -74
- flwr/cli/supernode/ls.py +35 -62
- flwr/cli/supernode/register.py +31 -80
- flwr/cli/supernode/unregister.py +24 -70
- flwr/cli/typing.py +200 -0
- flwr/cli/utils.py +160 -412
- flwr/client/grpc_adapter_client/connection.py +2 -2
- flwr/client/grpc_rere_client/connection.py +9 -6
- flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
- flwr/client/message_handler/message_handler.py +2 -1
- flwr/client/mod/centraldp_mods.py +1 -1
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/rest_client/connection.py +6 -4
- flwr/client/run_info_store.py +2 -1
- flwr/clientapp/client_app.py +2 -1
- flwr/common/__init__.py +3 -2
- flwr/common/args.py +5 -5
- flwr/common/config.py +12 -17
- flwr/common/constant.py +3 -16
- flwr/common/context.py +2 -1
- flwr/common/exit/exit.py +4 -4
- flwr/common/exit/exit_code.py +6 -0
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +1 -1
- flwr/common/message.py +1 -1
- flwr/common/retry_invoker.py +13 -5
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -2
- flwr/common/serde.py +13 -5
- flwr/common/telemetry.py +1 -1
- flwr/common/typing.py +10 -3
- flwr/compat/client/app.py +6 -9
- flwr/compat/client/grpc_client/connection.py +2 -1
- flwr/compat/common/constant.py +29 -0
- flwr/compat/server/app.py +1 -1
- flwr/proto/clientappio_pb2.py +2 -2
- flwr/proto/clientappio_pb2_grpc.py +104 -88
- flwr/proto/clientappio_pb2_grpc.pyi +140 -80
- flwr/proto/federation_pb2.py +5 -3
- flwr/proto/federation_pb2.pyi +32 -2
- flwr/proto/fleet_pb2.py +10 -10
- flwr/proto/fleet_pb2.pyi +5 -1
- flwr/proto/run_pb2.py +18 -26
- flwr/proto/run_pb2.pyi +10 -58
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +138 -207
- flwr/proto/serverappio_pb2_grpc.pyi +189 -155
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +62 -90
- flwr/proto/simulationio_pb2_grpc.pyi +95 -55
- flwr/server/app.py +7 -13
- flwr/server/compat/grid_client_proxy.py +2 -1
- flwr/server/grid/grpc_grid.py +5 -5
- flwr/server/serverapp/app.py +11 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +13 -12
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/linkstate/__init__.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +36 -10
- flwr/server/superlink/linkstate/linkstate.py +34 -21
- flwr/server/superlink/linkstate/linkstate_factory.py +16 -8
- flwr/server/superlink/linkstate/{sqlite_linkstate.py → sql_linkstate.py} +471 -516
- flwr/server/superlink/linkstate/utils.py +49 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -33
- flwr/server/superlink/simulation/simulationio_servicer.py +0 -19
- flwr/server/utils/validator.py +1 -1
- flwr/server/workflow/default_workflows.py +2 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
- flwr/serverapp/strategy/bulyan.py +7 -1
- flwr/serverapp/strategy/dp_fixed_clipping.py +9 -1
- flwr/serverapp/strategy/fedavg.py +1 -1
- flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -6
- flwr/simulation/run_simulation.py +3 -12
- flwr/simulation/simulationio_connection.py +3 -3
- flwr/{common → supercore}/address.py +7 -33
- flwr/supercore/app_utils.py +2 -1
- flwr/supercore/constant.py +27 -2
- flwr/supercore/corestate/{sqlite_corestate.py → sql_corestate.py} +19 -23
- flwr/supercore/credential_store/__init__.py +33 -0
- flwr/supercore/credential_store/credential_store.py +34 -0
- flwr/supercore/credential_store/file_credential_store.py +76 -0
- flwr/{common → supercore}/date.py +0 -11
- flwr/supercore/ffs/disk_ffs.py +1 -1
- flwr/supercore/object_store/object_store_factory.py +14 -6
- flwr/supercore/object_store/{sqlite_object_store.py → sql_object_store.py} +115 -117
- flwr/supercore/sql_mixin.py +315 -0
- flwr/{cli/new/templates → supercore/state}/__init__.py +2 -2
- flwr/{cli/new/templates/app/code/flwr_tune → supercore/state/alembic}/__init__.py +2 -2
- flwr/supercore/state/alembic/env.py +103 -0
- flwr/supercore/state/alembic/script.py.mako +43 -0
- flwr/supercore/state/alembic/utils.py +239 -0
- flwr/{cli/new/templates/app → supercore/state/alembic/versions}/__init__.py +2 -2
- flwr/supercore/state/alembic/versions/rev_2026_01_28_initialize_migration_of_state_tables.py +200 -0
- flwr/supercore/state/schema/README.md +121 -0
- flwr/{cli/new/templates/app/code → supercore/state/schema}/__init__.py +2 -2
- flwr/supercore/state/schema/corestate_tables.py +36 -0
- flwr/supercore/state/schema/linkstate_tables.py +152 -0
- flwr/supercore/state/schema/objectstore_tables.py +90 -0
- flwr/supercore/superexec/run_superexec.py +2 -2
- flwr/supercore/utils.py +225 -0
- flwr/superlink/federation/federation_manager.py +2 -2
- flwr/superlink/federation/noop_federation_manager.py +8 -6
- flwr/superlink/servicer/control/control_grpc.py +2 -0
- flwr/superlink/servicer/control/control_servicer.py +106 -21
- flwr/supernode/cli/flower_supernode.py +2 -1
- flwr/supernode/nodestate/in_memory_nodestate.py +62 -1
- flwr/supernode/nodestate/nodestate.py +45 -0
- flwr/supernode/runtime/run_clientapp.py +14 -14
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +13 -5
- flwr/supernode/start_client_internal.py +17 -10
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/METADATA +8 -8
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/RECORD +144 -184
- flwr/cli/federation/show.py +0 -317
- flwr/cli/new/templates/app/.gitignore.tpl +0 -163
- flwr/cli/new/templates/app/LICENSE.tpl +0 -202
- flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
- flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
- flwr/cli/new/templates/app/README.md.tpl +0 -37
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
- flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
- flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -99
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
- flwr/common/pyproject.py +0 -42
- flwr/supercore/sqlite_mixin.py +0 -159
- /flwr/{common → supercore}/version.py +0 -0
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/WHEEL +0 -0
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
import math
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from omegaconf import DictConfig
|
|
7
|
-
from peft import LoraConfig, get_peft_model
|
|
8
|
-
from peft.utils import prepare_model_for_kbit_training
|
|
9
|
-
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def cosine_annealing(
|
|
13
|
-
current_round: int,
|
|
14
|
-
total_round: int,
|
|
15
|
-
lrate_max: float = 0.001,
|
|
16
|
-
lrate_min: float = 0.0,
|
|
17
|
-
) -> float:
|
|
18
|
-
"""Implement cosine annealing learning rate schedule."""
|
|
19
|
-
cos_inner = math.pi * current_round / total_round
|
|
20
|
-
return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def get_model(model_cfg: DictConfig):
|
|
24
|
-
"""Load model with appropriate quantization config and other optimizations.
|
|
25
|
-
"""
|
|
26
|
-
if model_cfg.quantization == 4:
|
|
27
|
-
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
28
|
-
elif model_cfg.quantization == 8:
|
|
29
|
-
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
30
|
-
else:
|
|
31
|
-
raise ValueError(
|
|
32
|
-
f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
36
|
-
model_cfg.name,
|
|
37
|
-
quantization_config=quantization_config,
|
|
38
|
-
torch_dtype=torch.bfloat16,
|
|
39
|
-
low_cpu_mem_usage=True,
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
model = prepare_model_for_kbit_training(
|
|
43
|
-
model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
peft_config = LoraConfig(
|
|
47
|
-
r=model_cfg.lora.peft_lora_r,
|
|
48
|
-
lora_alpha=model_cfg.lora.peft_lora_alpha,
|
|
49
|
-
lora_dropout=0.075,
|
|
50
|
-
task_type="CAUSAL_LM",
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
if model_cfg.gradient_checkpointing:
|
|
54
|
-
model.config.use_cache = False
|
|
55
|
-
|
|
56
|
-
return get_peft_model(model, peft_config)
|
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
from datetime import datetime
|
|
5
|
-
|
|
6
|
-
from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
|
|
7
|
-
from flwr.common.config import unflatten_dict
|
|
8
|
-
from flwr.serverapp import Grid, ServerApp
|
|
9
|
-
from omegaconf import DictConfig
|
|
10
|
-
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
11
|
-
|
|
12
|
-
from $import_name.dataset import replace_keys
|
|
13
|
-
from $import_name.models import get_model
|
|
14
|
-
from $import_name.strategy import FlowerTuneLlm
|
|
15
|
-
|
|
16
|
-
# Create ServerApp
|
|
17
|
-
app = ServerApp()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@app.main()
|
|
21
|
-
def main(grid: Grid, context: Context) -> None:
|
|
22
|
-
"""Main entry point for the ServerApp."""
|
|
23
|
-
# Create output directory given current timestamp
|
|
24
|
-
current_time = datetime.now()
|
|
25
|
-
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
26
|
-
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
27
|
-
os.makedirs(save_path, exist_ok=True)
|
|
28
|
-
|
|
29
|
-
# Read from config
|
|
30
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
31
|
-
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
32
|
-
|
|
33
|
-
# Get initial model weights
|
|
34
|
-
init_model = get_model(cfg.model)
|
|
35
|
-
arrays = ArrayRecord(get_peft_model_state_dict(init_model))
|
|
36
|
-
|
|
37
|
-
# Define strategy
|
|
38
|
-
strategy = FlowerTuneLlm(
|
|
39
|
-
fraction_train=cfg.strategy.fraction_train,
|
|
40
|
-
fraction_evaluate=cfg.strategy.fraction_evaluate,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
44
|
-
strategy.start(
|
|
45
|
-
grid=grid,
|
|
46
|
-
initial_arrays=arrays,
|
|
47
|
-
train_config=ConfigRecord({"save_path": save_path}),
|
|
48
|
-
num_rounds=num_rounds,
|
|
49
|
-
evaluate_fn=get_evaluate_fn(
|
|
50
|
-
cfg.model, cfg.train.save_every_round, num_rounds, save_path
|
|
51
|
-
),
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# Get function that will be executed by the strategy
|
|
56
|
-
# Here we use it to save global model checkpoints
|
|
57
|
-
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
58
|
-
"""Return an evaluation function for saving global model."""
|
|
59
|
-
|
|
60
|
-
def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
|
|
61
|
-
# Save model
|
|
62
|
-
if server_round != 0 and (
|
|
63
|
-
server_round == total_round or server_round % save_every_round == 0
|
|
64
|
-
):
|
|
65
|
-
# Init model
|
|
66
|
-
model = get_model(model_cfg)
|
|
67
|
-
set_peft_model_state_dict(model, arrays.to_torch_state_dict())
|
|
68
|
-
|
|
69
|
-
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
70
|
-
|
|
71
|
-
return MetricRecord()
|
|
72
|
-
|
|
73
|
-
return evaluate
|
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
from collections.abc import Iterable
|
|
4
|
-
from logging import INFO, WARN
|
|
5
|
-
from typing import Optional
|
|
6
|
-
|
|
7
|
-
from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
|
|
8
|
-
from flwr.common import log
|
|
9
|
-
from flwr.serverapp import Grid
|
|
10
|
-
from flwr.serverapp.strategy import FedAvg
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class FlowerTuneLlm(FedAvg):
|
|
14
|
-
"""Customised FedAvg strategy implementation.
|
|
15
|
-
|
|
16
|
-
This class behaves just like FedAvg but also tracks the communication
|
|
17
|
-
costs associated with `train` over FL rounds.
|
|
18
|
-
"""
|
|
19
|
-
def __init__(self, **kwargs):
|
|
20
|
-
super().__init__(**kwargs)
|
|
21
|
-
self.comm_tracker = CommunicationTracker()
|
|
22
|
-
|
|
23
|
-
def configure_train(
|
|
24
|
-
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
25
|
-
) -> Iterable[Message]:
|
|
26
|
-
"""Configure the next round of training."""
|
|
27
|
-
messages = super().configure_train(server_round, arrays, config, grid)
|
|
28
|
-
|
|
29
|
-
# Track communication costs
|
|
30
|
-
self.comm_tracker.track(messages)
|
|
31
|
-
|
|
32
|
-
return messages
|
|
33
|
-
|
|
34
|
-
def aggregate_train(
|
|
35
|
-
self,
|
|
36
|
-
server_round: int,
|
|
37
|
-
replies: Iterable[Message],
|
|
38
|
-
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
39
|
-
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
40
|
-
# Track communication costs
|
|
41
|
-
self.comm_tracker.track(replies)
|
|
42
|
-
|
|
43
|
-
arrays, metrics = super().aggregate_train(server_round, replies)
|
|
44
|
-
|
|
45
|
-
return arrays, metrics
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class CommunicationTracker:
|
|
49
|
-
"""Communication costs tracker over FL rounds."""
|
|
50
|
-
def __init__(self):
|
|
51
|
-
self.curr_comm_cost = 0.0
|
|
52
|
-
|
|
53
|
-
def track(self, messages: Iterable[Message]):
|
|
54
|
-
comm_cost = (
|
|
55
|
-
sum(
|
|
56
|
-
record.count_bytes()
|
|
57
|
-
for msg in messages
|
|
58
|
-
if msg.has_content()
|
|
59
|
-
for record in msg.content.array_records.values()
|
|
60
|
-
)
|
|
61
|
-
/ 1024**2
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
self.curr_comm_cost += comm_cost
|
|
65
|
-
log(
|
|
66
|
-
INFO,
|
|
67
|
-
"Communication budget: used %.2f MB (+%.2f MB this round) / 200,000 MB",
|
|
68
|
-
self.curr_comm_cost,
|
|
69
|
-
comm_cost,
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
if self.curr_comm_cost > 2e5:
|
|
73
|
-
log(
|
|
74
|
-
WARN,
|
|
75
|
-
"The accumulated communication cost has exceeded 200,000 MB. "
|
|
76
|
-
"Please consider reducing it if you plan to participate "
|
|
77
|
-
"FlowerTune LLM Leaderboard.",
|
|
78
|
-
)
|
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower Baseline."""
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import torch.nn.functional as F
|
|
5
|
-
from torch import nn
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class Net(nn.Module):
|
|
9
|
-
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')."""
|
|
10
|
-
|
|
11
|
-
def __init__(self):
|
|
12
|
-
super().__init__()
|
|
13
|
-
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
14
|
-
self.pool = nn.MaxPool2d(2, 2)
|
|
15
|
-
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
16
|
-
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
17
|
-
self.fc2 = nn.Linear(120, 84)
|
|
18
|
-
self.fc3 = nn.Linear(84, 10)
|
|
19
|
-
|
|
20
|
-
def forward(self, x):
|
|
21
|
-
"""Do forward."""
|
|
22
|
-
x = self.pool(F.relu(self.conv1(x)))
|
|
23
|
-
x = self.pool(F.relu(self.conv2(x)))
|
|
24
|
-
x = x.view(-1, 16 * 5 * 5)
|
|
25
|
-
x = F.relu(self.fc1(x))
|
|
26
|
-
x = F.relu(self.fc2(x))
|
|
27
|
-
return self.fc3(x)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def train(net, trainloader, epochs, device):
|
|
31
|
-
"""Train the model on the training set."""
|
|
32
|
-
net.to(device) # move model to GPU if available
|
|
33
|
-
criterion = torch.nn.CrossEntropyLoss()
|
|
34
|
-
criterion.to(device)
|
|
35
|
-
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
|
|
36
|
-
net.train()
|
|
37
|
-
running_loss = 0.0
|
|
38
|
-
for _ in range(epochs):
|
|
39
|
-
for batch in trainloader:
|
|
40
|
-
images = batch["img"]
|
|
41
|
-
labels = batch["label"]
|
|
42
|
-
optimizer.zero_grad()
|
|
43
|
-
loss = criterion(net(images.to(device)), labels.to(device))
|
|
44
|
-
loss.backward()
|
|
45
|
-
optimizer.step()
|
|
46
|
-
running_loss += loss.item()
|
|
47
|
-
|
|
48
|
-
avg_trainloss = running_loss / len(trainloader)
|
|
49
|
-
return avg_trainloss
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def test(net, testloader, device):
|
|
53
|
-
"""Validate the model on the test set."""
|
|
54
|
-
net.to(device)
|
|
55
|
-
criterion = torch.nn.CrossEntropyLoss()
|
|
56
|
-
correct, loss = 0, 0.0
|
|
57
|
-
with torch.no_grad():
|
|
58
|
-
for batch in testloader:
|
|
59
|
-
images = batch["img"].to(device)
|
|
60
|
-
labels = batch["label"].to(device)
|
|
61
|
-
outputs = net(images)
|
|
62
|
-
loss += criterion(outputs, labels).item()
|
|
63
|
-
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
64
|
-
accuracy = correct / len(testloader.dataset)
|
|
65
|
-
loss = loss / len(testloader)
|
|
66
|
-
return loss, accuracy
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower Baseline."""
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from flwr.app import ArrayRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
|
|
8
|
-
from $import_name.model import Net
|
|
9
|
-
|
|
10
|
-
# Create ServerApp
|
|
11
|
-
app = ServerApp()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@app.main()
|
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
|
16
|
-
"""Main entry point for the ServerApp."""
|
|
17
|
-
|
|
18
|
-
# Read from config
|
|
19
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
20
|
-
fraction_train = context.run_config["fraction-train"]
|
|
21
|
-
|
|
22
|
-
# Load global model
|
|
23
|
-
global_model = Net()
|
|
24
|
-
arrays = ArrayRecord(global_model.state_dict())
|
|
25
|
-
|
|
26
|
-
# Initialize FedAvg strategy
|
|
27
|
-
strategy = FedAvg(
|
|
28
|
-
fraction_train=fraction_train,
|
|
29
|
-
fraction_evaluate=1.0,
|
|
30
|
-
min_available_nodes=2,
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
34
|
-
result = strategy.start(
|
|
35
|
-
grid=grid,
|
|
36
|
-
initial_arrays=arrays,
|
|
37
|
-
num_rounds=num_rounds,
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Save final model to disk
|
|
41
|
-
print("\nSaving final model to disk...")
|
|
42
|
-
state_dict = result.arrays.to_torch_state_dict()
|
|
43
|
-
torch.save(state_dict, "final_model.pt")
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from flwr.app import ArrayRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
from transformers import AutoModelForSequenceClassification
|
|
8
|
-
|
|
9
|
-
# Create ServerApp
|
|
10
|
-
app = ServerApp()
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@app.main()
|
|
14
|
-
def main(grid: Grid, context: Context) -> None:
|
|
15
|
-
"""Main entry point for the ServerApp."""
|
|
16
|
-
|
|
17
|
-
# Read from config
|
|
18
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
19
|
-
fraction_train = context.run_config["fraction-train"]
|
|
20
|
-
|
|
21
|
-
# Initialize global model
|
|
22
|
-
model_name = context.run_config["model-name"]
|
|
23
|
-
num_labels = context.run_config["num-labels"]
|
|
24
|
-
net = AutoModelForSequenceClassification.from_pretrained(
|
|
25
|
-
model_name, num_labels=num_labels
|
|
26
|
-
)
|
|
27
|
-
arrays = ArrayRecord(net.state_dict())
|
|
28
|
-
|
|
29
|
-
# Initialize FedAvg strategy
|
|
30
|
-
strategy = FedAvg(fraction_train=fraction_train)
|
|
31
|
-
|
|
32
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
33
|
-
result = strategy.start(
|
|
34
|
-
grid=grid,
|
|
35
|
-
initial_arrays=arrays,
|
|
36
|
-
num_rounds=num_rounds,
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
# Save final model to disk
|
|
40
|
-
print("\nSaving final model to disk...")
|
|
41
|
-
state_dict = result.arrays.to_torch_state_dict()
|
|
42
|
-
torch.save(state_dict, "final_model.pt")
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
from flwr.app import ArrayRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
|
|
8
|
-
from $import_name.task import get_params, load_model
|
|
9
|
-
|
|
10
|
-
# Create ServerApp
|
|
11
|
-
app = ServerApp()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@app.main()
|
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
|
16
|
-
"""Main entry point for the ServerApp."""
|
|
17
|
-
|
|
18
|
-
# Read from config
|
|
19
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
20
|
-
input_dim = context.run_config["input-dim"]
|
|
21
|
-
|
|
22
|
-
# Load global model
|
|
23
|
-
model = load_model((input_dim,))
|
|
24
|
-
arrays = ArrayRecord(get_params(model))
|
|
25
|
-
|
|
26
|
-
# Initialize FedAvg strategy
|
|
27
|
-
strategy = FedAvg()
|
|
28
|
-
|
|
29
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
30
|
-
result = strategy.start(
|
|
31
|
-
grid=grid,
|
|
32
|
-
initial_arrays=arrays,
|
|
33
|
-
num_rounds=num_rounds,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
# Save final model to disk
|
|
37
|
-
print("\nSaving final model to disk...")
|
|
38
|
-
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
39
|
-
np.savez("final_model.npz", *ndarrays)
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
from flwr.app import ArrayRecord, Context
|
|
4
|
-
from flwr.serverapp import Grid, ServerApp
|
|
5
|
-
from flwr.serverapp.strategy import FedAvg
|
|
6
|
-
|
|
7
|
-
from $import_name.task import MLP, get_params, set_params
|
|
8
|
-
|
|
9
|
-
# Create ServerApp
|
|
10
|
-
app = ServerApp()
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@app.main()
|
|
14
|
-
def main(grid: Grid, context: Context) -> None:
|
|
15
|
-
"""Main entry point for the ServerApp."""
|
|
16
|
-
# Read from config
|
|
17
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
18
|
-
num_layers = context.run_config["num-layers"]
|
|
19
|
-
input_dim = context.run_config["input-dim"]
|
|
20
|
-
hidden_dim = context.run_config["hidden-dim"]
|
|
21
|
-
|
|
22
|
-
# Initialize global model
|
|
23
|
-
model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
|
|
24
|
-
params = get_params(model)
|
|
25
|
-
arrays = ArrayRecord(params)
|
|
26
|
-
|
|
27
|
-
# Initialize FedAvg strategy
|
|
28
|
-
strategy = FedAvg()
|
|
29
|
-
|
|
30
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
31
|
-
result = strategy.start(
|
|
32
|
-
grid=grid,
|
|
33
|
-
initial_arrays=arrays,
|
|
34
|
-
num_rounds=num_rounds,
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
# Save final model to disk
|
|
38
|
-
print("\nSaving final model to disk...")
|
|
39
|
-
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
40
|
-
set_params(model, ndarrays)
|
|
41
|
-
model.save_weights("final_model.npz")
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
from flwr.app import ArrayRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
|
|
8
|
-
from $import_name.task import get_dummy_model
|
|
9
|
-
|
|
10
|
-
# Create ServerApp
|
|
11
|
-
app = ServerApp()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@app.main()
|
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
|
16
|
-
"""Main entry point for the ServerApp."""
|
|
17
|
-
|
|
18
|
-
# Read run config
|
|
19
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
|
20
|
-
|
|
21
|
-
# Load global model
|
|
22
|
-
model = get_dummy_model()
|
|
23
|
-
arrays = ArrayRecord(model)
|
|
24
|
-
|
|
25
|
-
# Initialize FedAvg strategy
|
|
26
|
-
strategy = FedAvg()
|
|
27
|
-
|
|
28
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
29
|
-
result = strategy.start(
|
|
30
|
-
grid=grid,
|
|
31
|
-
initial_arrays=arrays,
|
|
32
|
-
num_rounds=num_rounds,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
# Save final model to disk
|
|
36
|
-
print("\nSaving final model to disk...")
|
|
37
|
-
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
38
|
-
np.savez("final_model", *ndarrays)
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from flwr.app import ArrayRecord, ConfigRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
|
|
8
|
-
from $import_name.task import Net
|
|
9
|
-
|
|
10
|
-
# Create ServerApp
|
|
11
|
-
app = ServerApp()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@app.main()
|
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
|
16
|
-
"""Main entry point for the ServerApp."""
|
|
17
|
-
|
|
18
|
-
# Read run config
|
|
19
|
-
fraction_train: float = context.run_config["fraction-train"]
|
|
20
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
|
21
|
-
lr: float = context.run_config["lr"]
|
|
22
|
-
|
|
23
|
-
# Load global model
|
|
24
|
-
global_model = Net()
|
|
25
|
-
arrays = ArrayRecord(global_model.state_dict())
|
|
26
|
-
|
|
27
|
-
# Initialize FedAvg strategy
|
|
28
|
-
strategy = FedAvg(fraction_train=fraction_train)
|
|
29
|
-
|
|
30
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
31
|
-
result = strategy.start(
|
|
32
|
-
grid=grid,
|
|
33
|
-
initial_arrays=arrays,
|
|
34
|
-
train_config=ConfigRecord({"lr": lr}),
|
|
35
|
-
num_rounds=num_rounds,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# Save final model to disk
|
|
39
|
-
print("\nSaving final model to disk...")
|
|
40
|
-
state_dict = result.arrays.to_torch_state_dict()
|
|
41
|
-
torch.save(state_dict, "final_model.pt")
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
-
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
-
from flwr.server.strategy import FedAvg
|
|
6
|
-
from $import_name.task import Net, get_weights
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def server_fn(context: Context):
|
|
10
|
-
# Read from config
|
|
11
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
12
|
-
fraction_fit = context.run_config["fraction-fit"]
|
|
13
|
-
|
|
14
|
-
# Initialize model parameters
|
|
15
|
-
ndarrays = get_weights(Net())
|
|
16
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
|
17
|
-
|
|
18
|
-
# Define strategy
|
|
19
|
-
strategy = FedAvg(
|
|
20
|
-
fraction_fit=fraction_fit,
|
|
21
|
-
fraction_evaluate=1.0,
|
|
22
|
-
min_available_clients=2,
|
|
23
|
-
initial_parameters=parameters,
|
|
24
|
-
)
|
|
25
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
26
|
-
|
|
27
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
# Create ServerApp
|
|
31
|
-
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import joblib
|
|
4
|
-
from flwr.app import ArrayRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
|
|
8
|
-
from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
|
|
9
|
-
|
|
10
|
-
# Create ServerApp
|
|
11
|
-
app = ServerApp()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@app.main()
|
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
|
16
|
-
"""Main entry point for the ServerApp."""
|
|
17
|
-
|
|
18
|
-
# Read run config
|
|
19
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
|
20
|
-
|
|
21
|
-
# Create LogisticRegression Model
|
|
22
|
-
penalty = context.run_config["penalty"]
|
|
23
|
-
local_epochs = context.run_config["local-epochs"]
|
|
24
|
-
model = get_model(penalty, local_epochs)
|
|
25
|
-
# Setting initial parameters, akin to model.compile for keras models
|
|
26
|
-
set_initial_params(model)
|
|
27
|
-
# Construct ArrayRecord representation
|
|
28
|
-
arrays = ArrayRecord(get_model_params(model))
|
|
29
|
-
|
|
30
|
-
# Initialize FedAvg strategy
|
|
31
|
-
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
|
32
|
-
|
|
33
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
34
|
-
result = strategy.start(
|
|
35
|
-
grid=grid,
|
|
36
|
-
initial_arrays=arrays,
|
|
37
|
-
num_rounds=num_rounds,
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Save final model parameters
|
|
41
|
-
print("\nSaving final model to disk...")
|
|
42
|
-
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
43
|
-
set_model_params(model, ndarrays)
|
|
44
|
-
joblib.dump(model, "logreg_model.pkl")
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
from flwr.app import ArrayRecord, Context
|
|
4
|
-
from flwr.serverapp import Grid, ServerApp
|
|
5
|
-
from flwr.serverapp.strategy import FedAvg
|
|
6
|
-
|
|
7
|
-
from $import_name.task import load_model
|
|
8
|
-
|
|
9
|
-
# Create ServerApp
|
|
10
|
-
app = ServerApp()
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@app.main()
|
|
14
|
-
def main(grid: Grid, context: Context) -> None:
|
|
15
|
-
"""Main entry point for the ServerApp."""
|
|
16
|
-
|
|
17
|
-
# Read run config
|
|
18
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
|
19
|
-
|
|
20
|
-
# Load global model
|
|
21
|
-
model = load_model()
|
|
22
|
-
arrays = ArrayRecord(model.get_weights())
|
|
23
|
-
|
|
24
|
-
# Initialize FedAvg strategy
|
|
25
|
-
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
|
26
|
-
|
|
27
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
28
|
-
result = strategy.start(
|
|
29
|
-
grid=grid,
|
|
30
|
-
initial_arrays=arrays,
|
|
31
|
-
num_rounds=num_rounds,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
# Save final model to disk
|
|
35
|
-
print("\nSaving final model to disk...")
|
|
36
|
-
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
37
|
-
model.set_weights(ndarrays)
|
|
38
|
-
model.save("final_model.keras")
|