flwr-nightly 1.10.0.dev20240612__py3-none-any.whl → 1.10.0.dev20240624__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (130) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +6 -8
  3. flwr/cli/config_utils.py +53 -3
  4. flwr/cli/install.py +35 -20
  5. flwr/cli/new/new.py +104 -28
  6. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  7. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  8. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
  10. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  11. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  12. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  14. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
  16. flwr/cli/run/run.py +46 -2
  17. flwr/client/__init__.py +1 -1
  18. flwr/client/app.py +22 -10
  19. flwr/client/client_app.py +1 -1
  20. flwr/client/dpfedavg_numpy_client.py +1 -1
  21. flwr/client/grpc_adapter_client/__init__.py +15 -0
  22. flwr/client/grpc_adapter_client/connection.py +94 -0
  23. flwr/client/grpc_client/connection.py +5 -1
  24. flwr/client/grpc_rere_client/__init__.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +9 -2
  26. flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
  27. flwr/client/message_handler/__init__.py +1 -1
  28. flwr/client/message_handler/message_handler.py +1 -1
  29. flwr/client/mod/__init__.py +4 -4
  30. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  31. flwr/client/mod/utils.py +1 -1
  32. flwr/client/rest_client/__init__.py +1 -1
  33. flwr/client/rest_client/connection.py +10 -2
  34. flwr/client/supernode/app.py +141 -41
  35. flwr/common/__init__.py +12 -12
  36. flwr/common/address.py +1 -1
  37. flwr/common/config.py +73 -0
  38. flwr/common/constant.py +16 -1
  39. flwr/common/date.py +1 -1
  40. flwr/common/dp.py +1 -1
  41. flwr/common/grpc.py +1 -1
  42. flwr/common/object_ref.py +39 -5
  43. flwr/common/record/__init__.py +1 -1
  44. flwr/common/secure_aggregation/__init__.py +1 -1
  45. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  46. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  47. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  48. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  49. flwr/common/secure_aggregation/quantization.py +1 -1
  50. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  51. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  52. flwr/common/telemetry.py +4 -0
  53. flwr/common/typing.py +9 -0
  54. flwr/common/version.py +14 -0
  55. flwr/proto/exec_pb2.py +34 -0
  56. flwr/proto/exec_pb2.pyi +55 -0
  57. flwr/proto/exec_pb2_grpc.py +101 -0
  58. flwr/proto/exec_pb2_grpc.pyi +41 -0
  59. flwr/proto/fab_pb2.py +30 -0
  60. flwr/proto/fab_pb2.pyi +56 -0
  61. flwr/proto/fab_pb2_grpc.py +4 -0
  62. flwr/proto/fab_pb2_grpc.pyi +4 -0
  63. flwr/server/__init__.py +2 -2
  64. flwr/server/app.py +62 -25
  65. flwr/server/compat/app.py +1 -1
  66. flwr/server/compat/app_utils.py +1 -1
  67. flwr/server/compat/driver_client_proxy.py +1 -1
  68. flwr/server/driver/driver.py +6 -0
  69. flwr/server/driver/grpc_driver.py +85 -63
  70. flwr/server/driver/inmemory_driver.py +28 -26
  71. flwr/server/run_serverapp.py +65 -20
  72. flwr/server/strategy/__init__.py +2 -2
  73. flwr/server/strategy/bulyan.py +1 -1
  74. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  75. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  76. flwr/server/strategy/fedadagrad.py +1 -1
  77. flwr/server/strategy/fedadam.py +1 -1
  78. flwr/server/strategy/fedavg_android.py +1 -1
  79. flwr/server/strategy/fedavgm.py +1 -1
  80. flwr/server/strategy/fedmedian.py +1 -1
  81. flwr/server/strategy/fedopt.py +1 -1
  82. flwr/server/strategy/fedprox.py +1 -1
  83. flwr/server/strategy/fedxgb_bagging.py +1 -1
  84. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  85. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  86. flwr/server/strategy/fedyogi.py +1 -1
  87. flwr/server/strategy/krum.py +1 -1
  88. flwr/server/strategy/qfedavg.py +1 -1
  89. flwr/server/superlink/driver/__init__.py +1 -1
  90. flwr/server/superlink/driver/driver_grpc.py +1 -1
  91. flwr/server/superlink/driver/driver_servicer.py +15 -3
  92. flwr/server/superlink/fleet/__init__.py +1 -1
  93. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  94. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  95. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  96. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  97. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  98. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  99. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -1
  100. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  101. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  102. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  103. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -4
  104. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  106. flwr/server/superlink/fleet/vce/backend/raybackend.py +44 -25
  107. flwr/server/superlink/fleet/vce/vce_api.py +3 -1
  108. flwr/server/superlink/state/__init__.py +1 -1
  109. flwr/server/superlink/state/in_memory_state.py +9 -6
  110. flwr/server/superlink/state/sqlite_state.py +7 -4
  111. flwr/server/superlink/state/state.py +6 -5
  112. flwr/server/superlink/state/state_factory.py +11 -2
  113. flwr/server/utils/__init__.py +1 -1
  114. flwr/server/utils/tensorboard.py +1 -1
  115. flwr/simulation/__init__.py +5 -2
  116. flwr/simulation/app.py +1 -1
  117. flwr/simulation/ray_transport/__init__.py +1 -1
  118. flwr/simulation/ray_transport/ray_actor.py +0 -6
  119. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  120. flwr/simulation/run_simulation.py +63 -22
  121. flwr/superexec/__init__.py +21 -0
  122. flwr/superexec/app.py +178 -0
  123. flwr/superexec/exec_grpc.py +51 -0
  124. flwr/superexec/exec_servicer.py +65 -0
  125. flwr/superexec/executor.py +54 -0
  126. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/METADATA +2 -1
  127. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/RECORD +130 -101
  128. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/entry_points.txt +1 -0
  129. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/LICENSE +0 -0
  130. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/WHEEL +0 -0
flwr/cli/app.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Flower command line interface."""
16
16
 
17
17
  import typer
18
+ from typer.main import get_command
18
19
 
19
20
  from .build import build
20
21
  from .example import example
@@ -37,5 +38,7 @@ app.command()(run)
37
38
  app.command()(build)
38
39
  app.command()(install)
39
40
 
41
+ typer_click_object = get_command(app)
42
+
40
43
  if __name__ == "__main__":
41
44
  app()
flwr/cli/build.py CHANGED
@@ -33,16 +33,12 @@ def build(
33
33
  Optional[Path],
34
34
  typer.Option(help="The Flower project directory to bundle into a FAB"),
35
35
  ] = None,
36
- ) -> None:
36
+ ) -> str:
37
37
  """Build a Flower project into a Flower App Bundle (FAB).
38
38
 
39
- You can run `flwr build` without any argument to bundle the current directory:
40
-
41
- `flwr build`
42
-
43
- You can also build a specific directory:
44
-
45
- `flwr build --directory ./projects/flower-hello-world`
39
+ You can run ``flwr build`` without any arguments to bundle the current directory,
40
+ or you can use ``--directory`` to build a specific directory:
41
+ ``flwr build --directory ./projects/flower-hello-world``.
46
42
  """
47
43
  if directory is None:
48
44
  directory = Path.cwd()
@@ -125,6 +121,8 @@ def build(
125
121
  f"🎊 Successfully built {fab_filename}.", fg=typer.colors.GREEN, bold=True
126
122
  )
127
123
 
124
+ return fab_filename
125
+
128
126
 
129
127
  def _load_gitignore(directory: Path) -> pathspec.PathSpec:
130
128
  """Load and parse .gitignore file, returning a pathspec."""
flwr/cli/config_utils.py CHANGED
@@ -14,14 +14,56 @@
14
14
  # ==============================================================================
15
15
  """Utility to validate the `pyproject.toml` file."""
16
16
 
17
+ import zipfile
18
+ from io import BytesIO
17
19
  from pathlib import Path
18
- from typing import Any, Dict, List, Optional, Tuple
20
+ from typing import IO, Any, Dict, List, Optional, Tuple, Union
19
21
 
20
22
  import tomli
21
23
 
22
24
  from flwr.common import object_ref
23
25
 
24
26
 
27
+ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
28
+ """Extract the fab_id and the fab_version from a FAB file or path.
29
+
30
+ Parameters
31
+ ----------
32
+ fab_file : Union[Path, bytes]
33
+ The Flower App Bundle file to validate and extract the metadata from.
34
+ It can either be a path to the file or the file itself as bytes.
35
+
36
+ Returns
37
+ -------
38
+ Tuple[str, str]
39
+ The `fab_version` and `fab_id` of the given Flower App Bundle.
40
+ """
41
+ fab_file_archive: Union[Path, IO[bytes]]
42
+ if isinstance(fab_file, bytes):
43
+ fab_file_archive = BytesIO(fab_file)
44
+ elif isinstance(fab_file, Path):
45
+ fab_file_archive = fab_file
46
+ else:
47
+ raise ValueError("fab_file must be either a Path or bytes")
48
+
49
+ with zipfile.ZipFile(fab_file_archive, "r") as zipf:
50
+ with zipf.open("pyproject.toml") as file:
51
+ toml_content = file.read().decode("utf-8")
52
+
53
+ conf = load_from_string(toml_content)
54
+ if conf is None:
55
+ raise ValueError("Invalid TOML content in pyproject.toml")
56
+
57
+ is_valid, errors, _ = validate(conf, check_module=False)
58
+ if not is_valid:
59
+ raise ValueError(errors)
60
+
61
+ return (
62
+ conf["project"]["version"],
63
+ f"{conf['flower']['publisher']}/{conf['project']['name']}",
64
+ )
65
+
66
+
25
67
  def load_and_validate(
26
68
  path: Optional[Path] = None,
27
69
  check_module: bool = True,
@@ -63,8 +105,7 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
63
105
  return None
64
106
 
65
107
  with toml_path.open(encoding="utf-8") as toml_file:
66
- data = tomli.loads(toml_file.read())
67
- return data
108
+ return load_from_string(toml_file.read())
68
109
 
69
110
 
70
111
  # pylint: disable=too-many-branches
@@ -128,3 +169,12 @@ def validate(
128
169
  return False, [reason], []
129
170
 
130
171
  return True, [], []
172
+
173
+
174
+ def load_from_string(toml_content: str) -> Optional[Dict[str, Any]]:
175
+ """Load TOML content from a string and return as dict."""
176
+ try:
177
+ data = tomli.loads(toml_content)
178
+ return data
179
+ except tomli.TOMLDecodeError:
180
+ return None
flwr/cli/install.py CHANGED
@@ -15,16 +15,18 @@
15
15
  """Flower command line interface `install` command."""
16
16
 
17
17
 
18
- import os
19
18
  import shutil
20
19
  import tempfile
21
20
  import zipfile
21
+ from io import BytesIO
22
22
  from pathlib import Path
23
- from typing import Optional
23
+ from typing import IO, Optional, Union
24
24
 
25
25
  import typer
26
26
  from typing_extensions import Annotated
27
27
 
28
+ from flwr.common.config import get_flwr_dir
29
+
28
30
  from .config_utils import load_and_validate
29
31
  from .utils import get_sha256_hash
30
32
 
@@ -80,11 +82,24 @@ def install(
80
82
 
81
83
 
82
84
  def install_from_fab(
83
- fab_file: Path, flwr_dir: Optional[Path], skip_prompt: bool = False
84
- ) -> None:
85
+ fab_file: Union[Path, bytes],
86
+ flwr_dir: Optional[Path],
87
+ skip_prompt: bool = False,
88
+ ) -> Path:
85
89
  """Install from a FAB file after extracting and validating."""
90
+ fab_file_archive: Union[Path, IO[bytes]]
91
+ fab_name: Optional[str]
92
+ if isinstance(fab_file, bytes):
93
+ fab_file_archive = BytesIO(fab_file)
94
+ fab_name = None
95
+ elif isinstance(fab_file, Path):
96
+ fab_file_archive = fab_file
97
+ fab_name = fab_file.stem
98
+ else:
99
+ raise ValueError("fab_file must be either a Path or bytes")
100
+
86
101
  with tempfile.TemporaryDirectory() as tmpdir:
87
- with zipfile.ZipFile(fab_file, "r") as zipf:
102
+ with zipfile.ZipFile(fab_file_archive, "r") as zipf:
88
103
  zipf.extractall(tmpdir)
89
104
  tmpdir_path = Path(tmpdir)
90
105
  info_dir = tmpdir_path / ".info"
@@ -110,15 +125,19 @@ def install_from_fab(
110
125
 
111
126
  shutil.rmtree(info_dir)
112
127
 
113
- validate_and_install(tmpdir_path, fab_file.stem, flwr_dir, skip_prompt)
128
+ installed_path = validate_and_install(
129
+ tmpdir_path, fab_name, flwr_dir, skip_prompt
130
+ )
131
+
132
+ return installed_path
114
133
 
115
134
 
116
135
  def validate_and_install(
117
136
  project_dir: Path,
118
- fab_name: str,
137
+ fab_name: Optional[str],
119
138
  flwr_dir: Optional[Path],
120
139
  skip_prompt: bool = False,
121
- ) -> None:
140
+ ) -> Path:
122
141
  """Validate TOML files and install the project to the desired directory."""
123
142
  config, _, _ = load_and_validate(project_dir / "pyproject.toml", check_module=False)
124
143
 
@@ -134,7 +153,10 @@ def validate_and_install(
134
153
  project_name = config["project"]["name"]
135
154
  version = config["project"]["version"]
136
155
 
137
- if fab_name != f"{publisher}.{project_name}.{version.replace('.', '-')}":
156
+ if (
157
+ fab_name
158
+ and fab_name != f"{publisher}.{project_name}.{version.replace('.', '-')}"
159
+ ):
138
160
  typer.secho(
139
161
  "❌ FAB file has incorrect name. The file name must follow the format "
140
162
  "`<publisher>.<project_name>.<version>.fab`.",
@@ -144,16 +166,7 @@ def validate_and_install(
144
166
  raise typer.Exit(code=1)
145
167
 
146
168
  install_dir: Path = (
147
- (
148
- Path(
149
- os.getenv(
150
- "FLWR_HOME",
151
- f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
152
- )
153
- )
154
- if not flwr_dir
155
- else flwr_dir
156
- )
169
+ (get_flwr_dir() if not flwr_dir else flwr_dir)
157
170
  / "apps"
158
171
  / publisher
159
172
  / project_name
@@ -168,7 +181,7 @@ def validate_and_install(
168
181
  bold=True,
169
182
  )
170
183
  ):
171
- return
184
+ return install_dir
172
185
 
173
186
  install_dir.mkdir(parents=True, exist_ok=True)
174
187
 
@@ -185,6 +198,8 @@ def validate_and_install(
185
198
  bold=True,
186
199
  )
187
200
 
201
+ return install_dir
202
+
188
203
 
189
204
  def _verify_hashes(list_content: str, tmpdir: Path) -> bool:
190
205
  """Verify file hashes based on the LIST content."""
flwr/cli/new/new.py CHANGED
@@ -41,6 +41,16 @@ class MlFramework(str, Enum):
41
41
  HUGGINGFACE = "HF"
42
42
  MLX = "MLX"
43
43
  SKLEARN = "sklearn"
44
+ FLOWERTUNE = "FlowerTune"
45
+
46
+
47
+ class LlmChallengeName(str, Enum):
48
+ """Available LLM challenges."""
49
+
50
+ GENERALNLP = "GeneralNLP"
51
+ FINANCE = "Finance"
52
+ MEDICAL = "Medical"
53
+ CODE = "Code"
44
54
 
45
55
 
46
56
  class TemplateNotFound(Exception):
@@ -81,6 +91,7 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
81
91
  create_file(file_path, content)
82
92
 
83
93
 
94
+ # pylint: disable=too-many-locals,too-many-branches,too-many-statements
84
95
  def new(
85
96
  project_name: Annotated[
86
97
  Optional[str],
@@ -125,6 +136,19 @@ def new(
125
136
 
126
137
  framework_str = framework_str.lower()
127
138
 
139
+ if framework_str == "flowertune":
140
+ llm_challenge_value = prompt_options(
141
+ "Please select LLM challenge by typing in the number",
142
+ sorted([challenge.value for challenge in LlmChallengeName]),
143
+ )
144
+ selected_value = [
145
+ name
146
+ for name, value in vars(LlmChallengeName).items()
147
+ if value == llm_challenge_value
148
+ ]
149
+ llm_challenge_str = selected_value[0]
150
+ llm_challenge_str = llm_challenge_str.lower()
151
+
128
152
  print(
129
153
  typer.style(
130
154
  f"\n🔨 Creating Flower project {project_name}...",
@@ -139,33 +163,6 @@ def new(
139
163
  import_name = package_name.replace("-", "_")
140
164
  project_dir = os.path.join(cwd, package_name)
141
165
 
142
- # List of files to render
143
- files = {
144
- ".gitignore": {"template": "app/.gitignore.tpl"},
145
- "README.md": {"template": "app/README.md.tpl"},
146
- "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
147
- f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
148
- f"{import_name}/server.py": {
149
- "template": f"app/code/server.{framework_str}.py.tpl"
150
- },
151
- f"{import_name}/client.py": {
152
- "template": f"app/code/client.{framework_str}.py.tpl"
153
- },
154
- }
155
-
156
- # Depending on the framework, generate task.py file
157
- frameworks_with_tasks = [
158
- MlFramework.PYTORCH.value.lower(),
159
- MlFramework.JAX.value.lower(),
160
- MlFramework.HUGGINGFACE.value.lower(),
161
- MlFramework.MLX.value.lower(),
162
- MlFramework.TENSORFLOW.value.lower(),
163
- ]
164
- if framework_str in frameworks_with_tasks:
165
- files[f"{import_name}/task.py"] = {
166
- "template": f"app/code/task.{framework_str}.py.tpl"
167
- }
168
-
169
166
  context = {
170
167
  "project_name": project_name,
171
168
  "package_name": package_name,
@@ -173,6 +170,85 @@ def new(
173
170
  "username": username,
174
171
  }
175
172
 
173
+ # List of files to render
174
+ if framework_str == "flowertune":
175
+ files = {
176
+ ".gitignore": {"template": "app/.gitignore.tpl"},
177
+ "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
178
+ "README.md": {"template": f"app/README.{framework_str}.md.tpl"},
179
+ f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
180
+ f"{import_name}/server.py": {
181
+ "template": "app/code/flwr_tune/server.py.tpl"
182
+ },
183
+ f"{import_name}/client.py": {
184
+ "template": "app/code/flwr_tune/client.py.tpl"
185
+ },
186
+ f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
187
+ f"{import_name}/models.py": {
188
+ "template": "app/code/flwr_tune/models.py.tpl"
189
+ },
190
+ f"{import_name}/dataset.py": {
191
+ "template": "app/code/flwr_tune/dataset.py.tpl"
192
+ },
193
+ f"{import_name}/conf/config.yaml": {
194
+ "template": "app/code/flwr_tune/config.yaml.tpl"
195
+ },
196
+ f"{import_name}/conf/static_config.yaml": {
197
+ "template": "app/code/flwr_tune/static_config.yaml.tpl"
198
+ },
199
+ }
200
+
201
+ # Challenge specific context
202
+ fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
203
+ if llm_challenge_str == "generalnlp":
204
+ challenge_name = "General NLP"
205
+ num_clients = "20"
206
+ dataset_name = "vicgalle/alpaca-gpt4"
207
+ elif llm_challenge_str == "finance":
208
+ challenge_name = "Finance"
209
+ num_clients = "50"
210
+ dataset_name = "FinGPT/fingpt-sentiment-train"
211
+ elif llm_challenge_str == "medical":
212
+ challenge_name = "Medical"
213
+ num_clients = "20"
214
+ dataset_name = "medalpaca/medical_meadow_medical_flashcards"
215
+ else:
216
+ challenge_name = "Code"
217
+ num_clients = "10"
218
+ dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"
219
+
220
+ context["llm_challenge_str"] = llm_challenge_str
221
+ context["fraction_fit"] = fraction_fit
222
+ context["challenge_name"] = challenge_name
223
+ context["num_clients"] = num_clients
224
+ context["dataset_name"] = dataset_name
225
+ else:
226
+ files = {
227
+ ".gitignore": {"template": "app/.gitignore.tpl"},
228
+ "README.md": {"template": "app/README.md.tpl"},
229
+ "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
230
+ f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
231
+ f"{import_name}/server.py": {
232
+ "template": f"app/code/server.{framework_str}.py.tpl"
233
+ },
234
+ f"{import_name}/client.py": {
235
+ "template": f"app/code/client.{framework_str}.py.tpl"
236
+ },
237
+ }
238
+
239
+ # Depending on the framework, generate task.py file
240
+ frameworks_with_tasks = [
241
+ MlFramework.PYTORCH.value.lower(),
242
+ MlFramework.JAX.value.lower(),
243
+ MlFramework.HUGGINGFACE.value.lower(),
244
+ MlFramework.MLX.value.lower(),
245
+ MlFramework.TENSORFLOW.value.lower(),
246
+ ]
247
+ if framework_str in frameworks_with_tasks:
248
+ files[f"{import_name}/task.py"] = {
249
+ "template": f"app/code/task.{framework_str}.py.tpl"
250
+ }
251
+
176
252
  for file_path, value in files.items():
177
253
  render_and_create(
178
254
  file_path=os.path.join(project_dir, file_path),
@@ -190,7 +266,7 @@ def new(
190
266
  )
191
267
  print(
192
268
  typer.style(
193
- f" cd {project_name}\n" + " pip install -e .\n flwr run\n",
269
+ f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
194
270
  fg=typer.colors.BRIGHT_CYAN,
195
271
  bold=True,
196
272
  )
@@ -0,0 +1,56 @@
1
+ # FlowerTune LLM on $challenge_name Dataset
2
+
3
+ This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name).
4
+ We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
5
+ Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
6
+ which allows users to perform the training on a single GPU.
7
+
8
+
9
+ ## Methodology
10
+
11
+ This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
12
+ The clients' models are aggregated with FedAvg strategy.
13
+ This provides a baseline performance for the leaderboard of $challenge_name challenge.
14
+
15
+
16
+ ## Environments setup
17
+
18
+ Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with:
19
+
20
+ ```shell
21
+ pip install -e .
22
+ ```
23
+
24
+ ## Experimental setup
25
+
26
+ The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
27
+ We randomly sample $fraction_fit clients to be available for each round,
28
+ and the federated fine-tuning lasts for `200` rounds.
29
+ All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).
30
+
31
+
32
+ ## Running the challenge
33
+
34
+ First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website.
35
+ Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine:
36
+
37
+ ```bash
38
+ huggingface-cli login
39
+ ```
40
+
41
+ Run the challenge with default config values.
42
+ The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.
43
+
44
+ ```bash
45
+ flwr run
46
+ ```
47
+
48
+ ## VRAM consumption
49
+
50
+ We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:
51
+
52
+ | Challenges | GeneralNLP | Finance | Medical | Code |
53
+ | :--------: | :--------: | :--------: | :--------: | :--------: |
54
+ | VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
55
+
56
+ You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`.
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower CLI `new` command app / code / flwr_tune templates."""
@@ -0,0 +1,86 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ import os
4
+ import warnings
5
+ from datetime import datetime
6
+
7
+ from flwr_datasets import FederatedDataset
8
+ from hydra import compose, initialize
9
+ from hydra.utils import instantiate
10
+
11
+ from flwr.client import ClientApp
12
+ from flwr.common import ndarrays_to_parameters
13
+ from flwr.server import ServerApp, ServerConfig
14
+
15
+ from $import_name.client import gen_client_fn, get_parameters
16
+ from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
17
+ from $import_name.models import get_model
18
+ from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config
19
+
20
+ # Avoid warnings
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
23
+ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
24
+
25
+ # Initialise regular config
26
+ with initialize(config_path="conf", version_base="1.1"):
27
+ cfg = compose(config_name="config")
28
+
29
+ # Initialise static config
30
+ with initialize(config_path="conf", version_base="1.1"):
31
+ cfg_static = compose(config_name="static_config")
32
+
33
+ cfg.train.num_rounds = cfg_static.num_rounds
34
+
35
+ # Create output directory given current timestamp
36
+ current_time = datetime.now()
37
+ folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
38
+ save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
39
+ os.makedirs(save_path, exist_ok=True)
40
+
41
+ # Partition dataset and get dataloaders
42
+ partitioner = instantiate(cfg_static.partitioner)
43
+ fds = FederatedDataset(
44
+ dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
45
+ )
46
+ (
47
+ tokenizer,
48
+ data_collator,
49
+ formatting_prompts_func,
50
+ ) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
51
+
52
+ # ClientApp for Flower Next
53
+ client = ClientApp(
54
+ client_fn=gen_client_fn(
55
+ fds,
56
+ tokenizer,
57
+ formatting_prompts_func,
58
+ data_collator,
59
+ cfg.model,
60
+ cfg.train,
61
+ save_path,
62
+ ),
63
+ )
64
+
65
+ # Get initial model weights
66
+ init_model = get_model(cfg.model)
67
+ init_model_parameters = get_parameters(init_model)
68
+ init_model_parameters = ndarrays_to_parameters(init_model_parameters)
69
+
70
+ # Instantiate strategy according to config. Here we pass other arguments
71
+ # that are only defined at runtime.
72
+ strategy = instantiate(
73
+ cfg.strategy,
74
+ on_fit_config_fn=get_on_fit_config(),
75
+ fit_metrics_aggregation_fn=fit_weighted_average,
76
+ initial_parameters=init_model_parameters,
77
+ evaluate_fn=get_evaluate_fn(
78
+ cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
79
+ ),
80
+ )
81
+
82
+ # ServerApp for Flower Next
83
+ server = ServerApp(
84
+ config=ServerConfig(num_rounds=cfg_static.num_rounds),
85
+ strategy=strategy,
86
+ )