flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/app.py +2 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +15 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +9 -13
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +36 -8
- flwr/client/grpc_rere_client/connection.py +1 -12
- flwr/client/rest_client/connection.py +3 -0
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +29 -0
- flwr/clientapp/mod/centraldp_mods.py +248 -0
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -4
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +30 -7
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +0 -11
- flwr/common/inflatable_utils.py +1 -1
- flwr/common/logger.py +1 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/telemetry.py +4 -0
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +25 -17
- flwr/proto/appio_pb2.pyi +46 -2
- flwr/proto/clientappio_pb2.py +3 -11
- flwr/proto/clientappio_pb2.pyi +0 -47
- flwr/proto/clientappio_pb2_grpc.py +19 -20
- flwr/proto/clientappio_pb2_grpc.pyi +10 -11
- flwr/proto/control_pb2.py +66 -0
- flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
- flwr/proto/control_pb2_grpc.pyi +106 -0
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +68 -0
- flwr/proto/serverappio_pb2_grpc.pyi +26 -0
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +142 -152
- flwr/server/grid/grpc_grid.py +3 -0
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +157 -146
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
- flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +64 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
- flwr/serverapp/strategy/fedadagrad.py +159 -0
- flwr/serverapp/strategy/fedadam.py +178 -0
- flwr/serverapp/strategy/fedavg.py +320 -0
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +170 -0
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +299 -0
- flwr/simulation/app.py +161 -164
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/app_utils.py +58 -0
- flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
- flwr/supercore/cli/flower_superexec.py +166 -0
- flwr/supercore/constant.py +19 -0
- flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/supercore/grpc_health/__init__.py +3 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +199 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
- flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
- flwr/supernode/cli/flower_supernode.py +3 -0
- flwr/supernode/cli/flwr_clientapp.py +18 -21
- flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
- flwr/supernode/nodestate/nodestate.py +3 -59
- flwr/supernode/runtime/run_clientapp.py +39 -102
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
- flwr/supernode/start_client_internal.py +35 -76
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
- flwr/proto/exec_pb2.py +0 -62
- flwr/proto/exec_pb2_grpc.pyi +0 -93
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -191
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -129
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
|
@@ -1,57 +1,82 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.
|
|
4
|
-
from flwr.
|
|
3
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
4
|
+
from flwr.clientapp import ClientApp
|
|
5
5
|
|
|
6
6
|
from $import_name.task import load_data, load_model
|
|
7
7
|
|
|
8
|
+
# Flower ClientApp
|
|
9
|
+
app = ClientApp()
|
|
8
10
|
|
|
9
|
-
# Define Flower Client and client_fn
|
|
10
|
-
class FlowerClient(NumPyClient):
|
|
11
|
-
def __init__(
|
|
12
|
-
self, model, data, epochs, batch_size, verbose
|
|
13
|
-
):
|
|
14
|
-
self.model = model
|
|
15
|
-
self.x_train, self.y_train, self.x_test, self.y_test = data
|
|
16
|
-
self.epochs = epochs
|
|
17
|
-
self.batch_size = batch_size
|
|
18
|
-
self.verbose = verbose
|
|
19
|
-
|
|
20
|
-
def fit(self, parameters, config):
|
|
21
|
-
self.model.set_weights(parameters)
|
|
22
|
-
self.model.fit(
|
|
23
|
-
self.x_train,
|
|
24
|
-
self.y_train,
|
|
25
|
-
epochs=self.epochs,
|
|
26
|
-
batch_size=self.batch_size,
|
|
27
|
-
verbose=self.verbose,
|
|
28
|
-
)
|
|
29
|
-
return self.model.get_weights(), len(self.x_train), {}
|
|
30
|
-
|
|
31
|
-
def evaluate(self, parameters, config):
|
|
32
|
-
self.model.set_weights(parameters)
|
|
33
|
-
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
|
|
34
|
-
return loss, len(self.x_test), {"accuracy": accuracy}
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def client_fn(context: Context):
|
|
38
|
-
# Load model and data
|
|
39
|
-
net = load_model()
|
|
40
11
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
12
|
+
@app.train()
|
|
13
|
+
def train(msg: Message, context: Context):
|
|
14
|
+
"""Train the model on local data."""
|
|
15
|
+
|
|
16
|
+
# Load the model and initialize it with the received weights
|
|
17
|
+
model = load_model()
|
|
18
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
19
|
+
model.set_weights(ndarrays)
|
|
20
|
+
|
|
21
|
+
# Read from config
|
|
44
22
|
epochs = context.run_config["local-epochs"]
|
|
45
23
|
batch_size = context.run_config["batch-size"]
|
|
46
24
|
verbose = context.run_config.get("verbose")
|
|
47
25
|
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
26
|
+
# Load the data
|
|
27
|
+
partition_id = context.node_config["partition-id"]
|
|
28
|
+
num_partitions = context.node_config["num-partitions"]
|
|
29
|
+
x_train, y_train, _, _ = load_data(partition_id, num_partitions)
|
|
52
30
|
|
|
31
|
+
# Train the model on local data
|
|
32
|
+
history = model.fit(
|
|
33
|
+
x_train,
|
|
34
|
+
y_train,
|
|
35
|
+
epochs=epochs,
|
|
36
|
+
batch_size=batch_size,
|
|
37
|
+
verbose=verbose,
|
|
38
|
+
)
|
|
53
39
|
|
|
54
|
-
#
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
40
|
+
# Get final training loss and accuracy
|
|
41
|
+
train_loss = history.history["loss"][-1] if "loss" in history.history else None
|
|
42
|
+
train_acc = history.history.get("accuracy")
|
|
43
|
+
train_acc = train_acc[-1] if train_acc is not None else None
|
|
44
|
+
|
|
45
|
+
# Construct and return reply Message
|
|
46
|
+
model_record = ArrayRecord(model.get_weights())
|
|
47
|
+
metrics = {"num-examples": len(x_train)}
|
|
48
|
+
if train_loss is not None:
|
|
49
|
+
metrics["train_loss"] = train_loss
|
|
50
|
+
if train_acc is not None:
|
|
51
|
+
metrics["train_acc"] = train_acc
|
|
52
|
+
metric_record = MetricRecord(metrics)
|
|
53
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
54
|
+
return Message(content=content, reply_to=msg)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@app.evaluate()
|
|
58
|
+
def evaluate(msg: Message, context: Context):
|
|
59
|
+
"""Evaluate the model on local data."""
|
|
60
|
+
|
|
61
|
+
# Load the model and initialize it with the received weights
|
|
62
|
+
model = load_model()
|
|
63
|
+
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
64
|
+
model.set_weights(ndarrays)
|
|
65
|
+
|
|
66
|
+
# Load the data
|
|
67
|
+
partition_id = context.node_config["partition-id"]
|
|
68
|
+
num_partitions = context.node_config["num-partitions"]
|
|
69
|
+
_, _, x_test, y_test = load_data(partition_id, num_partitions)
|
|
70
|
+
|
|
71
|
+
# Evaluate the model on local data
|
|
72
|
+
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
|
|
73
|
+
|
|
74
|
+
# Construct and return reply Message
|
|
75
|
+
metrics = {
|
|
76
|
+
"eval_loss": loss,
|
|
77
|
+
"eval_acc": accuracy,
|
|
78
|
+
"num-examples": len(x_test),
|
|
79
|
+
}
|
|
80
|
+
metric_record = MetricRecord(metrics)
|
|
81
|
+
content = RecordDict({"metrics": metric_record})
|
|
82
|
+
return Message(content=content, reply_to=msg)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import xgboost as xgb
|
|
7
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
8
|
+
from flwr.clientapp import ClientApp
|
|
9
|
+
from flwr.common.config import unflatten_dict
|
|
10
|
+
|
|
11
|
+
from $import_name.task import load_data, replace_keys
|
|
12
|
+
|
|
13
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Flower ClientApp
|
|
17
|
+
app = ClientApp()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _local_boost(bst_input, num_local_round, train_dmatrix):
|
|
21
|
+
# Update trees based on local training data.
|
|
22
|
+
for i in range(num_local_round):
|
|
23
|
+
bst_input.update(train_dmatrix, bst_input.num_boosted_rounds())
|
|
24
|
+
|
|
25
|
+
# Bagging: extract the last N=num_local_round trees for sever aggregation
|
|
26
|
+
bst = bst_input[
|
|
27
|
+
bst_input.num_boosted_rounds()
|
|
28
|
+
- num_local_round : bst_input.num_boosted_rounds()
|
|
29
|
+
]
|
|
30
|
+
return bst
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@app.train()
|
|
34
|
+
def train(msg: Message, context: Context) -> Message:
|
|
35
|
+
# Load model and data
|
|
36
|
+
partition_id = context.node_config["partition-id"]
|
|
37
|
+
num_partitions = context.node_config["num-partitions"]
|
|
38
|
+
train_dmatrix, _, num_train, _ = load_data(partition_id, num_partitions)
|
|
39
|
+
|
|
40
|
+
# Read from run config
|
|
41
|
+
num_local_round = context.run_config["local-epochs"]
|
|
42
|
+
# Flatted config dict and replace "-" with "_"
|
|
43
|
+
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
44
|
+
params = cfg["params"]
|
|
45
|
+
|
|
46
|
+
global_round = msg.content["config"]["server-round"]
|
|
47
|
+
if global_round == 1:
|
|
48
|
+
# First round local training
|
|
49
|
+
bst = xgb.train(
|
|
50
|
+
params,
|
|
51
|
+
train_dmatrix,
|
|
52
|
+
num_boost_round=num_local_round,
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
bst = xgb.Booster(params=params)
|
|
56
|
+
global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
|
|
57
|
+
|
|
58
|
+
# Load global model into booster
|
|
59
|
+
bst.load_model(global_model)
|
|
60
|
+
|
|
61
|
+
# Local training
|
|
62
|
+
bst = _local_boost(bst, num_local_round, train_dmatrix)
|
|
63
|
+
|
|
64
|
+
# Save model
|
|
65
|
+
local_model = bst.save_raw("json")
|
|
66
|
+
model_np = np.frombuffer(local_model, dtype=np.uint8)
|
|
67
|
+
|
|
68
|
+
# Construct reply message
|
|
69
|
+
# Note: we store the model as the first item in a list into ArrayRecord,
|
|
70
|
+
# which can be accessed using index ["0"].
|
|
71
|
+
model_record = ArrayRecord([model_np])
|
|
72
|
+
metrics = {
|
|
73
|
+
"num-examples": num_train,
|
|
74
|
+
}
|
|
75
|
+
metric_record = MetricRecord(metrics)
|
|
76
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
77
|
+
return Message(content=content, reply_to=msg)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@app.evaluate()
|
|
81
|
+
def evaluate(msg: Message, context: Context) -> Message:
|
|
82
|
+
# Load model and data
|
|
83
|
+
partition_id = context.node_config["partition-id"]
|
|
84
|
+
num_partitions = context.node_config["num-partitions"]
|
|
85
|
+
_, valid_dmatrix, _, num_val = load_data(partition_id, num_partitions)
|
|
86
|
+
|
|
87
|
+
# Load config
|
|
88
|
+
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
89
|
+
params = cfg["params"]
|
|
90
|
+
|
|
91
|
+
# Load global model
|
|
92
|
+
bst = xgb.Booster(params=params)
|
|
93
|
+
global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
|
|
94
|
+
bst.load_model(global_model)
|
|
95
|
+
|
|
96
|
+
# Run evaluation
|
|
97
|
+
eval_results = bst.eval_set(
|
|
98
|
+
evals=[(valid_dmatrix, "valid")],
|
|
99
|
+
iteration=bst.num_boosted_rounds() - 1,
|
|
100
|
+
)
|
|
101
|
+
auc = float(eval_results.split("\t")[1].split(":")[1])
|
|
102
|
+
|
|
103
|
+
# Construct and return reply Message
|
|
104
|
+
metrics = {
|
|
105
|
+
"auc": auc,
|
|
106
|
+
"num-examples": num_val,
|
|
107
|
+
}
|
|
108
|
+
metric_record = MetricRecord(metrics)
|
|
109
|
+
content = RecordDict({"metrics": metric_record})
|
|
110
|
+
return Message(content=content, reply_to=msg)
|
|
@@ -2,15 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import warnings
|
|
5
|
-
from typing import Dict, Tuple
|
|
6
5
|
|
|
7
|
-
import
|
|
8
|
-
from flwr.
|
|
9
|
-
from flwr.common import Context
|
|
6
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
7
|
+
from flwr.clientapp import ClientApp
|
|
10
8
|
from flwr.common.config import unflatten_dict
|
|
11
|
-
from flwr.common.typing import NDArrays, Scalar
|
|
12
9
|
from omegaconf import DictConfig
|
|
13
|
-
|
|
10
|
+
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
14
11
|
from transformers import TrainingArguments
|
|
15
12
|
from trl import SFTTrainer
|
|
16
13
|
|
|
@@ -19,12 +16,7 @@ from $import_name.dataset import (
|
|
|
19
16
|
load_data,
|
|
20
17
|
replace_keys,
|
|
21
18
|
)
|
|
22
|
-
from $import_name.models import
|
|
23
|
-
cosine_annealing,
|
|
24
|
-
get_model,
|
|
25
|
-
set_parameters,
|
|
26
|
-
get_parameters,
|
|
27
|
-
)
|
|
19
|
+
from $import_name.models import cosine_annealing, get_model
|
|
28
20
|
|
|
29
21
|
# Avoid warnings
|
|
30
22
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
@@ -32,95 +24,69 @@ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
|
32
24
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
33
25
|
|
|
34
26
|
|
|
35
|
-
#
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
27
|
+
# Avoid warnings
|
|
28
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
29
|
+
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
30
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
39
31
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
): # pylint: disable=too-many-arguments
|
|
50
|
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
51
|
-
self.train_cfg = train_cfg
|
|
52
|
-
self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
|
|
53
|
-
self.tokenizer = tokenizer
|
|
54
|
-
self.formatting_prompts_func = formatting_prompts_func
|
|
55
|
-
self.data_collator = data_collator
|
|
56
|
-
self.num_rounds = num_rounds
|
|
57
|
-
self.trainset = trainset
|
|
58
|
-
|
|
59
|
-
# instantiate model
|
|
60
|
-
self.model = get_model(model_cfg)
|
|
61
|
-
|
|
62
|
-
def fit(
|
|
63
|
-
self, parameters: NDArrays, config: Dict[str, Scalar]
|
|
64
|
-
) -> Tuple[NDArrays, int, Dict]:
|
|
65
|
-
"""Implement distributed fit function for a given client."""
|
|
66
|
-
set_parameters(self.model, parameters)
|
|
67
|
-
|
|
68
|
-
new_lr = cosine_annealing(
|
|
69
|
-
int(config["current_round"]),
|
|
70
|
-
self.num_rounds,
|
|
71
|
-
self.train_cfg.learning_rate_max,
|
|
72
|
-
self.train_cfg.learning_rate_min,
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
self.training_arguments.learning_rate = new_lr
|
|
76
|
-
self.training_arguments.output_dir = config["save_path"]
|
|
77
|
-
|
|
78
|
-
# Construct trainer
|
|
79
|
-
trainer = SFTTrainer(
|
|
80
|
-
model=self.model,
|
|
81
|
-
tokenizer=self.tokenizer,
|
|
82
|
-
args=self.training_arguments,
|
|
83
|
-
max_seq_length=self.train_cfg.seq_length,
|
|
84
|
-
train_dataset=self.trainset,
|
|
85
|
-
formatting_func=self.formatting_prompts_func,
|
|
86
|
-
data_collator=self.data_collator,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# Do local training
|
|
90
|
-
results = trainer.train()
|
|
91
|
-
|
|
92
|
-
return (
|
|
93
|
-
get_parameters(self.model),
|
|
94
|
-
len(self.trainset),
|
|
95
|
-
{"train_loss": results.training_loss},
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def client_fn(context: Context) -> FlowerClient:
|
|
100
|
-
"""Create a Flower client representing a single organization."""
|
|
32
|
+
|
|
33
|
+
# Flower ClientApp
|
|
34
|
+
app = ClientApp()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@app.train()
|
|
38
|
+
def train(msg: Message, context: Context):
|
|
39
|
+
"""Train the model on local data."""
|
|
40
|
+
# Parse config
|
|
101
41
|
partition_id = context.node_config["partition-id"]
|
|
102
42
|
num_partitions = context.node_config["num-partitions"]
|
|
103
43
|
num_rounds = context.run_config["num-server-rounds"]
|
|
104
44
|
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
45
|
+
training_arguments = TrainingArguments(**cfg.train.training_arguments)
|
|
105
46
|
|
|
106
47
|
# Let's get the client partition
|
|
107
|
-
|
|
48
|
+
trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
|
|
108
49
|
(
|
|
109
50
|
tokenizer,
|
|
110
51
|
data_collator,
|
|
111
52
|
formatting_prompts_func,
|
|
112
53
|
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
|
|
113
54
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
client_trainset,
|
|
118
|
-
tokenizer,
|
|
119
|
-
formatting_prompts_func,
|
|
120
|
-
data_collator,
|
|
121
|
-
num_rounds,
|
|
122
|
-
).to_client()
|
|
123
|
-
|
|
55
|
+
# Load the model and initialize it with the received weights
|
|
56
|
+
model = get_model(cfg.model)
|
|
57
|
+
set_peft_model_state_dict(model, msg.content["arrays"].to_torch_state_dict())
|
|
124
58
|
|
|
125
|
-
#
|
|
126
|
-
|
|
59
|
+
# Set learning rate for current round
|
|
60
|
+
new_lr = cosine_annealing(
|
|
61
|
+
msg.content["config"]["server-round"],
|
|
62
|
+
num_rounds,
|
|
63
|
+
cfg.train.learning_rate_max,
|
|
64
|
+
cfg.train.learning_rate_min,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
training_arguments.learning_rate = new_lr
|
|
68
|
+
training_arguments.output_dir = msg.content["config"]["save_path"]
|
|
69
|
+
|
|
70
|
+
# Construct trainer
|
|
71
|
+
trainer = SFTTrainer(
|
|
72
|
+
model=model,
|
|
73
|
+
tokenizer=tokenizer,
|
|
74
|
+
args=training_arguments,
|
|
75
|
+
max_seq_length=cfg.train.seq_length,
|
|
76
|
+
train_dataset=trainset,
|
|
77
|
+
formatting_func=formatting_prompts_func,
|
|
78
|
+
data_collator=data_collator,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Do local training
|
|
82
|
+
results = trainer.train()
|
|
83
|
+
|
|
84
|
+
# Construct and return reply Message
|
|
85
|
+
model_record = ArrayRecord(get_peft_model_state_dict(model))
|
|
86
|
+
metrics = {
|
|
87
|
+
"train_loss": results.training_loss,
|
|
88
|
+
"num-examples": len(trainset),
|
|
89
|
+
}
|
|
90
|
+
metric_record = MetricRecord(metrics)
|
|
91
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
92
|
+
return Message(content=content, reply_to=msg)
|
|
@@ -4,18 +4,10 @@ import math
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from omegaconf import DictConfig
|
|
7
|
-
from
|
|
8
|
-
from peft import (
|
|
9
|
-
LoraConfig,
|
|
10
|
-
get_peft_model,
|
|
11
|
-
get_peft_model_state_dict,
|
|
12
|
-
set_peft_model_state_dict,
|
|
13
|
-
)
|
|
7
|
+
from peft import LoraConfig, get_peft_model
|
|
14
8
|
from peft.utils import prepare_model_for_kbit_training
|
|
15
9
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
16
10
|
|
|
17
|
-
from flwr.common.typing import NDArrays
|
|
18
|
-
|
|
19
11
|
|
|
20
12
|
def cosine_annealing(
|
|
21
13
|
current_round: int,
|
|
@@ -62,17 +54,3 @@ def get_model(model_cfg: DictConfig):
|
|
|
62
54
|
model.config.use_cache = False
|
|
63
55
|
|
|
64
56
|
return get_peft_model(model, peft_config)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def set_parameters(model, parameters: NDArrays) -> None:
|
|
68
|
-
"""Change the parameters of the model using the given ones."""
|
|
69
|
-
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
|
|
70
|
-
params_dict = zip(peft_state_dict_keys, parameters)
|
|
71
|
-
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
|
72
|
-
set_peft_model_state_dict(model, state_dict)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def get_parameters(model) -> NDArrays:
|
|
76
|
-
"""Return the parameters of the current net."""
|
|
77
|
-
state_dict = get_peft_model_state_dict(model)
|
|
78
|
-
return [val.cpu().numpy() for _, val in state_dict.items()]
|
|
@@ -3,62 +3,23 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
|
|
6
|
-
from flwr.
|
|
6
|
+
from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
|
|
7
7
|
from flwr.common.config import unflatten_dict
|
|
8
|
-
from flwr.
|
|
8
|
+
from flwr.serverapp import Grid, ServerApp
|
|
9
9
|
from omegaconf import DictConfig
|
|
10
|
+
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
10
11
|
|
|
11
|
-
from $import_name.models import get_model, get_parameters, set_parameters
|
|
12
12
|
from $import_name.dataset import replace_keys
|
|
13
|
+
from $import_name.models import get_model
|
|
13
14
|
from $import_name.strategy import FlowerTuneLlm
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
# Here we use it to save global model checkpoints
|
|
18
|
-
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
19
|
-
"""Return an evaluation function for saving global model."""
|
|
20
|
-
|
|
21
|
-
def evaluate(server_round: int, parameters, config):
|
|
22
|
-
# Save model
|
|
23
|
-
if server_round != 0 and (
|
|
24
|
-
server_round == total_round or server_round % save_every_round == 0
|
|
25
|
-
):
|
|
26
|
-
# Init model
|
|
27
|
-
model = get_model(model_cfg)
|
|
28
|
-
set_parameters(model, parameters)
|
|
29
|
-
|
|
30
|
-
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
31
|
-
|
|
32
|
-
return 0.0, {}
|
|
33
|
-
|
|
34
|
-
return evaluate
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def get_on_fit_config(save_path):
|
|
38
|
-
"""Return a function that will be used to construct the config that the
|
|
39
|
-
client's fit() method will receive."""
|
|
40
|
-
|
|
41
|
-
def fit_config_fn(server_round: int):
|
|
42
|
-
fit_config = {}
|
|
43
|
-
fit_config["current_round"] = server_round
|
|
44
|
-
fit_config["save_path"] = save_path
|
|
45
|
-
return fit_config
|
|
46
|
-
|
|
47
|
-
return fit_config_fn
|
|
16
|
+
# Create ServerApp
|
|
17
|
+
app = ServerApp()
|
|
48
18
|
|
|
49
19
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
54
|
-
examples = [num_examples for num_examples, _ in metrics]
|
|
55
|
-
|
|
56
|
-
# Aggregate and return custom metric (weighted average)
|
|
57
|
-
return {"train_loss": sum(losses) / sum(examples)}
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def server_fn(context: Context):
|
|
61
|
-
"""Construct components that set the ServerApp behaviour."""
|
|
20
|
+
@app.main()
|
|
21
|
+
def main(grid: Grid, context: Context) -> None:
|
|
22
|
+
"""Main entry point for the ServerApp."""
|
|
62
23
|
# Create output directory given current timestamp
|
|
63
24
|
current_time = datetime.now()
|
|
64
25
|
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
@@ -71,24 +32,42 @@ def server_fn(context: Context):
|
|
|
71
32
|
|
|
72
33
|
# Get initial model weights
|
|
73
34
|
init_model = get_model(cfg.model)
|
|
74
|
-
|
|
75
|
-
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
35
|
+
arrays = ArrayRecord(get_peft_model_state_dict(init_model))
|
|
76
36
|
|
|
77
37
|
# Define strategy
|
|
78
38
|
strategy = FlowerTuneLlm(
|
|
79
|
-
|
|
39
|
+
fraction_train=cfg.strategy.fraction_train,
|
|
80
40
|
fraction_evaluate=cfg.strategy.fraction_evaluate,
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
44
|
+
strategy.start(
|
|
45
|
+
grid=grid,
|
|
46
|
+
initial_arrays=arrays,
|
|
47
|
+
train_config=ConfigRecord({"save_path": save_path}),
|
|
48
|
+
num_rounds=num_rounds,
|
|
84
49
|
evaluate_fn=get_evaluate_fn(
|
|
85
50
|
cfg.model, cfg.train.save_every_round, num_rounds, save_path
|
|
86
51
|
),
|
|
87
52
|
)
|
|
88
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
89
53
|
|
|
90
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
91
54
|
|
|
55
|
+
# Get function that will be executed by the strategy
|
|
56
|
+
# Here we use it to save global model checkpoints
|
|
57
|
+
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
58
|
+
"""Return an evaluation function for saving global model."""
|
|
92
59
|
|
|
93
|
-
|
|
94
|
-
|
|
60
|
+
def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
|
|
61
|
+
# Save model
|
|
62
|
+
if server_round != 0 and (
|
|
63
|
+
server_round == total_round or server_round % save_every_round == 0
|
|
64
|
+
):
|
|
65
|
+
# Init model
|
|
66
|
+
model = get_model(model_cfg)
|
|
67
|
+
set_peft_model_state_dict(model, arrays.to_torch_state_dict())
|
|
68
|
+
|
|
69
|
+
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
70
|
+
|
|
71
|
+
return MetricRecord()
|
|
72
|
+
|
|
73
|
+
return evaluate
|