flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from peft import LoraConfig, get_peft_model
|
|
8
|
+
from peft.utils import prepare_model_for_kbit_training
|
|
9
|
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def cosine_annealing(
|
|
13
|
+
current_round: int,
|
|
14
|
+
total_round: int,
|
|
15
|
+
lrate_max: float = 0.001,
|
|
16
|
+
lrate_min: float = 0.0,
|
|
17
|
+
) -> float:
|
|
18
|
+
"""Implement cosine annealing learning rate schedule."""
|
|
19
|
+
cos_inner = math.pi * current_round / total_round
|
|
20
|
+
return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_model(model_cfg: DictConfig):
|
|
24
|
+
"""Load model with appropriate quantization config and other optimizations.
|
|
25
|
+
|
|
26
|
+
Please refer to this example for `peft + BitsAndBytes`:
|
|
27
|
+
https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
|
|
28
|
+
"""
|
|
29
|
+
if model_cfg.quantization == 4:
|
|
30
|
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
31
|
+
elif model_cfg.quantization == 8:
|
|
32
|
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
39
|
+
model_cfg.name,
|
|
40
|
+
quantization_config=quantization_config,
|
|
41
|
+
torch_dtype=torch.bfloat16,
|
|
42
|
+
low_cpu_mem_usage=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
model = prepare_model_for_kbit_training(
|
|
46
|
+
model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
peft_config = LoraConfig(
|
|
50
|
+
r=model_cfg.lora.peft_lora_r,
|
|
51
|
+
lora_alpha=model_cfg.lora.peft_lora_alpha,
|
|
52
|
+
lora_dropout=0.075,
|
|
53
|
+
task_type="CAUSAL_LM",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if model_cfg.gradient_checkpointing:
|
|
57
|
+
model.config.use_cache = False
|
|
58
|
+
|
|
59
|
+
return get_peft_model(model, peft_config)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from $import_name.client_app import set_parameters
|
|
4
|
+
from $import_name.models import get_model
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Get function that will be executed by the strategy's evaluate() method
|
|
8
|
+
# Here we use it to save global model checkpoints
|
|
9
|
+
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
10
|
+
"""Return an evaluation function for saving global model."""
|
|
11
|
+
|
|
12
|
+
def evaluate(server_round: int, parameters, config):
|
|
13
|
+
# Save model
|
|
14
|
+
if server_round != 0 and (
|
|
15
|
+
server_round == total_round or server_round % save_every_round == 0
|
|
16
|
+
):
|
|
17
|
+
# Init model
|
|
18
|
+
model = get_model(model_cfg)
|
|
19
|
+
set_parameters(model, parameters)
|
|
20
|
+
|
|
21
|
+
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
22
|
+
|
|
23
|
+
return 0.0, {}
|
|
24
|
+
|
|
25
|
+
return evaluate
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_on_fit_config():
|
|
29
|
+
"""
|
|
30
|
+
Return a function that will be used to construct the config
|
|
31
|
+
that the client's fit() method will receive.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def fit_config_fn(server_round: int):
|
|
35
|
+
fit_config = {"current_round": server_round}
|
|
36
|
+
return fit_config
|
|
37
|
+
|
|
38
|
+
return fit_config_fn
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def fit_weighted_average(metrics):
|
|
42
|
+
"""Aggregate (federated) evaluation metrics."""
|
|
43
|
+
# Multiply accuracy of each client by number of examples used
|
|
44
|
+
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
45
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
46
|
+
|
|
47
|
+
# Aggregate and return custom metric (weighted average)
|
|
48
|
+
return {"train_loss": sum(losses) / sum(examples)}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Federated Instruction Tuning (static)
|
|
2
|
+
---
|
|
3
|
+
dataset:
|
|
4
|
+
name: $dataset_name
|
|
5
|
+
|
|
6
|
+
# FL experimental settings
|
|
7
|
+
num_clients: $num_clients # total number of clients
|
|
8
|
+
num_rounds: 200
|
|
9
|
+
partitioner:
|
|
10
|
+
_target_: flwr_datasets.partitioner.IidPartitioner
|
|
11
|
+
num_partitions: $num_clients
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server.strategy import FedAvg
|
|
5
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg(
|
|
14
|
+
fraction_fit=1.0,
|
|
15
|
+
fraction_evaluate=1.0,
|
|
16
|
+
)
|
|
17
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
18
|
+
|
|
19
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Create ServerApp
|
|
23
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server.strategy import FedAvg
|
|
5
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Create ServerApp
|
|
20
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Create ServerApp
|
|
20
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,12 +1,20 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
4
6
|
|
|
5
|
-
# Configure the strategy
|
|
6
|
-
strategy = fl.server.strategy.FedAvg()
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Create ServerApp
|
|
20
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,28 +1,31 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import ndarrays_to_parameters
|
|
4
|
-
from flwr.server import ServerApp, ServerConfig
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
5
|
from flwr.server.strategy import FedAvg
|
|
6
6
|
|
|
7
|
-
from $
|
|
7
|
+
from $import_name.task import Net, get_weights
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
def server_fn(context: Context):
|
|
11
|
+
# Read from config
|
|
12
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
13
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
13
14
|
|
|
15
|
+
# Initialize model parameters
|
|
16
|
+
ndarrays = get_weights(Net())
|
|
17
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
|
14
18
|
|
|
15
|
-
# Define strategy
|
|
16
|
-
strategy = FedAvg(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
)
|
|
19
|
+
# Define strategy
|
|
20
|
+
strategy = FedAvg(
|
|
21
|
+
fraction_fit=fraction_fit,
|
|
22
|
+
fraction_evaluate=1.0,
|
|
23
|
+
min_available_clients=2,
|
|
24
|
+
initial_parameters=parameters,
|
|
25
|
+
)
|
|
26
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
22
27
|
|
|
28
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
23
29
|
|
|
24
30
|
# Create ServerApp
|
|
25
|
-
app = ServerApp(
|
|
26
|
-
config=ServerConfig(num_rounds=3),
|
|
27
|
-
strategy=strategy,
|
|
28
|
-
)
|
|
31
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg(
|
|
14
|
+
fraction_fit=1.0,
|
|
15
|
+
fraction_evaluate=1.0,
|
|
16
|
+
min_available_clients=2,
|
|
17
|
+
)
|
|
18
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
19
|
+
|
|
20
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Create ServerApp
|
|
24
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1 +1,29 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
6
|
+
|
|
7
|
+
from $import_name.task import load_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def server_fn(context: Context):
|
|
11
|
+
# Read from config
|
|
12
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
13
|
+
|
|
14
|
+
# Get parameters to initialize global model
|
|
15
|
+
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
16
|
+
|
|
17
|
+
# Define strategy
|
|
18
|
+
strategy = strategy = FedAvg(
|
|
19
|
+
fraction_fit=1.0,
|
|
20
|
+
fraction_evaluate=1.0,
|
|
21
|
+
min_available_clients=2,
|
|
22
|
+
initial_parameters=parameters,
|
|
23
|
+
)
|
|
24
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
25
|
+
|
|
26
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
27
|
+
|
|
28
|
+
# Create ServerApp
|
|
29
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from evaluate import load as load_metric
|
|
8
|
+
from torch.optim import AdamW
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
11
|
+
|
|
12
|
+
from flwr_datasets import FederatedDataset
|
|
13
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
17
|
+
DEVICE = torch.device("cpu")
|
|
18
|
+
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
fds = None # Cache FederatedDataset
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
25
|
+
"""Load IMDB data (training and eval)"""
|
|
26
|
+
# Only initialize `FederatedDataset` once
|
|
27
|
+
global fds
|
|
28
|
+
if fds is None:
|
|
29
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
30
|
+
fds = FederatedDataset(
|
|
31
|
+
dataset="stanfordnlp/imdb",
|
|
32
|
+
partitioners={"train": partitioner},
|
|
33
|
+
)
|
|
34
|
+
partition = fds.load_partition(partition_id)
|
|
35
|
+
# Divide data: 80% train, 20% test
|
|
36
|
+
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
37
|
+
|
|
38
|
+
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
|
|
39
|
+
|
|
40
|
+
def tokenize_function(examples):
|
|
41
|
+
return tokenizer(examples["text"], truncation=True)
|
|
42
|
+
|
|
43
|
+
partition_train_test = partition_train_test.map(tokenize_function, batched=True)
|
|
44
|
+
partition_train_test = partition_train_test.remove_columns("text")
|
|
45
|
+
partition_train_test = partition_train_test.rename_column("label", "labels")
|
|
46
|
+
|
|
47
|
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
48
|
+
trainloader = DataLoader(
|
|
49
|
+
partition_train_test["train"],
|
|
50
|
+
shuffle=True,
|
|
51
|
+
batch_size=32,
|
|
52
|
+
collate_fn=data_collator,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
testloader = DataLoader(
|
|
56
|
+
partition_train_test["test"], batch_size=32, collate_fn=data_collator
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return trainloader, testloader
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def train(net, trainloader, epochs):
|
|
63
|
+
optimizer = AdamW(net.parameters(), lr=5e-5)
|
|
64
|
+
net.train()
|
|
65
|
+
for _ in range(epochs):
|
|
66
|
+
for batch in trainloader:
|
|
67
|
+
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
|
68
|
+
outputs = net(**batch)
|
|
69
|
+
loss = outputs.loss
|
|
70
|
+
loss.backward()
|
|
71
|
+
optimizer.step()
|
|
72
|
+
optimizer.zero_grad()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test(net, testloader):
|
|
76
|
+
metric = load_metric("accuracy")
|
|
77
|
+
loss = 0
|
|
78
|
+
net.eval()
|
|
79
|
+
for batch in testloader:
|
|
80
|
+
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
outputs = net(**batch)
|
|
83
|
+
logits = outputs.logits
|
|
84
|
+
loss += outputs.loss.item()
|
|
85
|
+
predictions = torch.argmax(logits, dim=-1)
|
|
86
|
+
metric.add_batch(predictions=predictions, references=batch["labels"])
|
|
87
|
+
loss /= len(testloader.dataset)
|
|
88
|
+
accuracy = metric.compute()["accuracy"]
|
|
89
|
+
return loss, accuracy
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_weights(net):
|
|
93
|
+
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def set_weights(net, parameters):
|
|
97
|
+
params_dict = zip(net.state_dict().keys(), parameters)
|
|
98
|
+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
99
|
+
net.load_state_dict(state_dict, strict=True)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from sklearn.datasets import make_regression
|
|
6
|
+
from sklearn.model_selection import train_test_split
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
key = jax.random.PRNGKey(0)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_data():
|
|
13
|
+
# Load dataset
|
|
14
|
+
X, y = make_regression(n_features=3, random_state=0)
|
|
15
|
+
X, X_test, y, y_test = train_test_split(X, y)
|
|
16
|
+
return X, y, X_test, y_test
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_model(model_shape):
|
|
20
|
+
# Extract model parameters
|
|
21
|
+
params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
|
|
22
|
+
return params
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def loss_fn(params, X, y):
|
|
26
|
+
# Return MSE as loss
|
|
27
|
+
err = jnp.dot(X, params["w"]) + params["b"] - y
|
|
28
|
+
return jnp.mean(jnp.square(err))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def train(params, grad_fn, X, y):
|
|
32
|
+
loss = 1_000_000
|
|
33
|
+
num_examples = X.shape[0]
|
|
34
|
+
for epochs in range(50):
|
|
35
|
+
grads = grad_fn(params, X, y)
|
|
36
|
+
params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
|
|
37
|
+
loss = loss_fn(params, X, y)
|
|
38
|
+
return params, loss, num_examples
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def evaluation(params, grad_fn, X_test, y_test):
|
|
42
|
+
num_examples = X_test.shape[0]
|
|
43
|
+
err_test = loss_fn(params, X_test, y_test)
|
|
44
|
+
loss_test = jnp.mean(jnp.square(err_test))
|
|
45
|
+
return loss_test, num_examples
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_params(params):
|
|
49
|
+
parameters = []
|
|
50
|
+
for _, val in params.items():
|
|
51
|
+
parameters.append(np.array(val))
|
|
52
|
+
return parameters
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def set_params(local_params, global_params):
|
|
56
|
+
for key, value in list(zip(local_params.keys(), global_params)):
|
|
57
|
+
local_params[key] = value
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
from datasets.utils.logging import disable_progress_bar
|
|
7
|
+
from flwr_datasets import FederatedDataset
|
|
8
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
disable_progress_bar()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MLP(nn.Module):
|
|
15
|
+
"""A simple MLP."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
|
22
|
+
self.layers = [
|
|
23
|
+
nn.Linear(idim, odim)
|
|
24
|
+
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
def __call__(self, x):
|
|
28
|
+
for l in self.layers[:-1]:
|
|
29
|
+
x = mx.maximum(l(x), 0.0)
|
|
30
|
+
return self.layers[-1](x)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def loss_fn(model, X, y):
|
|
34
|
+
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def eval_fn(model, X, y):
|
|
38
|
+
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def batch_iterate(batch_size, X, y):
|
|
42
|
+
perm = mx.array(np.random.permutation(y.size))
|
|
43
|
+
for s in range(0, y.size, batch_size):
|
|
44
|
+
ids = perm[s : s + batch_size]
|
|
45
|
+
yield X[ids], y[ids]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
fds = None # Cache FederatedDataset
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
52
|
+
# Only initialize `FederatedDataset` once
|
|
53
|
+
global fds
|
|
54
|
+
if fds is None:
|
|
55
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
56
|
+
fds = FederatedDataset(
|
|
57
|
+
dataset="ylecun/mnist",
|
|
58
|
+
partitioners={"train": partitioner},
|
|
59
|
+
trust_remote_code=True,
|
|
60
|
+
)
|
|
61
|
+
partition = fds.load_partition(partition_id)
|
|
62
|
+
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
63
|
+
|
|
64
|
+
partition_splits["train"].set_format("numpy")
|
|
65
|
+
partition_splits["test"].set_format("numpy")
|
|
66
|
+
|
|
67
|
+
train_partition = partition_splits["train"].map(
|
|
68
|
+
lambda img: {
|
|
69
|
+
"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
|
|
70
|
+
},
|
|
71
|
+
input_columns="image",
|
|
72
|
+
)
|
|
73
|
+
test_partition = partition_splits["test"].map(
|
|
74
|
+
lambda img: {
|
|
75
|
+
"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
|
|
76
|
+
},
|
|
77
|
+
input_columns="image",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
data = (
|
|
81
|
+
train_partition["img"],
|
|
82
|
+
train_partition["label"].astype(np.uint32),
|
|
83
|
+
test_partition["img"],
|
|
84
|
+
test_partition["label"].astype(np.uint32),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
train_images, train_labels, test_images, test_labels = map(mx.array, data)
|
|
88
|
+
return train_images, train_labels, test_images, test_labels
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_params(model):
|
|
92
|
+
layers = model.parameters()["layers"]
|
|
93
|
+
return [np.array(val) for layer in layers for _, val in layer.items()]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def set_params(model, parameters):
|
|
97
|
+
new_params = {}
|
|
98
|
+
new_params["layers"] = [
|
|
99
|
+
{"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
|
|
100
|
+
for i in range(0, len(parameters), 2)
|
|
101
|
+
]
|
|
102
|
+
model.update(new_params)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from collections import OrderedDict
|
|
4
4
|
|
|
@@ -6,11 +6,9 @@ import torch
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
|
-
from torchvision.datasets import CIFAR10
|
|
10
9
|
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
11
10
|
from flwr_datasets import FederatedDataset
|
|
12
|
-
|
|
13
|
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
11
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
14
12
|
|
|
15
13
|
|
|
16
14
|
class Net(nn.Module):
|
|
@@ -34,12 +32,22 @@ class Net(nn.Module):
|
|
|
34
32
|
return self.fc3(x)
|
|
35
33
|
|
|
36
34
|
|
|
37
|
-
|
|
35
|
+
fds = None # Cache FederatedDataset
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
38
39
|
"""Load partition CIFAR10 data."""
|
|
39
|
-
|
|
40
|
+
# Only initialize `FederatedDataset` once
|
|
41
|
+
global fds
|
|
42
|
+
if fds is None:
|
|
43
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
44
|
+
fds = FederatedDataset(
|
|
45
|
+
dataset="uoft-cs/cifar10",
|
|
46
|
+
partitioners={"train": partitioner},
|
|
47
|
+
)
|
|
40
48
|
partition = fds.load_partition(partition_id)
|
|
41
49
|
# Divide data on each node: 80% train, 20% test
|
|
42
|
-
partition_train_test = partition.train_test_split(test_size=0.2)
|
|
50
|
+
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
43
51
|
pytorch_transforms = Compose(
|
|
44
52
|
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
45
53
|
)
|
|
@@ -55,44 +63,41 @@ def load_data(partition_id, num_partitions):
|
|
|
55
63
|
return trainloader, testloader
|
|
56
64
|
|
|
57
65
|
|
|
58
|
-
def train(net, trainloader,
|
|
66
|
+
def train(net, trainloader, epochs, device):
|
|
59
67
|
"""Train the model on the training set."""
|
|
60
68
|
net.to(device) # move model to GPU if available
|
|
61
69
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
62
|
-
optimizer = torch.optim.SGD(net.parameters(), lr=0.
|
|
70
|
+
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
|
|
63
71
|
net.train()
|
|
72
|
+
running_loss = 0.0
|
|
64
73
|
for _ in range(epochs):
|
|
65
74
|
for batch in trainloader:
|
|
66
75
|
images = batch["img"]
|
|
67
76
|
labels = batch["label"]
|
|
68
77
|
optimizer.zero_grad()
|
|
69
|
-
criterion(net(images.to(
|
|
78
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
|
79
|
+
loss.backward()
|
|
70
80
|
optimizer.step()
|
|
81
|
+
running_loss += loss.item()
|
|
71
82
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
results = {
|
|
76
|
-
"train_loss": train_loss,
|
|
77
|
-
"train_accuracy": train_acc,
|
|
78
|
-
"val_loss": val_loss,
|
|
79
|
-
"val_accuracy": val_acc,
|
|
80
|
-
}
|
|
81
|
-
return results
|
|
83
|
+
avg_trainloss = running_loss / len(trainloader)
|
|
84
|
+
return avg_trainloss
|
|
82
85
|
|
|
83
86
|
|
|
84
|
-
def test(net, testloader):
|
|
87
|
+
def test(net, testloader, device):
|
|
85
88
|
"""Validate the model on the test set."""
|
|
89
|
+
net.to(device)
|
|
86
90
|
criterion = torch.nn.CrossEntropyLoss()
|
|
87
91
|
correct, loss = 0, 0.0
|
|
88
92
|
with torch.no_grad():
|
|
89
93
|
for batch in testloader:
|
|
90
|
-
images = batch["img"].to(
|
|
91
|
-
labels = batch["label"].to(
|
|
94
|
+
images = batch["img"].to(device)
|
|
95
|
+
labels = batch["label"].to(device)
|
|
92
96
|
outputs = net(images)
|
|
93
97
|
loss += criterion(outputs, labels).item()
|
|
94
98
|
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
95
99
|
accuracy = correct / len(testloader.dataset)
|
|
100
|
+
loss = loss / len(testloader)
|
|
96
101
|
return loss, accuracy
|
|
97
102
|
|
|
98
103
|
|