flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
flwr/cli/new/new.py CHANGED
@@ -14,23 +14,43 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `new` command."""
16
16
 
17
- import os
17
+ import re
18
18
  from enum import Enum
19
+ from pathlib import Path
19
20
  from string import Template
20
21
  from typing import Dict, Optional
21
22
 
22
23
  import typer
23
24
  from typing_extensions import Annotated
24
25
 
25
- from ..utils import prompt_options, prompt_text
26
+ from ..utils import (
27
+ is_valid_project_name,
28
+ prompt_options,
29
+ prompt_text,
30
+ sanitize_project_name,
31
+ )
26
32
 
27
33
 
28
34
  class MlFramework(str, Enum):
29
35
  """Available frameworks."""
30
36
 
31
- NUMPY = "NumPy"
32
37
  PYTORCH = "PyTorch"
33
38
  TENSORFLOW = "TensorFlow"
39
+ SKLEARN = "sklearn"
40
+ HUGGINGFACE = "HuggingFace"
41
+ JAX = "JAX"
42
+ MLX = "MLX"
43
+ NUMPY = "NumPy"
44
+ FLOWERTUNE = "FlowerTune"
45
+
46
+
47
+ class LlmChallengeName(str, Enum):
48
+ """Available LLM challenges."""
49
+
50
+ GENERALNLP = "GeneralNLP"
51
+ FINANCE = "Finance"
52
+ MEDICAL = "Medical"
53
+ CODE = "Code"
34
54
 
35
55
 
36
56
  class TemplateNotFound(Exception):
@@ -39,10 +59,10 @@ class TemplateNotFound(Exception):
39
59
 
40
60
  def load_template(name: str) -> str:
41
61
  """Load template from template directory and return as text."""
42
- tpl_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "templates"))
43
- tpl_file_path = os.path.join(tpl_dir, name)
62
+ tpl_dir = (Path(__file__).parent / "templates").absolute()
63
+ tpl_file_path = tpl_dir / name
44
64
 
45
- if not os.path.isfile(tpl_file_path):
65
+ if not tpl_file_path.is_file():
46
66
  raise TemplateNotFound(f"Template '{name}' not found")
47
67
 
48
68
  with open(tpl_file_path, encoding="utf-8") as tpl_file:
@@ -53,47 +73,69 @@ def render_template(template: str, data: Dict[str, str]) -> str:
53
73
  """Render template."""
54
74
  tpl_file = load_template(template)
55
75
  tpl = Template(tpl_file)
56
- result = tpl.substitute(data)
57
- return result
76
+ if ".gitignore" not in template:
77
+ return tpl.substitute(data)
78
+ return tpl.template
58
79
 
59
80
 
60
- def create_file(file_path: str, content: str) -> None:
81
+ def create_file(file_path: Path, content: str) -> None:
61
82
  """Create file including all nessecary directories and write content into file."""
62
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
63
- with open(file_path, "w", encoding="utf-8") as f:
64
- f.write(content)
83
+ file_path.parent.mkdir(exist_ok=True)
84
+ file_path.write_text(content)
65
85
 
66
86
 
67
- def render_and_create(file_path: str, template: str, context: Dict[str, str]) -> None:
87
+ def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None:
68
88
  """Render template and write to file."""
69
89
  content = render_template(template, context)
70
90
  create_file(file_path, content)
71
91
 
72
92
 
93
+ # pylint: disable=too-many-locals,too-many-branches,too-many-statements
73
94
  def new(
74
- project_name: Annotated[
95
+ app_name: Annotated[
75
96
  Optional[str],
76
- typer.Argument(metavar="project_name", help="The name of the project"),
97
+ typer.Argument(help="The name of the Flower App"),
77
98
  ] = None,
78
99
  framework: Annotated[
79
100
  Optional[MlFramework],
80
101
  typer.Option(case_sensitive=False, help="The ML framework to use"),
81
102
  ] = None,
103
+ username: Annotated[
104
+ Optional[str],
105
+ typer.Option(case_sensitive=False, help="The Flower username of the author"),
106
+ ] = None,
82
107
  ) -> None:
83
- """Create new Flower project."""
84
- print(
85
- typer.style(
86
- f"🔨 Creating Flower project {project_name}...",
87
- fg=typer.colors.GREEN,
88
- bold=True,
108
+ """Create new Flower App."""
109
+ if app_name is None:
110
+ app_name = prompt_text("Please provide the app name")
111
+ if not is_valid_project_name(app_name):
112
+ app_name = prompt_text(
113
+ "Please provide a name that only contains "
114
+ "characters in {'-', a-zA-Z', '0-9'}",
115
+ predicate=is_valid_project_name,
116
+ default=sanitize_project_name(app_name),
89
117
  )
90
- )
91
118
 
92
- if project_name is None:
93
- project_name = prompt_text("Please provide project name")
119
+ # Set project directory path
120
+ package_name = re.sub(r"[-_.]+", "-", app_name).lower()
121
+ import_name = package_name.replace("-", "_")
122
+ project_dir = Path.cwd() / package_name
123
+
124
+ if project_dir.exists():
125
+ if not typer.confirm(
126
+ typer.style(
127
+ f"\n💬 {app_name} already exists, do you want to override it?",
128
+ fg=typer.colors.MAGENTA,
129
+ bold=True,
130
+ )
131
+ ):
132
+ return
133
+
134
+ if username is None:
135
+ username = prompt_text("Please provide your Flower username")
94
136
 
95
137
  if framework is not None:
96
- framework_str = str(framework.value)
138
+ framework_str_upper = str(framework.value)
97
139
  else:
98
140
  framework_value = prompt_options(
99
141
  "Please select ML framework by typing in the number",
@@ -104,50 +146,139 @@ def new(
104
146
  for name, value in vars(MlFramework).items()
105
147
  if value == framework_value
106
148
  ]
107
- framework_str = selected_value[0]
149
+ framework_str_upper = selected_value[0]
108
150
 
109
- framework_str = framework_str.lower()
151
+ framework_str = framework_str_upper.lower()
110
152
 
111
- # Set project directory path
112
- cwd = os.getcwd()
113
- pnl = project_name.lower()
114
- project_dir = os.path.join(cwd, pnl)
153
+ llm_challenge_str = None
154
+ if framework_str == "flowertune":
155
+ llm_challenge_value = prompt_options(
156
+ "Please select LLM challenge by typing in the number",
157
+ sorted([challenge.value for challenge in LlmChallengeName]),
158
+ )
159
+ selected_value = [
160
+ name
161
+ for name, value in vars(LlmChallengeName).items()
162
+ if value == llm_challenge_value
163
+ ]
164
+ llm_challenge_str = selected_value[0]
165
+ llm_challenge_str = llm_challenge_str.lower()
115
166
 
116
- # List of files to render
117
- files = {
118
- "README.md": {"template": "app/README.md.tpl"},
119
- "requirements.txt": {"template": f"app/requirements.{framework_str}.txt.tpl"},
120
- "flower.toml": {"template": "app/flower.toml.tpl"},
121
- "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
122
- f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"},
123
- f"{pnl}/server.py": {"template": f"app/code/server.{framework_str}.py.tpl"},
124
- f"{pnl}/client.py": {"template": f"app/code/client.{framework_str}.py.tpl"},
167
+ print(
168
+ typer.style(
169
+ f"\n🔨 Creating Flower App {app_name}...",
170
+ fg=typer.colors.GREEN,
171
+ bold=True,
172
+ )
173
+ )
174
+
175
+ context = {
176
+ "framework_str": framework_str_upper,
177
+ "import_name": import_name.replace("-", "_"),
178
+ "package_name": package_name,
179
+ "project_name": app_name,
180
+ "username": username,
125
181
  }
126
182
 
127
- # In case framework is MlFramework.PYTORCH generate additionally the task.py file
128
- if framework_str == MlFramework.PYTORCH.value.lower():
129
- files[f"{pnl}/task.py"] = {"template": f"app/code/task.{framework_str}.py.tpl"}
183
+ # List of files to render
184
+ if llm_challenge_str:
185
+ files = {
186
+ ".gitignore": {"template": "app/.gitignore.tpl"},
187
+ "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
188
+ "README.md": {"template": f"app/README.{framework_str}.md.tpl"},
189
+ f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
190
+ f"{import_name}/server.py": {
191
+ "template": "app/code/flwr_tune/server.py.tpl"
192
+ },
193
+ f"{import_name}/client.py": {
194
+ "template": "app/code/flwr_tune/client.py.tpl"
195
+ },
196
+ f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
197
+ f"{import_name}/models.py": {
198
+ "template": "app/code/flwr_tune/models.py.tpl"
199
+ },
200
+ f"{import_name}/dataset.py": {
201
+ "template": "app/code/flwr_tune/dataset.py.tpl"
202
+ },
203
+ f"{import_name}/conf/config.yaml": {
204
+ "template": "app/code/flwr_tune/config.yaml.tpl"
205
+ },
206
+ f"{import_name}/conf/static_config.yaml": {
207
+ "template": "app/code/flwr_tune/static_config.yaml.tpl"
208
+ },
209
+ }
210
+
211
+ # Challenge specific context
212
+ fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
213
+ if llm_challenge_str == "generalnlp":
214
+ challenge_name = "General NLP"
215
+ num_clients = "20"
216
+ dataset_name = "vicgalle/alpaca-gpt4"
217
+ elif llm_challenge_str == "finance":
218
+ challenge_name = "Finance"
219
+ num_clients = "50"
220
+ dataset_name = "FinGPT/fingpt-sentiment-train"
221
+ elif llm_challenge_str == "medical":
222
+ challenge_name = "Medical"
223
+ num_clients = "20"
224
+ dataset_name = "medalpaca/medical_meadow_medical_flashcards"
225
+ else:
226
+ challenge_name = "Code"
227
+ num_clients = "10"
228
+ dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"
130
229
 
131
- context = {"project_name": project_name}
230
+ context["llm_challenge_str"] = llm_challenge_str
231
+ context["fraction_fit"] = fraction_fit
232
+ context["challenge_name"] = challenge_name
233
+ context["num_clients"] = num_clients
234
+ context["dataset_name"] = dataset_name
235
+ else:
236
+ files = {
237
+ ".gitignore": {"template": "app/.gitignore.tpl"},
238
+ "README.md": {"template": "app/README.md.tpl"},
239
+ "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
240
+ f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
241
+ f"{import_name}/server_app.py": {
242
+ "template": f"app/code/server.{framework_str}.py.tpl"
243
+ },
244
+ f"{import_name}/client_app.py": {
245
+ "template": f"app/code/client.{framework_str}.py.tpl"
246
+ },
247
+ }
248
+
249
+ # Depending on the framework, generate task.py file
250
+ frameworks_with_tasks = [
251
+ MlFramework.PYTORCH.value.lower(),
252
+ MlFramework.JAX.value.lower(),
253
+ MlFramework.HUGGINGFACE.value.lower(),
254
+ MlFramework.MLX.value.lower(),
255
+ MlFramework.TENSORFLOW.value.lower(),
256
+ ]
257
+ if framework_str in frameworks_with_tasks:
258
+ files[f"{import_name}/task.py"] = {
259
+ "template": f"app/code/task.{framework_str}.py.tpl"
260
+ }
132
261
 
133
262
  for file_path, value in files.items():
134
263
  render_and_create(
135
- file_path=os.path.join(project_dir, file_path),
264
+ file_path=project_dir / file_path,
136
265
  template=value["template"],
137
266
  context=context,
138
267
  )
139
268
 
140
269
  print(
141
270
  typer.style(
142
- "🎊 Project creation successful.\n\n"
143
- "Use the following command to run your project:\n",
271
+ "🎊 Flower App creation successful.\n\n"
272
+ "Use the following command to run your Flower App:\n",
144
273
  fg=typer.colors.GREEN,
145
274
  bold=True,
146
275
  )
147
276
  )
277
+
278
+ _add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
148
279
  print(
149
280
  typer.style(
150
- f" cd {project_name}\n" + " pip install -e .\n flwr run\n",
281
+ f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
151
282
  fg=typer.colors.BRIGHT_CYAN,
152
283
  bold=True,
153
284
  )
@@ -0,0 +1,160 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
@@ -0,0 +1,56 @@
1
+ # FlowerTune LLM on $challenge_name Dataset
2
+
3
+ This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name).
4
+ We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
5
+ Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
6
+ which allows users to perform the training on a single GPU.
7
+
8
+
9
+ ## Methodology
10
+
11
+ This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
12
+ The clients' models are aggregated with FedAvg strategy.
13
+ This provides a baseline performance for the leaderboard of $challenge_name challenge.
14
+
15
+
16
+ ## Environments setup
17
+
18
+ Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with:
19
+
20
+ ```shell
21
+ pip install -e .
22
+ ```
23
+
24
+ ## Experimental setup
25
+
26
+ The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
27
+ We randomly sample $fraction_fit clients to be available for each round,
28
+ and the federated fine-tuning lasts for `200` rounds.
29
+ All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).
30
+
31
+
32
+ ## Running the challenge
33
+
34
+ First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website.
35
+ Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine:
36
+
37
+ ```bash
38
+ huggingface-cli login
39
+ ```
40
+
41
+ Run the challenge with default config values.
42
+ The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.
43
+
44
+ ```bash
45
+ flwr run
46
+ ```
47
+
48
+ ## VRAM consumption
49
+
50
+ We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:
51
+
52
+ | Challenges | GeneralNLP | Finance | Medical | Code |
53
+ | :--------: | :--------: | :--------: | :--------: | :--------: |
54
+ | VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
55
+
56
+ You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`.
@@ -1,13 +1,9 @@
1
- # $project_name
1
+ # $project_name: A Flower / $framework_str app
2
2
 
3
3
  ## Install dependencies
4
4
 
5
5
  ```bash
6
- # Using pip
7
6
  pip install .
8
-
9
- # Or using Poetry
10
- poetry install
11
7
  ```
12
8
 
13
9
  ## Run (Simulation Engine)
@@ -1 +1 @@
1
- """$project_name."""
1
+ """$project_name: A Flower / $framework_str app."""
@@ -0,0 +1,65 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.client import ClientApp, NumPyClient
4
+ from flwr.common import Context
5
+ from transformers import AutoModelForSequenceClassification
6
+
7
+ from $import_name.task import (
8
+ get_weights,
9
+ load_data,
10
+ set_weights,
11
+ train,
12
+ test,
13
+ CHECKPOINT,
14
+ DEVICE,
15
+ )
16
+
17
+
18
+ # Flower client
19
+ class FlowerClient(NumPyClient):
20
+ def __init__(self, net, trainloader, testloader, local_epochs):
21
+ self.net = net
22
+ self.trainloader = trainloader
23
+ self.testloader = testloader
24
+ self.local_epochs = local_epochs
25
+
26
+ def get_parameters(self, config):
27
+ return get_weights(self.net)
28
+
29
+ def set_parameters(self, parameters):
30
+ set_weights(self.net, parameters)
31
+
32
+ def fit(self, parameters, config):
33
+ self.set_parameters(parameters)
34
+ train(
35
+ self.net,
36
+ self.trainloader,
37
+ epochs=self.local_epochs,
38
+ )
39
+ return self.get_parameters(config={}), len(self.trainloader), {}
40
+
41
+ def evaluate(self, parameters, config):
42
+ self.set_parameters(parameters)
43
+ loss, accuracy = test(self.net, self.testloader)
44
+ return float(loss), len(self.testloader), {"accuracy": accuracy}
45
+
46
+
47
+ def client_fn(context: Context):
48
+ # Load model and data
49
+ net = AutoModelForSequenceClassification.from_pretrained(
50
+ CHECKPOINT, num_labels=2
51
+ ).to(DEVICE)
52
+
53
+ partition_id = context.node_config["partition-id"]
54
+ num_partitions = context.node_config["num-partitions"]
55
+ trainloader, valloader = load_data(partition_id, num_partitions)
56
+ local_epochs = context.run_config["local-epochs"]
57
+
58
+ # Return Client instance
59
+ return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
60
+
61
+
62
+ # Flower ClientApp
63
+ app = ClientApp(
64
+ client_fn,
65
+ )
@@ -0,0 +1,56 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import jax
4
+ from flwr.client import NumPyClient, ClientApp
5
+ from flwr.common import Context
6
+
7
+ from $import_name.task import (
8
+ evaluation,
9
+ get_params,
10
+ load_data,
11
+ load_model,
12
+ loss_fn,
13
+ set_params,
14
+ train,
15
+ )
16
+
17
+
18
+ # Define Flower Client and client_fn
19
+ class FlowerClient(NumPyClient):
20
+ def __init__(self):
21
+ self.train_x, self.train_y, self.test_x, self.test_y = load_data()
22
+ self.grad_fn = jax.grad(loss_fn)
23
+ model_shape = self.train_x.shape[1:]
24
+
25
+ self.params = load_model(model_shape)
26
+
27
+ def get_parameters(self, config):
28
+ return get_params(self.params)
29
+
30
+ def set_parameters(self, parameters):
31
+ set_params(self.params, parameters)
32
+
33
+ def fit(self, parameters, config):
34
+ self.set_parameters(parameters)
35
+ self.params, loss, num_examples = train(
36
+ self.params, self.grad_fn, self.train_x, self.train_y
37
+ )
38
+ parameters = self.get_parameters(config={})
39
+ return parameters, num_examples, {"loss": float(loss)}
40
+
41
+ def evaluate(self, parameters, config):
42
+ self.set_parameters(parameters)
43
+ loss, num_examples = evaluation(
44
+ self.params, self.grad_fn, self.test_x, self.test_y
45
+ )
46
+ return float(loss), num_examples, {"loss": float(loss)}
47
+
48
+ def client_fn(context: Context):
49
+ # Return Client instance
50
+ return FlowerClient().to_client()
51
+
52
+
53
+ # Flower ClientApp
54
+ app = ClientApp(
55
+ client_fn,
56
+ )