flwr-nightly 1.10.0.dev20240612__py3-none-any.whl → 1.10.0.dev20240624__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (130) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +6 -8
  3. flwr/cli/config_utils.py +53 -3
  4. flwr/cli/install.py +35 -20
  5. flwr/cli/new/new.py +104 -28
  6. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  7. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  8. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
  10. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  11. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  12. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  14. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
  16. flwr/cli/run/run.py +46 -2
  17. flwr/client/__init__.py +1 -1
  18. flwr/client/app.py +22 -10
  19. flwr/client/client_app.py +1 -1
  20. flwr/client/dpfedavg_numpy_client.py +1 -1
  21. flwr/client/grpc_adapter_client/__init__.py +15 -0
  22. flwr/client/grpc_adapter_client/connection.py +94 -0
  23. flwr/client/grpc_client/connection.py +5 -1
  24. flwr/client/grpc_rere_client/__init__.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +9 -2
  26. flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
  27. flwr/client/message_handler/__init__.py +1 -1
  28. flwr/client/message_handler/message_handler.py +1 -1
  29. flwr/client/mod/__init__.py +4 -4
  30. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  31. flwr/client/mod/utils.py +1 -1
  32. flwr/client/rest_client/__init__.py +1 -1
  33. flwr/client/rest_client/connection.py +10 -2
  34. flwr/client/supernode/app.py +141 -41
  35. flwr/common/__init__.py +12 -12
  36. flwr/common/address.py +1 -1
  37. flwr/common/config.py +73 -0
  38. flwr/common/constant.py +16 -1
  39. flwr/common/date.py +1 -1
  40. flwr/common/dp.py +1 -1
  41. flwr/common/grpc.py +1 -1
  42. flwr/common/object_ref.py +39 -5
  43. flwr/common/record/__init__.py +1 -1
  44. flwr/common/secure_aggregation/__init__.py +1 -1
  45. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  46. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  47. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  48. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  49. flwr/common/secure_aggregation/quantization.py +1 -1
  50. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  51. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  52. flwr/common/telemetry.py +4 -0
  53. flwr/common/typing.py +9 -0
  54. flwr/common/version.py +14 -0
  55. flwr/proto/exec_pb2.py +34 -0
  56. flwr/proto/exec_pb2.pyi +55 -0
  57. flwr/proto/exec_pb2_grpc.py +101 -0
  58. flwr/proto/exec_pb2_grpc.pyi +41 -0
  59. flwr/proto/fab_pb2.py +30 -0
  60. flwr/proto/fab_pb2.pyi +56 -0
  61. flwr/proto/fab_pb2_grpc.py +4 -0
  62. flwr/proto/fab_pb2_grpc.pyi +4 -0
  63. flwr/server/__init__.py +2 -2
  64. flwr/server/app.py +62 -25
  65. flwr/server/compat/app.py +1 -1
  66. flwr/server/compat/app_utils.py +1 -1
  67. flwr/server/compat/driver_client_proxy.py +1 -1
  68. flwr/server/driver/driver.py +6 -0
  69. flwr/server/driver/grpc_driver.py +85 -63
  70. flwr/server/driver/inmemory_driver.py +28 -26
  71. flwr/server/run_serverapp.py +65 -20
  72. flwr/server/strategy/__init__.py +2 -2
  73. flwr/server/strategy/bulyan.py +1 -1
  74. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  75. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  76. flwr/server/strategy/fedadagrad.py +1 -1
  77. flwr/server/strategy/fedadam.py +1 -1
  78. flwr/server/strategy/fedavg_android.py +1 -1
  79. flwr/server/strategy/fedavgm.py +1 -1
  80. flwr/server/strategy/fedmedian.py +1 -1
  81. flwr/server/strategy/fedopt.py +1 -1
  82. flwr/server/strategy/fedprox.py +1 -1
  83. flwr/server/strategy/fedxgb_bagging.py +1 -1
  84. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  85. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  86. flwr/server/strategy/fedyogi.py +1 -1
  87. flwr/server/strategy/krum.py +1 -1
  88. flwr/server/strategy/qfedavg.py +1 -1
  89. flwr/server/superlink/driver/__init__.py +1 -1
  90. flwr/server/superlink/driver/driver_grpc.py +1 -1
  91. flwr/server/superlink/driver/driver_servicer.py +15 -3
  92. flwr/server/superlink/fleet/__init__.py +1 -1
  93. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  94. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  95. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  96. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  97. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  98. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  99. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -1
  100. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  101. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  102. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  103. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -4
  104. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  106. flwr/server/superlink/fleet/vce/backend/raybackend.py +44 -25
  107. flwr/server/superlink/fleet/vce/vce_api.py +3 -1
  108. flwr/server/superlink/state/__init__.py +1 -1
  109. flwr/server/superlink/state/in_memory_state.py +9 -6
  110. flwr/server/superlink/state/sqlite_state.py +7 -4
  111. flwr/server/superlink/state/state.py +6 -5
  112. flwr/server/superlink/state/state_factory.py +11 -2
  113. flwr/server/utils/__init__.py +1 -1
  114. flwr/server/utils/tensorboard.py +1 -1
  115. flwr/simulation/__init__.py +5 -2
  116. flwr/simulation/app.py +1 -1
  117. flwr/simulation/ray_transport/__init__.py +1 -1
  118. flwr/simulation/ray_transport/ray_actor.py +0 -6
  119. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  120. flwr/simulation/run_simulation.py +63 -22
  121. flwr/superexec/__init__.py +21 -0
  122. flwr/superexec/app.py +178 -0
  123. flwr/superexec/exec_grpc.py +51 -0
  124. flwr/superexec/exec_servicer.py +65 -0
  125. flwr/superexec/executor.py +54 -0
  126. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/METADATA +2 -1
  127. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/RECORD +130 -101
  128. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/entry_points.txt +1 -0
  129. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/LICENSE +0 -0
  130. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/WHEEL +0 -0
@@ -0,0 +1,124 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from collections import OrderedDict
4
+ from typing import Callable, Dict, Tuple
5
+
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from peft import get_peft_model_state_dict, set_peft_model_state_dict
9
+ from transformers import TrainingArguments
10
+ from trl import SFTTrainer
11
+
12
+ from flwr.client import NumPyClient
13
+ from flwr.common.typing import NDArrays, Scalar
14
+ from $import_name.dataset import reformat
15
+ from $import_name.models import cosine_annealing, get_model
16
+
17
+
18
+ # pylint: disable=too-many-arguments
19
+ # pylint: disable=too-many-instance-attributes
20
+ class FlowerClient(NumPyClient):
21
+ """Standard Flower client for CNN training."""
22
+
23
+ def __init__(
24
+ self,
25
+ model_cfg: DictConfig,
26
+ train_cfg: DictConfig,
27
+ trainset,
28
+ tokenizer,
29
+ formatting_prompts_func,
30
+ data_collator,
31
+ save_path,
32
+ ): # pylint: disable=too-many-arguments
33
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+ self.train_cfg = train_cfg
35
+ self.training_argumnets = TrainingArguments(**train_cfg.training_arguments)
36
+ self.tokenizer = tokenizer
37
+ self.formatting_prompts_func = formatting_prompts_func
38
+ self.data_collator = data_collator
39
+ self.save_path = save_path
40
+
41
+ # instantiate model
42
+ self.model = get_model(model_cfg)
43
+
44
+ self.trainset = trainset
45
+
46
+ def fit(
47
+ self, parameters: NDArrays, config: Dict[str, Scalar]
48
+ ) -> Tuple[NDArrays, int, Dict]:
49
+ """Implement distributed fit function for a given client."""
50
+ set_parameters(self.model, parameters)
51
+
52
+ new_lr = cosine_annealing(
53
+ int(config["current_round"]),
54
+ self.train_cfg.num_rounds,
55
+ self.train_cfg.learning_rate_max,
56
+ self.train_cfg.learning_rate_min,
57
+ )
58
+
59
+ self.training_argumnets.learning_rate = new_lr
60
+ self.training_argumnets.output_dir = self.save_path
61
+
62
+ # Construct trainer
63
+ trainer = SFTTrainer(
64
+ model=self.model,
65
+ tokenizer=self.tokenizer,
66
+ args=self.training_argumnets,
67
+ max_seq_length=self.train_cfg.seq_length,
68
+ train_dataset=self.trainset,
69
+ formatting_func=self.formatting_prompts_func,
70
+ data_collator=self.data_collator,
71
+ )
72
+
73
+ # Do local training
74
+ results = trainer.train()
75
+
76
+ return (
77
+ get_parameters(self.model),
78
+ len(self.trainset),
79
+ {"train_loss": results.training_loss},
80
+ )
81
+
82
+
83
+ def set_parameters(model, parameters: NDArrays) -> None:
84
+ """Change the parameters of the model using the given ones."""
85
+ peft_state_dict_keys = get_peft_model_state_dict(model).keys()
86
+ params_dict = zip(peft_state_dict_keys, parameters)
87
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
88
+ set_peft_model_state_dict(model, state_dict)
89
+
90
+
91
+ def get_parameters(model) -> NDArrays:
92
+ """Return the parameters of the current net."""
93
+ state_dict = get_peft_model_state_dict(model)
94
+ return [val.cpu().numpy() for _, val in state_dict.items()]
95
+
96
+
97
+ def gen_client_fn(
98
+ fds,
99
+ tokenizer,
100
+ formatting_prompts_func,
101
+ data_collator,
102
+ model_cfg: DictConfig,
103
+ train_cfg: DictConfig,
104
+ save_path: str,
105
+ ) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
106
+ """Generate the client function that creates the Flower Clients."""
107
+
108
+ def client_fn(cid: str) -> FlowerClient:
109
+ """Create a Flower client representing a single organization."""
110
+ # Let's get the partition corresponding to the i-th client
111
+ client_trainset = fds.load_partition(int(cid), "train")
112
+ client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
113
+
114
+ return FlowerClient(
115
+ model_cfg,
116
+ train_cfg,
117
+ client_trainset,
118
+ tokenizer,
119
+ formatting_prompts_func,
120
+ data_collator,
121
+ save_path,
122
+ ).to_client()
123
+
124
+ return client_fn
@@ -0,0 +1,34 @@
1
+ # Federated Instruction Tuning
2
+ ---
3
+ model:
4
+ name: "mistralai/Mistral-7B-v0.3"
5
+ quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes
6
+ gradient_checkpointing: True
7
+ lora:
8
+ peft_lora_r: 32
9
+ peft_lora_alpha: 64
10
+
11
+ train:
12
+ num_rounds: null
13
+ save_every_round: 5
14
+ learning_rate_max: 5e-5
15
+ learning_rate_min: 1e-6
16
+ seq_length: 512
17
+ training_arguments:
18
+ output_dir: null # to be set by hydra
19
+ learning_rate: null # to be set by the client
20
+ per_device_train_batch_size: 16
21
+ gradient_accumulation_steps: 1
22
+ logging_steps: 10
23
+ num_train_epochs: 3
24
+ max_steps: 10
25
+ report_to: null
26
+ save_steps: 1000
27
+ save_total_limit: 10
28
+ gradient_checkpointing: True
29
+ lr_scheduler_type: "constant"
30
+
31
+ strategy:
32
+ _target_: flwr.server.strategy.FedAvg
33
+ fraction_fit: $fraction_fit
34
+ fraction_evaluate: 0.0 # no client evaluation
@@ -0,0 +1,57 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from transformers import AutoTokenizer
4
+ from trl import DataCollatorForCompletionOnlyLM
5
+
6
+
7
+ def formatting_prompts_func(example):
8
+ """Construct prompts."""
9
+ output_texts = []
10
+ # Constructing a standard Alpaca
11
+ # (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
12
+ mssg = (
13
+ "Below is an instruction that describes a task. "
14
+ "Write a response that appropriately completes the request."
15
+ )
16
+ for i in range(len(example["instruction"])):
17
+ text = (
18
+ f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
19
+ f"### Response: {example['response'][i]}"
20
+ )
21
+ output_texts.append(text)
22
+ return output_texts
23
+
24
+
25
+ def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
26
+ """Get tokenizer, data_collator and prompt formatting."""
27
+ # From: https://huggingface.co/docs/trl/en/sft_trainer
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_name, use_fast=True, padding_side="right"
30
+ )
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+ response_template_with_context = "\n### Response:" # alpaca response tag
33
+ response_template_ids = tokenizer.encode(
34
+ response_template_with_context, add_special_tokens=False
35
+ )[2:]
36
+ data_collator = DataCollatorForCompletionOnlyLM(
37
+ response_template_ids, tokenizer=tokenizer
38
+ )
39
+
40
+ return tokenizer, data_collator, formatting_prompts_func
41
+
42
+
43
+ def formatting(dataset):
44
+ """Format dataset."""
45
+ dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
46
+ return dataset
47
+
48
+
49
+ def reformat(dataset, llm_task):
50
+ """Reformat datasets."""
51
+ dataset = dataset.rename_column("output", "response")
52
+ if llm_task == "finance" or llm_task == "code":
53
+ dataset = dataset.map(formatting, remove_columns=["input"])
54
+ if llm_task == "medical":
55
+ dataset = dataset.remove_columns(["instruction"])
56
+ dataset = dataset.rename_column("input", "instruction")
57
+ return dataset
@@ -0,0 +1,59 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from peft import LoraConfig, get_peft_model
8
+ from peft.utils import prepare_model_for_kbit_training
9
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
10
+
11
+
12
+ def cosine_annealing(
13
+ current_round: int,
14
+ total_round: int,
15
+ lrate_max: float = 0.001,
16
+ lrate_min: float = 0.0,
17
+ ) -> float:
18
+ """Implement cosine annealing learning rate schedule."""
19
+ cos_inner = math.pi * current_round / total_round
20
+ return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
21
+
22
+
23
+ def get_model(model_cfg: DictConfig):
24
+ """Load model with appropriate quantization config and other optimizations.
25
+
26
+ Please refer to this example for `peft + BitsAndBytes`:
27
+ https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
28
+ """
29
+ if model_cfg.quantization == 4:
30
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
31
+ elif model_cfg.quantization == 8:
32
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
33
+ else:
34
+ raise ValueError(
35
+ f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
36
+ )
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_cfg.name,
40
+ quantization_config=quantization_config,
41
+ torch_dtype=torch.bfloat16,
42
+ low_cpu_mem_usage=True,
43
+ )
44
+
45
+ model = prepare_model_for_kbit_training(
46
+ model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
47
+ )
48
+
49
+ peft_config = LoraConfig(
50
+ r=model_cfg.lora.peft_lora_r,
51
+ lora_alpha=model_cfg.lora.peft_lora_alpha,
52
+ lora_dropout=0.075,
53
+ task_type="CAUSAL_LM",
54
+ )
55
+
56
+ if model_cfg.gradient_checkpointing:
57
+ model.config.use_cache = False
58
+
59
+ return get_peft_model(model, peft_config)
@@ -0,0 +1,48 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from $import_name.client import set_parameters
4
+ from $import_name.models import get_model
5
+
6
+
7
+ # Get function that will be executed by the strategy's evaluate() method
8
+ # Here we use it to save global model checkpoints
9
+ def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
10
+ """Return an evaluation function for saving global model."""
11
+
12
+ def evaluate(server_round: int, parameters, config):
13
+ # Save model
14
+ if server_round != 0 and (
15
+ server_round == total_round or server_round % save_every_round == 0
16
+ ):
17
+ # Init model
18
+ model = get_model(model_cfg)
19
+ set_parameters(model, parameters)
20
+
21
+ model.save_pretrained(f"{save_path}/peft_{server_round}")
22
+
23
+ return 0.0, {}
24
+
25
+ return evaluate
26
+
27
+
28
+ def get_on_fit_config():
29
+ """
30
+ Return a function that will be used to construct the config
31
+ that the client's fit() method will receive.
32
+ """
33
+
34
+ def fit_config_fn(server_round: int):
35
+ fit_config = {"current_round": server_round}
36
+ return fit_config
37
+
38
+ return fit_config_fn
39
+
40
+
41
+ def fit_weighted_average(metrics):
42
+ """Aggregate (federated) evaluation metrics."""
43
+ # Multiply accuracy of each client by number of examples used
44
+ losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
45
+ examples = [num_examples for num_examples, _ in metrics]
46
+
47
+ # Aggregate and return custom metric (weighted average)
48
+ return {"train_loss": sum(losses) / sum(examples)}
@@ -0,0 +1,11 @@
1
+ # Federated Instruction Tuning (static)
2
+ ---
3
+ dataset:
4
+ name: $dataset_name
5
+
6
+ # FL experimental settings
7
+ num_clients: $num_clients # total number of clients
8
+ num_rounds: 200
9
+ partitioner:
10
+ _target_: flwr_datasets.partitioner.IidPartitioner
11
+ num_partitions: $num_clients
@@ -0,0 +1,42 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = { text = "Apache License (2.0)" }
13
+ dependencies = [
14
+ "flwr[simulation]>=1.9.0,<2.0",
15
+ "flwr-datasets>=0.1.0,<1.0.0",
16
+ "hydra-core==1.3.2",
17
+ "trl==0.8.1",
18
+ "bitsandbytes==0.43.0",
19
+ "scipy==1.13.0",
20
+ "peft==0.6.2",
21
+ "transformers==4.39.3",
22
+ "sentencepiece==0.2.0",
23
+ ]
24
+
25
+ [tool.hatch.build.targets.wheel]
26
+ packages = ["."]
27
+
28
+ [flower]
29
+ publisher = "$username"
30
+
31
+ [flower.components]
32
+ serverapp = "$import_name.app:server"
33
+ clientapp = "$import_name.app:client"
34
+
35
+ [flower.engine]
36
+ name = "simulation"
37
+
38
+ [flower.engine.simulation.supernode]
39
+ num = $num_clients
40
+
41
+ [flower.engine.simulation]
42
+ backend_config = { client_resources = { num_cpus = 8, num_gpus = 1.0 } }
flwr/cli/run/run.py CHANGED
@@ -16,12 +16,18 @@
16
16
 
17
17
  import sys
18
18
  from enum import Enum
19
+ from logging import DEBUG
19
20
  from typing import Optional
20
21
 
21
22
  import typer
22
23
  from typing_extensions import Annotated
23
24
 
24
25
  from flwr.cli import config_utils
26
+ from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
27
+ from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
28
+ from flwr.common.logger import log
29
+ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
30
+ from flwr.proto.exec_pb2_grpc import ExecStub
25
31
  from flwr.simulation.run_simulation import _run_simulation
26
32
 
27
33
 
@@ -31,20 +37,35 @@ class Engine(str, Enum):
31
37
  SIMULATION = "simulation"
32
38
 
33
39
 
40
+ # pylint: disable-next=too-many-locals
34
41
  def run(
35
42
  engine: Annotated[
36
43
  Optional[Engine],
37
- typer.Option(case_sensitive=False, help="The ML framework to use"),
44
+ typer.Option(
45
+ case_sensitive=False,
46
+ help="The engine to run FL with (currently only simulation is supported).",
47
+ ),
38
48
  ] = None,
49
+ use_superexec: Annotated[
50
+ bool,
51
+ typer.Option(
52
+ case_sensitive=False, help="Use this flag to use the new SuperExec API"
53
+ ),
54
+ ] = False,
39
55
  ) -> None:
40
56
  """Run Flower project."""
57
+ if use_superexec:
58
+ _start_superexec_run()
59
+ return
60
+
41
61
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
42
62
 
43
63
  config, errors, warnings = config_utils.load_and_validate()
44
64
 
45
65
  if config is None:
46
66
  typer.secho(
47
- "Project configuration could not be loaded.\npyproject.toml is invalid:\n"
67
+ "Project configuration could not be loaded.\n"
68
+ "pyproject.toml is invalid:\n"
48
69
  + "\n".join([f"- {line}" for line in errors]),
49
70
  fg=typer.colors.RED,
50
71
  bold=True,
@@ -69,12 +90,16 @@ def run(
69
90
 
70
91
  if engine == Engine.SIMULATION:
71
92
  num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
93
+ backend_config = config["flower"]["engine"]["simulation"].get(
94
+ "backend_config", None
95
+ )
72
96
 
73
97
  typer.secho("Starting run... ", fg=typer.colors.BLUE)
74
98
  _run_simulation(
75
99
  server_app_attr=server_app_ref,
76
100
  client_app_attr=client_app_ref,
77
101
  num_supernodes=num_supernodes,
102
+ backend_config=backend_config,
78
103
  )
79
104
  else:
80
105
  typer.secho(
@@ -82,3 +107,22 @@ def run(
82
107
  fg=typer.colors.RED,
83
108
  bold=True,
84
109
  )
110
+
111
+
112
+ def _start_superexec_run() -> None:
113
+ def on_channel_state_change(channel_connectivity: str) -> None:
114
+ """Log channel connectivity."""
115
+ log(DEBUG, channel_connectivity)
116
+
117
+ channel = create_channel(
118
+ server_address=SUPEREXEC_DEFAULT_ADDRESS,
119
+ insecure=True,
120
+ root_certificates=None,
121
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
122
+ interceptors=None,
123
+ )
124
+ channel.subscribe(on_channel_state_change)
125
+ stub = ExecStub(channel)
126
+
127
+ req = StartRunRequest()
128
+ stub.StartRun(req)
flwr/client/__init__.py CHANGED
@@ -28,8 +28,8 @@ __all__ = [
28
28
  "Client",
29
29
  "ClientApp",
30
30
  "ClientFn",
31
- "mod",
32
31
  "NumPyClient",
32
+ "mod",
33
33
  "run_client_app",
34
34
  "run_supernode",
35
35
  "start_client",
flwr/client/app.py CHANGED
@@ -19,7 +19,7 @@ import sys
19
19
  import time
20
20
  from dataclasses import dataclass
21
21
  from logging import DEBUG, ERROR, INFO, WARN
22
- from typing import Callable, ContextManager, Optional, Tuple, Type, Union
22
+ from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
23
23
 
24
24
  from cryptography.hazmat.primitives.asymmetric import ec
25
25
  from grpc import RpcError
@@ -31,6 +31,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
31
31
  from flwr.common.address import parse_address
32
32
  from flwr.common.constant import (
33
33
  MISSING_EXTRA_REST,
34
+ TRANSPORT_TYPE_GRPC_ADAPTER,
34
35
  TRANSPORT_TYPE_GRPC_BIDI,
35
36
  TRANSPORT_TYPE_GRPC_RERE,
36
37
  TRANSPORT_TYPE_REST,
@@ -41,6 +42,7 @@ from flwr.common.logger import log, warn_deprecated_feature
41
42
  from flwr.common.message import Error
42
43
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
43
44
 
45
+ from .grpc_adapter_client.connection import grpc_adapter
44
46
  from .grpc_client.connection import grpc_connection
45
47
  from .grpc_rere_client.connection import grpc_request_response
46
48
  from .message_handler.message_handler import handle_control_message
@@ -177,7 +179,7 @@ def start_client(
177
179
  def _start_client_internal(
178
180
  *,
179
181
  server_address: str,
180
- load_client_app_fn: Optional[Callable[[], ClientApp]] = None,
182
+ load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
181
183
  client_fn: Optional[ClientFn] = None,
182
184
  client: Optional[Client] = None,
183
185
  grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
@@ -252,7 +254,7 @@ def _start_client_internal(
252
254
 
253
255
  client_fn = single_client_factory
254
256
 
255
- def _load_client_app() -> ClientApp:
257
+ def _load_client_app(_1: str, _2: str) -> ClientApp:
256
258
  return ClientApp(client_fn=client_fn)
257
259
 
258
260
  load_client_app_fn = _load_client_app
@@ -308,6 +310,8 @@ def _start_client_internal(
308
310
  )
309
311
 
310
312
  node_state = NodeState()
313
+ # run_id -> (fab_id, fab_version)
314
+ run_info: Dict[int, Tuple[str, str]] = {}
311
315
 
312
316
  while not app_state_tracker.interrupt:
313
317
  sleep_duration: int = 0
@@ -319,7 +323,6 @@ def _start_client_internal(
319
323
  root_certificates,
320
324
  authentication_keys,
321
325
  ) as conn:
322
- # pylint: disable-next=W0612
323
326
  receive, send, create_node, delete_node, get_run = conn
324
327
 
325
328
  # Register node
@@ -356,13 +359,20 @@ def _start_client_internal(
356
359
  send(out_message)
357
360
  break
358
361
 
362
+ # Get run info
363
+ run_id = message.metadata.run_id
364
+ if run_id not in run_info:
365
+ if get_run is not None:
366
+ run_info[run_id] = get_run(run_id)
367
+ # If get_run is None, i.e., in grpc-bidi mode
368
+ else:
369
+ run_info[run_id] = ("", "")
370
+
359
371
  # Register context for this run
360
- node_state.register_context(run_id=message.metadata.run_id)
372
+ node_state.register_context(run_id=run_id)
361
373
 
362
374
  # Retrieve context for this run
363
- context = node_state.retrieve_context(
364
- run_id=message.metadata.run_id
365
- )
375
+ context = node_state.retrieve_context(run_id=run_id)
366
376
 
367
377
  # Create an error reply message that will never be used to prevent
368
378
  # the used-before-assignment linting error
@@ -373,7 +383,7 @@ def _start_client_internal(
373
383
  # Handle app loading and task message
374
384
  try:
375
385
  # Load ClientApp instance
376
- client_app: ClientApp = load_client_app_fn()
386
+ client_app: ClientApp = load_client_app_fn(*run_info[run_id])
377
387
 
378
388
  # Execute ClientApp
379
389
  reply_message = client_app(message=message, context=context)
@@ -411,7 +421,7 @@ def _start_client_internal(
411
421
  else:
412
422
  # No exception, update node state
413
423
  node_state.update_context(
414
- run_id=message.metadata.run_id,
424
+ run_id=run_id,
415
425
  context=context,
416
426
  )
417
427
 
@@ -592,6 +602,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
592
602
  connection, error_type = http_request_response, RequestsConnectionError
593
603
  elif transport == TRANSPORT_TYPE_GRPC_RERE:
594
604
  connection, error_type = grpc_request_response, RpcError
605
+ elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
606
+ connection, error_type = grpc_adapter, RpcError
595
607
  elif transport == TRANSPORT_TYPE_GRPC_BIDI:
596
608
  connection, error_type = grpc_connection, RpcError
597
609
  else:
flwr/client/client_app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Client-side part of the GrpcAdapter transport layer."""