flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import mlx.optimizers as optim
|
|
6
|
+
from flwr.client import NumPyClient, ClientApp
|
|
7
|
+
from flwr.common import Context
|
|
8
|
+
|
|
9
|
+
from $import_name.task import (
|
|
10
|
+
batch_iterate,
|
|
11
|
+
eval_fn,
|
|
12
|
+
get_params,
|
|
13
|
+
load_data,
|
|
14
|
+
loss_fn,
|
|
15
|
+
set_params,
|
|
16
|
+
MLP,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Define Flower Client and client_fn
|
|
21
|
+
class FlowerClient(NumPyClient):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
data,
|
|
25
|
+
num_layers,
|
|
26
|
+
hidden_dim,
|
|
27
|
+
num_classes,
|
|
28
|
+
batch_size,
|
|
29
|
+
learning_rate,
|
|
30
|
+
num_epochs,
|
|
31
|
+
):
|
|
32
|
+
self.num_layers = num_layers
|
|
33
|
+
self.hidden_dim = hidden_dim
|
|
34
|
+
self.num_classes = num_classes
|
|
35
|
+
self.batch_size = batch_size
|
|
36
|
+
self.learning_rate = learning_rate
|
|
37
|
+
self.num_epochs = num_epochs
|
|
38
|
+
|
|
39
|
+
self.train_images, self.train_labels, self.test_images, self.test_labels = data
|
|
40
|
+
self.model = MLP(
|
|
41
|
+
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
|
|
42
|
+
)
|
|
43
|
+
self.optimizer = optim.SGD(learning_rate=learning_rate)
|
|
44
|
+
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
|
|
45
|
+
self.num_epochs = num_epochs
|
|
46
|
+
self.batch_size = batch_size
|
|
47
|
+
|
|
48
|
+
def get_parameters(self, config):
|
|
49
|
+
return get_params(self.model)
|
|
50
|
+
|
|
51
|
+
def set_parameters(self, parameters):
|
|
52
|
+
set_params(self.model, parameters)
|
|
53
|
+
|
|
54
|
+
def fit(self, parameters, config):
|
|
55
|
+
self.set_parameters(parameters)
|
|
56
|
+
for _ in range(self.num_epochs):
|
|
57
|
+
for X, y in batch_iterate(
|
|
58
|
+
self.batch_size, self.train_images, self.train_labels
|
|
59
|
+
):
|
|
60
|
+
_, grads = self.loss_and_grad_fn(self.model, X, y)
|
|
61
|
+
self.optimizer.update(self.model, grads)
|
|
62
|
+
mx.eval(self.model.parameters(), self.optimizer.state)
|
|
63
|
+
return self.get_parameters(config={}), len(self.train_images), {}
|
|
64
|
+
|
|
65
|
+
def evaluate(self, parameters, config):
|
|
66
|
+
self.set_parameters(parameters)
|
|
67
|
+
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
|
|
68
|
+
loss = loss_fn(self.model, self.test_images, self.test_labels)
|
|
69
|
+
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def client_fn(context: Context):
|
|
73
|
+
partition_id = context.node_config["partition-id"]
|
|
74
|
+
num_partitions = context.node_config["num-partitions"]
|
|
75
|
+
data = load_data(partition_id, num_partitions)
|
|
76
|
+
|
|
77
|
+
num_layers = context.run_config["num-layers"]
|
|
78
|
+
hidden_dim = context.run_config["hidden-dim"]
|
|
79
|
+
num_classes = 10
|
|
80
|
+
batch_size = context.run_config["batch-size"]
|
|
81
|
+
learning_rate = context.run_config["lr"]
|
|
82
|
+
num_epochs = context.run_config["local-epochs"]
|
|
83
|
+
|
|
84
|
+
# Return Client instance
|
|
85
|
+
return FlowerClient(
|
|
86
|
+
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
|
|
87
|
+
).to_client()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Flower ClientApp
|
|
91
|
+
app = ClientApp(
|
|
92
|
+
client_fn,
|
|
93
|
+
)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
import numpy as np
|
|
5
6
|
|
|
6
7
|
|
|
@@ -15,7 +16,7 @@ class FlowerClient(NumPyClient):
|
|
|
15
16
|
return float(0.0), 1, {"accuracy": float(1.0)}
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
def client_fn(
|
|
19
|
+
def client_fn(context: Context):
|
|
19
20
|
return FlowerClient().to_client()
|
|
20
21
|
|
|
21
22
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from flwr.client import NumPyClient, ClientApp
|
|
5
|
+
from flwr.common import Context
|
|
4
6
|
|
|
5
|
-
from $
|
|
7
|
+
from $import_name.task import (
|
|
6
8
|
Net,
|
|
7
|
-
DEVICE,
|
|
8
9
|
load_data,
|
|
9
10
|
get_weights,
|
|
10
11
|
set_weights,
|
|
@@ -15,29 +16,40 @@ from $project_name.task import (
|
|
|
15
16
|
|
|
16
17
|
# Define Flower Client and client_fn
|
|
17
18
|
class FlowerClient(NumPyClient):
|
|
18
|
-
def __init__(self, net, trainloader, valloader):
|
|
19
|
+
def __init__(self, net, trainloader, valloader, local_epochs):
|
|
19
20
|
self.net = net
|
|
20
21
|
self.trainloader = trainloader
|
|
21
22
|
self.valloader = valloader
|
|
23
|
+
self.local_epochs = local_epochs
|
|
24
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
25
|
+
self.net.to(self.device)
|
|
22
26
|
|
|
23
27
|
def fit(self, parameters, config):
|
|
24
28
|
set_weights(self.net, parameters)
|
|
25
|
-
|
|
26
|
-
|
|
29
|
+
train_loss = train(
|
|
30
|
+
self.net,
|
|
31
|
+
self.trainloader,
|
|
32
|
+
self.local_epochs,
|
|
33
|
+
self.device,
|
|
34
|
+
)
|
|
35
|
+
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}
|
|
27
36
|
|
|
28
37
|
def evaluate(self, parameters, config):
|
|
29
38
|
set_weights(self.net, parameters)
|
|
30
|
-
loss, accuracy = test(self.net, self.valloader)
|
|
39
|
+
loss, accuracy = test(self.net, self.valloader, self.device)
|
|
31
40
|
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
32
41
|
|
|
33
42
|
|
|
34
|
-
def client_fn(
|
|
43
|
+
def client_fn(context: Context):
|
|
35
44
|
# Load model and data
|
|
36
|
-
net = Net()
|
|
37
|
-
|
|
45
|
+
net = Net()
|
|
46
|
+
partition_id = context.node_config["partition-id"]
|
|
47
|
+
num_partitions = context.node_config["num-partitions"]
|
|
48
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
49
|
+
local_epochs = context.run_config["local-epochs"]
|
|
38
50
|
|
|
39
51
|
# Return Client instance
|
|
40
|
-
return FlowerClient(net, trainloader, valloader).to_client()
|
|
52
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
41
53
|
|
|
42
54
|
|
|
43
55
|
# Flower ClientApp
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from flwr.client import NumPyClient, ClientApp
|
|
7
|
+
from flwr.common import Context
|
|
8
|
+
from flwr_datasets import FederatedDataset
|
|
9
|
+
from sklearn.linear_model import LogisticRegression
|
|
10
|
+
from sklearn.metrics import log_loss
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_model_parameters(model):
|
|
14
|
+
if model.fit_intercept:
|
|
15
|
+
params = [
|
|
16
|
+
model.coef_,
|
|
17
|
+
model.intercept_,
|
|
18
|
+
]
|
|
19
|
+
else:
|
|
20
|
+
params = [model.coef_]
|
|
21
|
+
return params
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def set_model_params(model, params):
|
|
25
|
+
model.coef_ = params[0]
|
|
26
|
+
if model.fit_intercept:
|
|
27
|
+
model.intercept_ = params[1]
|
|
28
|
+
return model
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def set_initial_params(model):
|
|
32
|
+
n_classes = 10 # MNIST has 10 classes
|
|
33
|
+
n_features = 784 # Number of features in dataset
|
|
34
|
+
model.classes_ = np.array([i for i in range(10)])
|
|
35
|
+
|
|
36
|
+
model.coef_ = np.zeros((n_classes, n_features))
|
|
37
|
+
if model.fit_intercept:
|
|
38
|
+
model.intercept_ = np.zeros((n_classes,))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FlowerClient(NumPyClient):
|
|
42
|
+
def __init__(self, model, X_train, X_test, y_train, y_test):
|
|
43
|
+
self.model = model
|
|
44
|
+
self.X_train = X_train
|
|
45
|
+
self.X_test = X_test
|
|
46
|
+
self.y_train = y_train
|
|
47
|
+
self.y_test = y_test
|
|
48
|
+
|
|
49
|
+
def get_parameters(self, config):
|
|
50
|
+
return get_model_parameters(self.model)
|
|
51
|
+
|
|
52
|
+
def fit(self, parameters, config):
|
|
53
|
+
set_model_params(self.model, parameters)
|
|
54
|
+
|
|
55
|
+
# Ignore convergence failure due to low local epochs
|
|
56
|
+
with warnings.catch_warnings():
|
|
57
|
+
warnings.simplefilter("ignore")
|
|
58
|
+
self.model.fit(self.X_train, self.y_train)
|
|
59
|
+
|
|
60
|
+
return get_model_parameters(self.model), len(self.X_train), {}
|
|
61
|
+
|
|
62
|
+
def evaluate(self, parameters, config):
|
|
63
|
+
set_model_params(self.model, parameters)
|
|
64
|
+
|
|
65
|
+
loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
|
|
66
|
+
accuracy = self.model.score(self.X_test, self.y_test)
|
|
67
|
+
|
|
68
|
+
return loss, len(self.X_test), {"accuracy": accuracy}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def client_fn(context: Context):
|
|
72
|
+
partition_id = context.node_config["partition-id"]
|
|
73
|
+
num_partitions = context.node_config["num-partitions"]
|
|
74
|
+
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
|
|
75
|
+
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
|
|
76
|
+
|
|
77
|
+
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
|
|
78
|
+
|
|
79
|
+
# Split the on edge data: 80% train, 20% test
|
|
80
|
+
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
|
|
81
|
+
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
|
|
82
|
+
|
|
83
|
+
# Create LogisticRegression Model
|
|
84
|
+
model = LogisticRegression(
|
|
85
|
+
penalty="l2",
|
|
86
|
+
max_iter=1, # local epoch
|
|
87
|
+
warm_start=True, # prevent refreshing weights when fitting
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Setting initial parameters, akin to model.compile for keras models
|
|
91
|
+
set_initial_params(model)
|
|
92
|
+
|
|
93
|
+
return FlowerClient(model, X_train, X_test, y_train, y_test).to_client()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Flower ClientApp
|
|
97
|
+
app = ClientApp(client_fn=client_fn)
|
|
@@ -1 +1,60 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
5
|
+
|
|
6
|
+
from $import_name.task import load_data, load_model
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Define Flower Client and client_fn
|
|
10
|
+
class FlowerClient(NumPyClient):
|
|
11
|
+
def __init__(
|
|
12
|
+
self, model, data, epochs, batch_size, verbose
|
|
13
|
+
):
|
|
14
|
+
self.model = model
|
|
15
|
+
self.x_train, self.y_train, self.x_test, self.y_test = data
|
|
16
|
+
self.epochs = epochs
|
|
17
|
+
self.batch_size = batch_size
|
|
18
|
+
self.verbose = verbose
|
|
19
|
+
|
|
20
|
+
def get_parameters(self, config):
|
|
21
|
+
return self.model.get_weights()
|
|
22
|
+
|
|
23
|
+
def fit(self, parameters, config):
|
|
24
|
+
self.model.set_weights(parameters)
|
|
25
|
+
self.model.fit(
|
|
26
|
+
self.x_train,
|
|
27
|
+
self.y_train,
|
|
28
|
+
epochs=self.epochs,
|
|
29
|
+
batch_size=self.batch_size,
|
|
30
|
+
verbose=self.verbose,
|
|
31
|
+
)
|
|
32
|
+
return self.model.get_weights(), len(self.x_train), {}
|
|
33
|
+
|
|
34
|
+
def evaluate(self, parameters, config):
|
|
35
|
+
self.model.set_weights(parameters)
|
|
36
|
+
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
|
|
37
|
+
return loss, len(self.x_test), {"accuracy": accuracy}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def client_fn(context: Context):
|
|
41
|
+
# Load model and data
|
|
42
|
+
net = load_model()
|
|
43
|
+
|
|
44
|
+
partition_id = context.node_config["partition-id"]
|
|
45
|
+
num_partitions = context.node_config["num-partitions"]
|
|
46
|
+
data = load_data(partition_id, num_partitions)
|
|
47
|
+
epochs = context.run_config["local-epochs"]
|
|
48
|
+
batch_size = context.run_config["batch-size"]
|
|
49
|
+
verbose = context.run_config.get("verbose")
|
|
50
|
+
|
|
51
|
+
# Return Client instance
|
|
52
|
+
return FlowerClient(
|
|
53
|
+
net, data, epochs, batch_size, verbose
|
|
54
|
+
).to_client()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Flower ClientApp
|
|
58
|
+
app = ClientApp(
|
|
59
|
+
client_fn=client_fn,
|
|
60
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower CLI `new` command app / code / flwr_tune templates."""
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
7
|
+
from flwr_datasets import FederatedDataset
|
|
8
|
+
from hydra import compose, initialize
|
|
9
|
+
from hydra.utils import instantiate
|
|
10
|
+
|
|
11
|
+
from flwr.client import ClientApp
|
|
12
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
13
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
14
|
+
|
|
15
|
+
from $import_name.client_app import gen_client_fn, get_parameters
|
|
16
|
+
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
|
|
17
|
+
from $import_name.models import get_model
|
|
18
|
+
from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config
|
|
19
|
+
|
|
20
|
+
# Avoid warnings
|
|
21
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
22
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
23
|
+
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
24
|
+
|
|
25
|
+
# Initialise regular config
|
|
26
|
+
with initialize(config_path="conf", version_base="1.1"):
|
|
27
|
+
cfg = compose(config_name="config")
|
|
28
|
+
|
|
29
|
+
# Initialise static config
|
|
30
|
+
with initialize(config_path="conf", version_base="1.1"):
|
|
31
|
+
cfg_static = compose(config_name="static_config")
|
|
32
|
+
|
|
33
|
+
cfg.train.num_rounds = cfg_static.num_rounds
|
|
34
|
+
|
|
35
|
+
# Create output directory given current timestamp
|
|
36
|
+
current_time = datetime.now()
|
|
37
|
+
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
38
|
+
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
39
|
+
os.makedirs(save_path, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
# Partition dataset and get dataloaders
|
|
42
|
+
partitioner = instantiate(cfg_static.partitioner)
|
|
43
|
+
fds = FederatedDataset(
|
|
44
|
+
dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
|
|
45
|
+
)
|
|
46
|
+
(
|
|
47
|
+
tokenizer,
|
|
48
|
+
data_collator,
|
|
49
|
+
formatting_prompts_func,
|
|
50
|
+
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
|
|
51
|
+
|
|
52
|
+
# ClientApp for Flower Next
|
|
53
|
+
client = ClientApp(
|
|
54
|
+
client_fn=gen_client_fn(
|
|
55
|
+
fds,
|
|
56
|
+
tokenizer,
|
|
57
|
+
formatting_prompts_func,
|
|
58
|
+
data_collator,
|
|
59
|
+
cfg.model,
|
|
60
|
+
cfg.train,
|
|
61
|
+
save_path,
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Get initial model weights
|
|
66
|
+
init_model = get_model(cfg.model)
|
|
67
|
+
init_model_parameters = get_parameters(init_model)
|
|
68
|
+
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
69
|
+
|
|
70
|
+
def server_fn(context: Context):
|
|
71
|
+
# Instantiate strategy according to config. Here we pass other arguments
|
|
72
|
+
# that are only defined at runtime.
|
|
73
|
+
strategy = instantiate(
|
|
74
|
+
cfg.strategy,
|
|
75
|
+
on_fit_config_fn=get_on_fit_config(),
|
|
76
|
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
77
|
+
initial_parameters=init_model_parameters,
|
|
78
|
+
evaluate_fn=get_evaluate_fn(
|
|
79
|
+
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
config = ServerConfig(num_rounds=cfg_static.num_rounds)
|
|
84
|
+
|
|
85
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# ServerApp for Flower Next
|
|
89
|
+
server = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from typing import Callable, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
9
|
+
from transformers import TrainingArguments
|
|
10
|
+
from trl import SFTTrainer
|
|
11
|
+
|
|
12
|
+
from flwr.client import NumPyClient
|
|
13
|
+
from flwr.common import Context
|
|
14
|
+
from flwr.common.typing import NDArrays, Scalar
|
|
15
|
+
from $import_name.dataset import reformat
|
|
16
|
+
from $import_name.models import cosine_annealing, get_model
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# pylint: disable=too-many-arguments
|
|
20
|
+
# pylint: disable=too-many-instance-attributes
|
|
21
|
+
class FlowerClient(NumPyClient):
|
|
22
|
+
"""Standard Flower client for CNN training."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model_cfg: DictConfig,
|
|
27
|
+
train_cfg: DictConfig,
|
|
28
|
+
trainset,
|
|
29
|
+
tokenizer,
|
|
30
|
+
formatting_prompts_func,
|
|
31
|
+
data_collator,
|
|
32
|
+
save_path,
|
|
33
|
+
): # pylint: disable=too-many-arguments
|
|
34
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
35
|
+
self.train_cfg = train_cfg
|
|
36
|
+
self.training_argumnets = TrainingArguments(**train_cfg.training_arguments)
|
|
37
|
+
self.tokenizer = tokenizer
|
|
38
|
+
self.formatting_prompts_func = formatting_prompts_func
|
|
39
|
+
self.data_collator = data_collator
|
|
40
|
+
self.save_path = save_path
|
|
41
|
+
|
|
42
|
+
# instantiate model
|
|
43
|
+
self.model = get_model(model_cfg)
|
|
44
|
+
|
|
45
|
+
self.trainset = trainset
|
|
46
|
+
|
|
47
|
+
def fit(
|
|
48
|
+
self, parameters: NDArrays, config: Dict[str, Scalar]
|
|
49
|
+
) -> Tuple[NDArrays, int, Dict]:
|
|
50
|
+
"""Implement distributed fit function for a given client."""
|
|
51
|
+
set_parameters(self.model, parameters)
|
|
52
|
+
|
|
53
|
+
new_lr = cosine_annealing(
|
|
54
|
+
int(config["current_round"]),
|
|
55
|
+
self.train_cfg.num_rounds,
|
|
56
|
+
self.train_cfg.learning_rate_max,
|
|
57
|
+
self.train_cfg.learning_rate_min,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.training_argumnets.learning_rate = new_lr
|
|
61
|
+
self.training_argumnets.output_dir = self.save_path
|
|
62
|
+
|
|
63
|
+
# Construct trainer
|
|
64
|
+
trainer = SFTTrainer(
|
|
65
|
+
model=self.model,
|
|
66
|
+
tokenizer=self.tokenizer,
|
|
67
|
+
args=self.training_argumnets,
|
|
68
|
+
max_seq_length=self.train_cfg.seq_length,
|
|
69
|
+
train_dataset=self.trainset,
|
|
70
|
+
formatting_func=self.formatting_prompts_func,
|
|
71
|
+
data_collator=self.data_collator,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Do local training
|
|
75
|
+
results = trainer.train()
|
|
76
|
+
|
|
77
|
+
return (
|
|
78
|
+
get_parameters(self.model),
|
|
79
|
+
len(self.trainset),
|
|
80
|
+
{"train_loss": results.training_loss},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def set_parameters(model, parameters: NDArrays) -> None:
|
|
85
|
+
"""Change the parameters of the model using the given ones."""
|
|
86
|
+
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
|
|
87
|
+
params_dict = zip(peft_state_dict_keys, parameters)
|
|
88
|
+
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
|
89
|
+
set_peft_model_state_dict(model, state_dict)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_parameters(model) -> NDArrays:
|
|
93
|
+
"""Return the parameters of the current net."""
|
|
94
|
+
state_dict = get_peft_model_state_dict(model)
|
|
95
|
+
return [val.cpu().numpy() for _, val in state_dict.items()]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def gen_client_fn(
|
|
99
|
+
fds,
|
|
100
|
+
tokenizer,
|
|
101
|
+
formatting_prompts_func,
|
|
102
|
+
data_collator,
|
|
103
|
+
model_cfg: DictConfig,
|
|
104
|
+
train_cfg: DictConfig,
|
|
105
|
+
save_path: str,
|
|
106
|
+
) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
|
|
107
|
+
"""Generate the client function that creates the Flower Clients."""
|
|
108
|
+
|
|
109
|
+
def client_fn(context: Context) -> FlowerClient:
|
|
110
|
+
"""Create a Flower client representing a single organization."""
|
|
111
|
+
# Let's get the partition corresponding to the i-th client
|
|
112
|
+
partition_id = context.node_config["partition-id"]
|
|
113
|
+
client_trainset = fds.load_partition(partition_id, "train")
|
|
114
|
+
client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
|
|
115
|
+
|
|
116
|
+
return FlowerClient(
|
|
117
|
+
model_cfg,
|
|
118
|
+
train_cfg,
|
|
119
|
+
client_trainset,
|
|
120
|
+
tokenizer,
|
|
121
|
+
formatting_prompts_func,
|
|
122
|
+
data_collator,
|
|
123
|
+
save_path,
|
|
124
|
+
).to_client()
|
|
125
|
+
|
|
126
|
+
return client_fn
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Federated Instruction Tuning
|
|
2
|
+
---
|
|
3
|
+
model:
|
|
4
|
+
name: "mistralai/Mistral-7B-v0.3"
|
|
5
|
+
quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes
|
|
6
|
+
gradient_checkpointing: True
|
|
7
|
+
lora:
|
|
8
|
+
peft_lora_r: 32
|
|
9
|
+
peft_lora_alpha: 64
|
|
10
|
+
|
|
11
|
+
train:
|
|
12
|
+
num_rounds: null
|
|
13
|
+
save_every_round: 5
|
|
14
|
+
learning_rate_max: 5e-5
|
|
15
|
+
learning_rate_min: 1e-6
|
|
16
|
+
seq_length: 512
|
|
17
|
+
training_arguments:
|
|
18
|
+
output_dir: null # to be set by hydra
|
|
19
|
+
learning_rate: null # to be set by the client
|
|
20
|
+
per_device_train_batch_size: 16
|
|
21
|
+
gradient_accumulation_steps: 1
|
|
22
|
+
logging_steps: 10
|
|
23
|
+
num_train_epochs: 3
|
|
24
|
+
max_steps: 10
|
|
25
|
+
report_to: null
|
|
26
|
+
save_steps: 1000
|
|
27
|
+
save_total_limit: 10
|
|
28
|
+
gradient_checkpointing: True
|
|
29
|
+
lr_scheduler_type: "constant"
|
|
30
|
+
|
|
31
|
+
strategy:
|
|
32
|
+
_target_: flwr.server.strategy.FedAvg
|
|
33
|
+
fraction_fit: $fraction_fit
|
|
34
|
+
fraction_evaluate: 0.0 # no client evaluation
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer
|
|
4
|
+
from trl import DataCollatorForCompletionOnlyLM
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def formatting_prompts_func(example):
|
|
8
|
+
"""Construct prompts."""
|
|
9
|
+
output_texts = []
|
|
10
|
+
# Constructing a standard Alpaca
|
|
11
|
+
# (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
|
|
12
|
+
mssg = (
|
|
13
|
+
"Below is an instruction that describes a task. "
|
|
14
|
+
"Write a response that appropriately completes the request."
|
|
15
|
+
)
|
|
16
|
+
for i in range(len(example["instruction"])):
|
|
17
|
+
text = (
|
|
18
|
+
f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
|
|
19
|
+
f"### Response: {example['response'][i]}"
|
|
20
|
+
)
|
|
21
|
+
output_texts.append(text)
|
|
22
|
+
return output_texts
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
|
|
26
|
+
"""Get tokenizer, data_collator and prompt formatting."""
|
|
27
|
+
# From: https://huggingface.co/docs/trl/en/sft_trainer
|
|
28
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
29
|
+
model_name, use_fast=True, padding_side="right"
|
|
30
|
+
)
|
|
31
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
32
|
+
response_template_with_context = "\n### Response:" # alpaca response tag
|
|
33
|
+
response_template_ids = tokenizer.encode(
|
|
34
|
+
response_template_with_context, add_special_tokens=False
|
|
35
|
+
)[2:]
|
|
36
|
+
data_collator = DataCollatorForCompletionOnlyLM(
|
|
37
|
+
response_template_ids, tokenizer=tokenizer
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return tokenizer, data_collator, formatting_prompts_func
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def formatting(dataset):
|
|
44
|
+
"""Format dataset."""
|
|
45
|
+
dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
|
|
46
|
+
return dataset
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def reformat(dataset, llm_task):
|
|
50
|
+
"""Reformat datasets."""
|
|
51
|
+
dataset = dataset.rename_column("output", "response")
|
|
52
|
+
if llm_task == "finance" or llm_task == "code":
|
|
53
|
+
dataset = dataset.map(formatting, remove_columns=["input"])
|
|
54
|
+
if llm_task == "medical":
|
|
55
|
+
dataset = dataset.remove_columns(["instruction"])
|
|
56
|
+
dataset = dataset.rename_column("input", "instruction")
|
|
57
|
+
return dataset
|