flwr 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. flwr/cli/app_cmd/review.py +13 -3
  2. flwr/cli/federation/show.py +4 -3
  3. flwr/cli/ls.py +44 -3
  4. flwr/cli/new/new.py +106 -297
  5. flwr/cli/run/run.py +12 -17
  6. flwr/cli/run_utils.py +23 -5
  7. flwr/cli/stop.py +1 -1
  8. flwr/cli/supernode/ls.py +10 -5
  9. flwr/cli/utils.py +0 -137
  10. flwr/client/grpc_adapter_client/connection.py +2 -2
  11. flwr/client/grpc_rere_client/connection.py +6 -3
  12. flwr/client/rest_client/connection.py +6 -4
  13. flwr/common/serde.py +6 -0
  14. flwr/common/typing.py +6 -0
  15. flwr/proto/fleet_pb2.py +10 -10
  16. flwr/proto/fleet_pb2.pyi +5 -1
  17. flwr/proto/run_pb2.py +24 -24
  18. flwr/proto/run_pb2.pyi +10 -1
  19. flwr/server/app.py +1 -0
  20. flwr/server/superlink/fleet/message_handler/message_handler.py +41 -2
  21. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  22. flwr/server/superlink/linkstate/linkstate.py +32 -0
  23. flwr/server/superlink/linkstate/sqlite_linkstate.py +60 -3
  24. flwr/supercore/constant.py +3 -0
  25. flwr/supercore/utils.py +190 -0
  26. flwr/superlink/servicer/control/control_grpc.py +2 -0
  27. flwr/superlink/servicer/control/control_servicer.py +88 -5
  28. flwr/supernode/nodestate/in_memory_nodestate.py +62 -1
  29. flwr/supernode/nodestate/nodestate.py +45 -0
  30. flwr/supernode/servicer/clientappio/clientappio_servicer.py +7 -1
  31. flwr/supernode/start_client_internal.py +7 -4
  32. {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/METADATA +2 -4
  33. {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/RECORD +35 -96
  34. flwr/cli/new/templates/__init__.py +0 -15
  35. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  36. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  37. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  38. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  39. flwr/cli/new/templates/app/README.md.tpl +0 -37
  40. flwr/cli/new/templates/app/__init__.py +0 -15
  41. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  42. flwr/cli/new/templates/app/code/__init__.py +0 -15
  43. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  44. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  45. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  46. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  47. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  48. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  49. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  50. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  51. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  52. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  53. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  54. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  55. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  56. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +0 -15
  57. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  58. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  59. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  60. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  61. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  62. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  63. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  64. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  65. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  66. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  67. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  68. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  69. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  70. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  71. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  72. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  73. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  74. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  75. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  76. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  77. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  78. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -99
  79. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  80. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  81. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  82. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  83. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  84. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  85. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  86. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  87. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  88. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  89. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  90. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  91. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  92. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  93. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  94. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  95. {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
  96. {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
@@ -1,82 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
4
- from flwr.clientapp import ClientApp
5
-
6
- from $import_name.task import load_data, load_model
7
-
8
- # Flower ClientApp
9
- app = ClientApp()
10
-
11
-
12
- @app.train()
13
- def train(msg: Message, context: Context):
14
- """Train the model on local data."""
15
-
16
- # Load the model and initialize it with the received weights
17
- model = load_model()
18
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
19
- model.set_weights(ndarrays)
20
-
21
- # Read from config
22
- epochs = context.run_config["local-epochs"]
23
- batch_size = context.run_config["batch-size"]
24
- verbose = context.run_config.get("verbose")
25
-
26
- # Load the data
27
- partition_id = context.node_config["partition-id"]
28
- num_partitions = context.node_config["num-partitions"]
29
- x_train, y_train, _, _ = load_data(partition_id, num_partitions)
30
-
31
- # Train the model on local data
32
- history = model.fit(
33
- x_train,
34
- y_train,
35
- epochs=epochs,
36
- batch_size=batch_size,
37
- verbose=verbose,
38
- )
39
-
40
- # Get final training loss and accuracy
41
- train_loss = history.history["loss"][-1] if "loss" in history.history else None
42
- train_acc = history.history.get("accuracy")
43
- train_acc = train_acc[-1] if train_acc is not None else None
44
-
45
- # Construct and return reply Message
46
- model_record = ArrayRecord(model.get_weights())
47
- metrics = {"num-examples": len(x_train)}
48
- if train_loss is not None:
49
- metrics["train_loss"] = train_loss
50
- if train_acc is not None:
51
- metrics["train_acc"] = train_acc
52
- metric_record = MetricRecord(metrics)
53
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
54
- return Message(content=content, reply_to=msg)
55
-
56
-
57
- @app.evaluate()
58
- def evaluate(msg: Message, context: Context):
59
- """Evaluate the model on local data."""
60
-
61
- # Load the model and initialize it with the received weights
62
- model = load_model()
63
- ndarrays = msg.content["arrays"].to_numpy_ndarrays()
64
- model.set_weights(ndarrays)
65
-
66
- # Load the data
67
- partition_id = context.node_config["partition-id"]
68
- num_partitions = context.node_config["num-partitions"]
69
- _, _, x_test, y_test = load_data(partition_id, num_partitions)
70
-
71
- # Evaluate the model on local data
72
- loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
73
-
74
- # Construct and return reply Message
75
- metrics = {
76
- "eval_loss": loss,
77
- "eval_acc": accuracy,
78
- "num-examples": len(x_test),
79
- }
80
- metric_record = MetricRecord(metrics)
81
- content = RecordDict({"metrics": metric_record})
82
- return Message(content=content, reply_to=msg)
@@ -1,110 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import warnings
4
-
5
- import numpy as np
6
- import xgboost as xgb
7
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
8
- from flwr.clientapp import ClientApp
9
- from flwr.common.config import unflatten_dict
10
-
11
- from $import_name.task import load_data, replace_keys
12
-
13
- warnings.filterwarnings("ignore", category=UserWarning)
14
-
15
-
16
- # Flower ClientApp
17
- app = ClientApp()
18
-
19
-
20
- def _local_boost(bst_input, num_local_round, train_dmatrix):
21
- # Update trees based on local training data.
22
- for i in range(num_local_round):
23
- bst_input.update(train_dmatrix, bst_input.num_boosted_rounds())
24
-
25
- # Bagging: extract the last N=num_local_round trees for sever aggregation
26
- bst = bst_input[
27
- bst_input.num_boosted_rounds()
28
- - num_local_round : bst_input.num_boosted_rounds()
29
- ]
30
- return bst
31
-
32
-
33
- @app.train()
34
- def train(msg: Message, context: Context) -> Message:
35
- # Load model and data
36
- partition_id = context.node_config["partition-id"]
37
- num_partitions = context.node_config["num-partitions"]
38
- train_dmatrix, _, num_train, _ = load_data(partition_id, num_partitions)
39
-
40
- # Read from run config
41
- num_local_round = context.run_config["local-epochs"]
42
- # Flatted config dict and replace "-" with "_"
43
- cfg = replace_keys(unflatten_dict(context.run_config))
44
- params = cfg["params"]
45
-
46
- global_round = msg.content["config"]["server-round"]
47
- if global_round == 1:
48
- # First round local training
49
- bst = xgb.train(
50
- params,
51
- train_dmatrix,
52
- num_boost_round=num_local_round,
53
- )
54
- else:
55
- bst = xgb.Booster(params=params)
56
- global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
57
-
58
- # Load global model into booster
59
- bst.load_model(global_model)
60
-
61
- # Local training
62
- bst = _local_boost(bst, num_local_round, train_dmatrix)
63
-
64
- # Save model
65
- local_model = bst.save_raw("json")
66
- model_np = np.frombuffer(local_model, dtype=np.uint8)
67
-
68
- # Construct reply message
69
- # Note: we store the model as the first item in a list into ArrayRecord,
70
- # which can be accessed using index ["0"].
71
- model_record = ArrayRecord([model_np])
72
- metrics = {
73
- "num-examples": num_train,
74
- }
75
- metric_record = MetricRecord(metrics)
76
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
77
- return Message(content=content, reply_to=msg)
78
-
79
-
80
- @app.evaluate()
81
- def evaluate(msg: Message, context: Context) -> Message:
82
- # Load model and data
83
- partition_id = context.node_config["partition-id"]
84
- num_partitions = context.node_config["num-partitions"]
85
- _, valid_dmatrix, _, num_val = load_data(partition_id, num_partitions)
86
-
87
- # Load config
88
- cfg = replace_keys(unflatten_dict(context.run_config))
89
- params = cfg["params"]
90
-
91
- # Load global model
92
- bst = xgb.Booster(params=params)
93
- global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
94
- bst.load_model(global_model)
95
-
96
- # Run evaluation
97
- eval_results = bst.eval_set(
98
- evals=[(valid_dmatrix, "valid")],
99
- iteration=bst.num_boosted_rounds() - 1,
100
- )
101
- auc = float(eval_results.split("\t")[1].split(":")[1])
102
-
103
- # Construct and return reply Message
104
- metrics = {
105
- "auc": auc,
106
- "num-examples": num_val,
107
- }
108
- metric_record = MetricRecord(metrics)
109
- content = RecordDict({"metrics": metric_record})
110
- return Message(content=content, reply_to=msg)
@@ -1,36 +0,0 @@
1
- """$project_name: A Flower Baseline."""
2
-
3
- from flwr_datasets import FederatedDataset
4
- from flwr_datasets.partitioner import IidPartitioner
5
- from torch.utils.data import DataLoader
6
- from torchvision.transforms import Compose, Normalize, ToTensor
7
-
8
- FDS = None # Cache FederatedDataset
9
-
10
-
11
- def load_data(partition_id: int, num_partitions: int):
12
- """Load partition CIFAR10 data."""
13
- # Only initialize `FederatedDataset` once
14
- global FDS # pylint: disable=global-statement
15
- if FDS is None:
16
- partitioner = IidPartitioner(num_partitions=num_partitions)
17
- FDS = FederatedDataset(
18
- dataset="uoft-cs/cifar10",
19
- partitioners={"train": partitioner},
20
- )
21
- partition = FDS.load_partition(partition_id)
22
- # Divide data on each node: 80% train, 20% test
23
- partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
24
- pytorch_transforms = Compose(
25
- [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
26
- )
27
-
28
- def apply_transforms(batch):
29
- """Apply transforms to the partition from FederatedDataset."""
30
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
31
- return batch
32
-
33
- partition_train_test = partition_train_test.with_transform(apply_transforms)
34
- trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
35
- testloader = DataLoader(partition_train_test["test"], batch_size=32)
36
- return trainloader, testloader
@@ -1,15 +0,0 @@
1
- # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Flower CLI `new` command app / code / flwr_tune templates."""
@@ -1,92 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- import os
4
- import warnings
5
-
6
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
7
- from flwr.clientapp import ClientApp
8
- from flwr.common.config import unflatten_dict
9
- from omegaconf import DictConfig
10
- from peft import get_peft_model_state_dict, set_peft_model_state_dict
11
- from transformers import TrainingArguments
12
- from trl import SFTTrainer
13
-
14
- from $import_name.dataset import (
15
- get_tokenizer_and_data_collator_and_propt_formatting,
16
- load_data,
17
- replace_keys,
18
- )
19
- from $import_name.models import cosine_annealing, get_model
20
-
21
- # Avoid warnings
22
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
23
- os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
24
- warnings.filterwarnings("ignore", category=UserWarning)
25
-
26
-
27
- # Avoid warnings
28
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
29
- os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
30
- warnings.filterwarnings("ignore", category=UserWarning)
31
-
32
-
33
- # Flower ClientApp
34
- app = ClientApp()
35
-
36
-
37
- @app.train()
38
- def train(msg: Message, context: Context):
39
- """Train the model on local data."""
40
- # Parse config
41
- partition_id = context.node_config["partition-id"]
42
- num_partitions = context.node_config["num-partitions"]
43
- num_rounds = context.run_config["num-server-rounds"]
44
- cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
45
- training_arguments = TrainingArguments(**cfg.train.training_arguments)
46
-
47
- # Let's get the client partition
48
- trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
49
- (
50
- tokenizer,
51
- data_collator,
52
- formatting_prompts_func,
53
- ) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
54
-
55
- # Load the model and initialize it with the received weights
56
- model = get_model(cfg.model)
57
- set_peft_model_state_dict(model, msg.content["arrays"].to_torch_state_dict())
58
-
59
- # Set learning rate for current round
60
- new_lr = cosine_annealing(
61
- msg.content["config"]["server-round"],
62
- num_rounds,
63
- cfg.train.learning_rate_max,
64
- cfg.train.learning_rate_min,
65
- )
66
-
67
- training_arguments.learning_rate = new_lr
68
- training_arguments.output_dir = msg.content["config"]["save_path"]
69
-
70
- # Construct trainer
71
- trainer = SFTTrainer(
72
- model=model,
73
- tokenizer=tokenizer,
74
- args=training_arguments,
75
- max_seq_length=cfg.train.seq_length,
76
- train_dataset=trainset,
77
- formatting_func=formatting_prompts_func,
78
- data_collator=data_collator,
79
- )
80
-
81
- # Do local training
82
- results = trainer.train()
83
-
84
- # Construct and return reply Message
85
- model_record = ArrayRecord(get_peft_model_state_dict(model))
86
- metrics = {
87
- "train_loss": results.training_loss,
88
- "num-examples": len(trainset),
89
- }
90
- metric_record = MetricRecord(metrics)
91
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
92
- return Message(content=content, reply_to=msg)
@@ -1,87 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- from flwr_datasets import FederatedDataset
4
- from flwr_datasets.partitioner import IidPartitioner
5
- from transformers import AutoTokenizer
6
- from trl import DataCollatorForCompletionOnlyLM
7
-
8
- FDS = None # Cache FederatedDataset
9
-
10
-
11
- def formatting_prompts_func(example):
12
- """Construct prompts."""
13
- output_texts = []
14
- # Constructing a standard Alpaca
15
- # (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
16
- mssg = (
17
- "Below is an instruction that describes a task. "
18
- "Write a response that appropriately completes the request."
19
- )
20
- for i in range(len(example["instruction"])):
21
- text = (
22
- f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
23
- f"### Response: {example['response'][i]}"
24
- )
25
- output_texts.append(text)
26
- return output_texts
27
-
28
-
29
- def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
30
- """Get tokenizer, data_collator and prompt formatting."""
31
- tokenizer = AutoTokenizer.from_pretrained(
32
- model_name, use_fast=True, padding_side="right"
33
- )
34
- tokenizer.pad_token = tokenizer.eos_token
35
- response_template_with_context = "\n### Response:" # alpaca response tag
36
- response_template_ids = tokenizer.encode(
37
- response_template_with_context, add_special_tokens=False
38
- )[2:]
39
- data_collator = DataCollatorForCompletionOnlyLM(
40
- response_template_ids, tokenizer=tokenizer
41
- )
42
-
43
- return tokenizer, data_collator, formatting_prompts_func
44
-
45
-
46
- def formatting(dataset):
47
- """Format dataset."""
48
- dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
49
- return dataset
50
-
51
-
52
- def reformat(dataset, llm_task):
53
- """Reformat datasets."""
54
- dataset = dataset.rename_column("output", "response")
55
- if llm_task in ["finance", "code"]:
56
- dataset = dataset.map(formatting, remove_columns=["input"])
57
- if llm_task == "medical":
58
- dataset = dataset.remove_columns(["instruction"])
59
- dataset = dataset.rename_column("input", "instruction")
60
- return dataset
61
-
62
-
63
- def load_data(partition_id: int, num_partitions: int, dataset_name: str):
64
- """Load partition data."""
65
- # Only initialize `FederatedDataset` once
66
- global FDS
67
- if FDS is None:
68
- partitioner = IidPartitioner(num_partitions=num_partitions)
69
- FDS = FederatedDataset(
70
- dataset=dataset_name,
71
- partitioners={"train": partitioner},
72
- )
73
- client_trainset = FDS.load_partition(partition_id, "train")
74
- client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
75
- return client_trainset
76
-
77
-
78
- def replace_keys(input_dict, match="-", target="_"):
79
- """Recursively replace match string with target string in dictionary keys."""
80
- new_dict = {}
81
- for key, value in input_dict.items():
82
- new_key = key.replace(match, target)
83
- if isinstance(value, dict):
84
- new_dict[new_key] = replace_keys(value, match, target)
85
- else:
86
- new_dict[new_key] = value
87
- return new_dict
@@ -1,56 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- import math
4
-
5
- import torch
6
- from omegaconf import DictConfig
7
- from peft import LoraConfig, get_peft_model
8
- from peft.utils import prepare_model_for_kbit_training
9
- from transformers import AutoModelForCausalLM, BitsAndBytesConfig
10
-
11
-
12
- def cosine_annealing(
13
- current_round: int,
14
- total_round: int,
15
- lrate_max: float = 0.001,
16
- lrate_min: float = 0.0,
17
- ) -> float:
18
- """Implement cosine annealing learning rate schedule."""
19
- cos_inner = math.pi * current_round / total_round
20
- return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
21
-
22
-
23
- def get_model(model_cfg: DictConfig):
24
- """Load model with appropriate quantization config and other optimizations.
25
- """
26
- if model_cfg.quantization == 4:
27
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
28
- elif model_cfg.quantization == 8:
29
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
30
- else:
31
- raise ValueError(
32
- f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
33
- )
34
-
35
- model = AutoModelForCausalLM.from_pretrained(
36
- model_cfg.name,
37
- quantization_config=quantization_config,
38
- torch_dtype=torch.bfloat16,
39
- low_cpu_mem_usage=True,
40
- )
41
-
42
- model = prepare_model_for_kbit_training(
43
- model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
44
- )
45
-
46
- peft_config = LoraConfig(
47
- r=model_cfg.lora.peft_lora_r,
48
- lora_alpha=model_cfg.lora.peft_lora_alpha,
49
- lora_dropout=0.075,
50
- task_type="CAUSAL_LM",
51
- )
52
-
53
- if model_cfg.gradient_checkpointing:
54
- model.config.use_cache = False
55
-
56
- return get_peft_model(model, peft_config)
@@ -1,73 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- import os
4
- from datetime import datetime
5
-
6
- from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
7
- from flwr.common.config import unflatten_dict
8
- from flwr.serverapp import Grid, ServerApp
9
- from omegaconf import DictConfig
10
- from peft import get_peft_model_state_dict, set_peft_model_state_dict
11
-
12
- from $import_name.dataset import replace_keys
13
- from $import_name.models import get_model
14
- from $import_name.strategy import FlowerTuneLlm
15
-
16
- # Create ServerApp
17
- app = ServerApp()
18
-
19
-
20
- @app.main()
21
- def main(grid: Grid, context: Context) -> None:
22
- """Main entry point for the ServerApp."""
23
- # Create output directory given current timestamp
24
- current_time = datetime.now()
25
- folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
26
- save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
27
- os.makedirs(save_path, exist_ok=True)
28
-
29
- # Read from config
30
- num_rounds = context.run_config["num-server-rounds"]
31
- cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
32
-
33
- # Get initial model weights
34
- init_model = get_model(cfg.model)
35
- arrays = ArrayRecord(get_peft_model_state_dict(init_model))
36
-
37
- # Define strategy
38
- strategy = FlowerTuneLlm(
39
- fraction_train=cfg.strategy.fraction_train,
40
- fraction_evaluate=cfg.strategy.fraction_evaluate,
41
- )
42
-
43
- # Start strategy, run FedAvg for `num_rounds`
44
- strategy.start(
45
- grid=grid,
46
- initial_arrays=arrays,
47
- train_config=ConfigRecord({"save_path": save_path}),
48
- num_rounds=num_rounds,
49
- evaluate_fn=get_evaluate_fn(
50
- cfg.model, cfg.train.save_every_round, num_rounds, save_path
51
- ),
52
- )
53
-
54
-
55
- # Get function that will be executed by the strategy
56
- # Here we use it to save global model checkpoints
57
- def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
58
- """Return an evaluation function for saving global model."""
59
-
60
- def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
61
- # Save model
62
- if server_round != 0 and (
63
- server_round == total_round or server_round % save_every_round == 0
64
- ):
65
- # Init model
66
- model = get_model(model_cfg)
67
- set_peft_model_state_dict(model, arrays.to_torch_state_dict())
68
-
69
- model.save_pretrained(f"{save_path}/peft_{server_round}")
70
-
71
- return MetricRecord()
72
-
73
- return evaluate
@@ -1,78 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- from collections.abc import Iterable
4
- from logging import INFO, WARN
5
- from typing import Optional
6
-
7
- from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
8
- from flwr.common import log
9
- from flwr.serverapp import Grid
10
- from flwr.serverapp.strategy import FedAvg
11
-
12
-
13
- class FlowerTuneLlm(FedAvg):
14
- """Customised FedAvg strategy implementation.
15
-
16
- This class behaves just like FedAvg but also tracks the communication
17
- costs associated with `train` over FL rounds.
18
- """
19
- def __init__(self, **kwargs):
20
- super().__init__(**kwargs)
21
- self.comm_tracker = CommunicationTracker()
22
-
23
- def configure_train(
24
- self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
25
- ) -> Iterable[Message]:
26
- """Configure the next round of training."""
27
- messages = super().configure_train(server_round, arrays, config, grid)
28
-
29
- # Track communication costs
30
- self.comm_tracker.track(messages)
31
-
32
- return messages
33
-
34
- def aggregate_train(
35
- self,
36
- server_round: int,
37
- replies: Iterable[Message],
38
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
39
- """Aggregate ArrayRecords and MetricRecords in the received Messages."""
40
- # Track communication costs
41
- self.comm_tracker.track(replies)
42
-
43
- arrays, metrics = super().aggregate_train(server_round, replies)
44
-
45
- return arrays, metrics
46
-
47
-
48
- class CommunicationTracker:
49
- """Communication costs tracker over FL rounds."""
50
- def __init__(self):
51
- self.curr_comm_cost = 0.0
52
-
53
- def track(self, messages: Iterable[Message]):
54
- comm_cost = (
55
- sum(
56
- record.count_bytes()
57
- for msg in messages
58
- if msg.has_content()
59
- for record in msg.content.array_records.values()
60
- )
61
- / 1024**2
62
- )
63
-
64
- self.curr_comm_cost += comm_cost
65
- log(
66
- INFO,
67
- "Communication budget: used %.2f MB (+%.2f MB this round) / 200,000 MB",
68
- self.curr_comm_cost,
69
- comm_cost,
70
- )
71
-
72
- if self.curr_comm_cost > 2e5:
73
- log(
74
- WARN,
75
- "The accumulated communication cost has exceeded 200,000 MB. "
76
- "Please consider reducing it if you plan to participate "
77
- "FlowerTune LLM Leaderboard.",
78
- )