flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. flwr/cli/app.py +17 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +196 -42
  12. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  13. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  14. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  15. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  16. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  17. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  18. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  19. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  20. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  21. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  22. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  24. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  25. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  26. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  27. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  28. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  29. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  30. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  31. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  32. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  33. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  34. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  35. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  36. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  37. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  38. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  39. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  40. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  41. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  42. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  43. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  44. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  45. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  46. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  47. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  49. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  50. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  52. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  53. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  54. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  55. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  56. flwr/cli/pull.py +100 -0
  57. flwr/cli/run/run.py +11 -7
  58. flwr/cli/stop.py +2 -2
  59. flwr/cli/supernode/__init__.py +25 -0
  60. flwr/cli/supernode/ls.py +260 -0
  61. flwr/cli/supernode/register.py +185 -0
  62. flwr/cli/supernode/unregister.py +138 -0
  63. flwr/cli/utils.py +109 -69
  64. flwr/client/__init__.py +2 -1
  65. flwr/client/grpc_adapter_client/connection.py +6 -8
  66. flwr/client/grpc_rere_client/connection.py +59 -31
  67. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  68. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  69. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  70. flwr/client/rest_client/connection.py +82 -37
  71. flwr/clientapp/__init__.py +1 -2
  72. flwr/clientapp/mod/__init__.py +4 -1
  73. flwr/clientapp/mod/centraldp_mods.py +156 -40
  74. flwr/clientapp/mod/localdp_mod.py +169 -0
  75. flwr/clientapp/typing.py +22 -0
  76. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  77. flwr/common/constant.py +56 -13
  78. flwr/common/exit/exit_code.py +24 -10
  79. flwr/common/inflatable_utils.py +10 -10
  80. flwr/common/record/array.py +3 -3
  81. flwr/common/record/arrayrecord.py +10 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  84. flwr/common/serde.py +4 -2
  85. flwr/common/typing.py +7 -6
  86. flwr/compat/client/app.py +1 -1
  87. flwr/compat/client/grpc_client/connection.py +2 -2
  88. flwr/proto/control_pb2.py +48 -31
  89. flwr/proto/control_pb2.pyi +95 -5
  90. flwr/proto/control_pb2_grpc.py +136 -0
  91. flwr/proto/control_pb2_grpc.pyi +52 -0
  92. flwr/proto/fab_pb2.py +11 -7
  93. flwr/proto/fab_pb2.pyi +21 -1
  94. flwr/proto/fleet_pb2.py +31 -23
  95. flwr/proto/fleet_pb2.pyi +63 -23
  96. flwr/proto/fleet_pb2_grpc.py +98 -28
  97. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  98. flwr/proto/node_pb2.py +3 -1
  99. flwr/proto/node_pb2.pyi +48 -0
  100. flwr/server/app.py +152 -114
  101. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  102. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  103. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  104. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  106. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  110. flwr/server/superlink/linkstate/linkstate.py +107 -24
  111. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  112. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  113. flwr/server/superlink/linkstate/utils.py +3 -54
  114. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  115. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  116. flwr/server/utils/validator.py +2 -3
  117. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  118. flwr/serverapp/strategy/__init__.py +26 -0
  119. flwr/serverapp/strategy/bulyan.py +238 -0
  120. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  121. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  122. flwr/serverapp/strategy/fedadagrad.py +0 -3
  123. flwr/serverapp/strategy/fedadam.py +0 -3
  124. flwr/serverapp/strategy/fedavg.py +89 -64
  125. flwr/serverapp/strategy/fedavgm.py +198 -0
  126. flwr/serverapp/strategy/fedmedian.py +105 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +0 -3
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/strategy_utils.py +48 -0
  136. flwr/simulation/app.py +1 -1
  137. flwr/simulation/ray_transport/ray_actor.py +1 -1
  138. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  139. flwr/simulation/run_simulation.py +28 -32
  140. flwr/supercore/cli/flower_superexec.py +26 -1
  141. flwr/supercore/constant.py +41 -0
  142. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  143. flwr/supercore/object_store/object_store_factory.py +26 -6
  144. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  145. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  146. flwr/supercore/primitives/asymmetric.py +117 -0
  147. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  148. flwr/supercore/sqlite_mixin.py +156 -0
  149. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  150. flwr/supercore/superexec/run_superexec.py +16 -2
  151. flwr/supercore/utils.py +20 -0
  152. flwr/superlink/artifact_provider/__init__.py +22 -0
  153. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  154. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  155. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  156. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  157. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  158. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  159. flwr/superlink/servicer/control/control_grpc.py +16 -11
  160. flwr/superlink/servicer/control/control_servicer.py +207 -58
  161. flwr/supernode/cli/flower_supernode.py +19 -26
  162. flwr/supernode/runtime/run_clientapp.py +2 -2
  163. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  164. flwr/supernode/start_client_internal.py +17 -9
  165. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
  166. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
  167. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  168. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  169. flwr/common/auth_plugin/auth_plugin.py +0 -149
  170. flwr/serverapp/dp_fixed_clipping.py +0 -352
  171. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  172. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  173. /flwr/{client → clientapp}/client_app.py +0 -0
  174. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  175. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
flwr/cli/login/login.py CHANGED
@@ -20,6 +20,7 @@ from typing import Annotated, Optional
20
20
 
21
21
  import typer
22
22
 
23
+ from flwr.cli.auth_plugin import LoginError, NoOpCliAuthPlugin
23
24
  from flwr.cli.config_utils import (
24
25
  exit_if_no_address,
25
26
  get_insecure_flag,
@@ -28,14 +29,19 @@ from flwr.cli.config_utils import (
28
29
  validate_federation_in_project_config,
29
30
  )
30
31
  from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
31
- from flwr.common.typing import UserAuthLoginDetails
32
+ from flwr.common.typing import AccountAuthLoginDetails
32
33
  from flwr.proto.control_pb2 import ( # pylint: disable=E0611
33
34
  GetLoginDetailsRequest,
34
35
  GetLoginDetailsResponse,
35
36
  )
36
37
  from flwr.proto.control_pb2_grpc import ControlStub
37
38
 
38
- from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
39
+ from ..utils import (
40
+ account_auth_enabled,
41
+ flwr_cli_grpc_exc_handler,
42
+ init_channel,
43
+ load_cli_auth_plugin,
44
+ )
39
45
 
40
46
 
41
47
  def login( # pylint: disable=R0914
@@ -67,12 +73,13 @@ def login( # pylint: disable=R0914
67
73
  )
68
74
  exit_if_no_address(federation_config, "login")
69
75
 
70
- # Check if `enable-user-auth` is set to `true`
71
- if not federation_config.get("enable-user-auth", False):
76
+ # Check if `enable-account-auth` is set to `true`
77
+
78
+ if not account_auth_enabled(federation_config):
72
79
  typer.secho(
73
- f"❌ User authentication is not enabled for the federation '{federation}'. "
74
- "To enable it, set `enable-user-auth = true` in the federation "
75
- "configuration.",
80
+ "❌ Account authentication is not enabled for the federation "
81
+ f"'{federation}'. To enable it, set `enable-account-auth = true` "
82
+ "in the federation configuration.",
76
83
  fg=typer.colors.RED,
77
84
  bold=True,
78
85
  )
@@ -88,7 +95,7 @@ def login( # pylint: disable=R0914
88
95
  )
89
96
  raise typer.Exit(code=1)
90
97
 
91
- channel = init_channel(app, federation_config, None)
98
+ channel = init_channel(app, federation_config, NoOpCliAuthPlugin(Path()))
92
99
  stub = ControlStub(channel)
93
100
 
94
101
  login_request = GetLoginDetailsRequest()
@@ -96,28 +103,32 @@ def login( # pylint: disable=R0914
96
103
  login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request)
97
104
 
98
105
  # Get the auth plugin
99
- auth_type = login_response.auth_type
100
- auth_plugin = try_obtain_cli_auth_plugin(
101
- app, federation, federation_config, auth_type
102
- )
103
- if auth_plugin is None:
104
- typer.secho(
105
- f'❌ Authentication type "{auth_type}" not found',
106
- fg=typer.colors.RED,
107
- bold=True,
108
- )
109
- raise typer.Exit(code=1)
106
+ authn_type = login_response.authn_type
107
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config, authn_type)
110
108
 
111
109
  # Login
112
- details = UserAuthLoginDetails(
113
- auth_type=login_response.auth_type,
110
+ details = AccountAuthLoginDetails(
111
+ authn_type=login_response.authn_type,
114
112
  device_code=login_response.device_code,
115
113
  verification_uri_complete=login_response.verification_uri_complete,
116
114
  expires_in=login_response.expires_in,
117
115
  interval=login_response.interval,
118
116
  )
119
- with flwr_cli_grpc_exc_handler():
120
- credentials = auth_plugin.login(details, stub)
117
+ try:
118
+ with flwr_cli_grpc_exc_handler():
119
+ credentials = auth_plugin.login(details, stub)
120
+ typer.secho(
121
+ "✅ Login successful.",
122
+ fg=typer.colors.GREEN,
123
+ bold=False,
124
+ )
125
+ except LoginError as e:
126
+ typer.secho(
127
+ f"❌ Login failed: {e.message}",
128
+ fg=typer.colors.RED,
129
+ bold=True,
130
+ )
131
+ raise typer.Exit(code=1) from None
121
132
 
122
133
  # Store the tokens
123
134
  auth_plugin.store_tokens(credentials)
flwr/cli/ls.py CHANGED
@@ -19,7 +19,7 @@ import io
19
19
  import json
20
20
  from datetime import datetime, timedelta
21
21
  from pathlib import Path
22
- from typing import Annotated, Optional
22
+ from typing import Annotated, Optional, cast
23
23
 
24
24
  import typer
25
25
  from rich.console import Console
@@ -44,12 +44,13 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
44
44
  )
45
45
  from flwr.proto.control_pb2_grpc import ControlStub
46
46
 
47
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
47
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
48
48
 
49
49
  _RunListType = tuple[int, str, str, str, str, str, str, str, str]
50
50
 
51
51
 
52
52
  def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
53
+ ctx: typer.Context,
53
54
  app: Annotated[
54
55
  Path,
55
56
  typer.Argument(help="Path of the Flower project"),
@@ -102,6 +103,9 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
102
103
 
103
104
  All timestamps follow ISO 8601, UTC and are formatted as ``YYYY-MM-DD HH:MM:SSZ``.
104
105
  """
106
+ # Resolve command used (list or ls)
107
+ command_name = cast(str, ctx.command.name) if ctx.command else "list"
108
+
105
109
  suppress_output = output_format == CliOutputFormat.JSON
106
110
  captured_output = io.StringIO()
107
111
  try:
@@ -116,14 +120,14 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
116
120
  federation, federation_config = validate_federation_in_project_config(
117
121
  federation, config, federation_config_overrides
118
122
  )
119
- exit_if_no_address(federation_config, "ls")
123
+ exit_if_no_address(federation_config, command_name)
120
124
  channel = None
121
125
  try:
122
126
  if runs and run_id is not None:
123
127
  raise ValueError(
124
128
  "The options '--runs' and '--run-id' are mutually exclusive."
125
129
  )
126
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
130
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
127
131
  channel = init_channel(app, federation_config, auth_plugin)
128
132
  stub = ControlStub(channel)
129
133
 
@@ -216,14 +220,14 @@ def _to_table(run_list: list[_RunListType]) -> Table:
216
220
 
217
221
  # Add columns
218
222
  table.add_column(
219
- Text("Run ID", justify="center"), style="bright_white", overflow="fold"
223
+ Text("Run ID", justify="center"), style="bright_black", no_wrap=True
220
224
  )
221
- table.add_column(Text("FAB", justify="center"), style="dim white")
225
+ table.add_column(Text("FAB", justify="center"), style="bright_black")
222
226
  table.add_column(Text("Status", justify="center"))
223
227
  table.add_column(Text("Elapsed", justify="center"), style="blue")
224
- table.add_column(Text("Created At", justify="center"), style="dim white")
225
- table.add_column(Text("Running At", justify="center"), style="dim white")
226
- table.add_column(Text("Finished At", justify="center"), style="dim white")
228
+ table.add_column(Text("Created At", justify="center"), style="bright_black")
229
+ table.add_column(Text("Running At", justify="center"), style="bright_black")
230
+ table.add_column(Text("Finished At", justify="center"), style="bright_black")
227
231
 
228
232
  for row in run_list:
229
233
  (
flwr/cli/new/new.py CHANGED
@@ -15,14 +15,20 @@
15
15
  """Flower command line interface `new` command."""
16
16
 
17
17
 
18
+ import io
19
+ import json
18
20
  import re
21
+ import zipfile
19
22
  from enum import Enum
20
23
  from pathlib import Path
21
24
  from string import Template
22
25
  from typing import Annotated, Optional
23
26
 
27
+ import requests
24
28
  import typer
25
29
 
30
+ from flwr.supercore.constant import APP_ID_PATTERN, PLATFORM_API_URL
31
+
26
32
  from ..utils import (
27
33
  is_valid_project_name,
28
34
  prompt_options,
@@ -35,15 +41,16 @@ class MlFramework(str, Enum):
35
41
  """Available frameworks."""
36
42
 
37
43
  PYTORCH = "PyTorch"
38
- PYTORCH_MSG_API = "PyTorch (Message API)"
39
44
  TENSORFLOW = "TensorFlow"
40
45
  SKLEARN = "sklearn"
41
46
  HUGGINGFACE = "HuggingFace"
42
47
  JAX = "JAX"
43
48
  MLX = "MLX"
44
49
  NUMPY = "NumPy"
50
+ XGBOOST = "XGBoost"
45
51
  FLOWERTUNE = "FlowerTune"
46
52
  BASELINE = "Flower Baseline"
53
+ PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
47
54
 
48
55
 
49
56
  class LlmChallengeName(str, Enum):
@@ -92,6 +99,180 @@ def render_and_create(file_path: Path, template: str, context: dict[str, str]) -
92
99
  create_file(file_path, content)
93
100
 
94
101
 
102
+ def print_success_prompt(
103
+ package_name: str, llm_challenge_str: Optional[str] = None
104
+ ) -> None:
105
+ """Print styled setup instructions for running a new Flower App after creation."""
106
+ prompt = typer.style(
107
+ "🎊 Flower App creation successful.\n\n"
108
+ "To run your Flower App, first install its dependencies:\n\n",
109
+ fg=typer.colors.GREEN,
110
+ bold=True,
111
+ )
112
+
113
+ _add = " huggingface-cli login\n" if llm_challenge_str else ""
114
+
115
+ prompt += typer.style(
116
+ f" cd {package_name} && pip install -e .\n" + _add + "\n",
117
+ fg=typer.colors.BRIGHT_CYAN,
118
+ bold=True,
119
+ )
120
+
121
+ prompt += typer.style(
122
+ "then, run the app:\n\n ",
123
+ fg=typer.colors.GREEN,
124
+ bold=True,
125
+ )
126
+
127
+ prompt += typer.style(
128
+ "\tflwr run .\n\n",
129
+ fg=typer.colors.BRIGHT_CYAN,
130
+ bold=True,
131
+ )
132
+
133
+ prompt += typer.style(
134
+ "💡 Check the README in your app directory to learn how to\n"
135
+ "customize it and how to run it using the Deployment Runtime.\n",
136
+ fg=typer.colors.GREEN,
137
+ bold=True,
138
+ )
139
+
140
+ print(prompt)
141
+
142
+
143
+ # Security: prevent zip-slip
144
+ def _safe_extract_zip(zf: zipfile.ZipFile, dest_dir: Path) -> None:
145
+ """Extract ZIP file into destination directory."""
146
+ dest_dir = dest_dir.resolve()
147
+
148
+ def _is_within_directory(base: Path, target: Path) -> bool:
149
+ try:
150
+ target.relative_to(base)
151
+ return True
152
+ except ValueError:
153
+ return False
154
+
155
+ for member in zf.infolist():
156
+ # Skip directory placeholders;
157
+ # ZipInfo can represent them as names ending with '/'.
158
+ if member.is_dir():
159
+ target_path = (dest_dir / member.filename).resolve()
160
+ if not _is_within_directory(dest_dir, target_path):
161
+ raise ValueError(f"Unsafe path in zip: {member.filename}")
162
+ target_path.mkdir(parents=True, exist_ok=True)
163
+ continue
164
+
165
+ # Files
166
+ target_path = (dest_dir / member.filename).resolve()
167
+ if not _is_within_directory(dest_dir, target_path):
168
+ raise ValueError(f"Unsafe path in zip: {member.filename}")
169
+
170
+ # Ensure parent exists
171
+ target_path.parent.mkdir(parents=True, exist_ok=True)
172
+
173
+ # Extract
174
+ with zf.open(member, "r") as src, open(target_path, "wb") as dst:
175
+ dst.write(src.read())
176
+
177
+
178
+ def _download_zip_to_memory(presigned_url: str) -> io.BytesIO:
179
+ """Download ZIP file from Platform API to memory."""
180
+ try:
181
+ r = requests.get(presigned_url, timeout=60)
182
+ r.raise_for_status()
183
+ except requests.RequestException as e:
184
+ raise typer.BadParameter(f"ZIP download failed: {e}") from e
185
+
186
+ buf = io.BytesIO(r.content)
187
+ # Validate it's a zip
188
+ if not zipfile.is_zipfile(buf):
189
+ raise typer.BadParameter("Downloaded file is not a valid ZIP")
190
+ buf.seek(0)
191
+ return buf
192
+
193
+
194
+ def _request_download_link(identifier: str) -> str:
195
+ """Request download link from Flower platform API."""
196
+ url = f"{PLATFORM_API_URL}/hub/fetch-zip"
197
+ headers = {
198
+ "Content-Type": "application/json",
199
+ "Accept": "application/json",
200
+ }
201
+ body = {
202
+ "identifier": identifier, # send raw string of identifier
203
+ }
204
+
205
+ try:
206
+ resp = requests.post(url, headers=headers, data=json.dumps(body), timeout=20)
207
+ except requests.RequestException as e:
208
+ raise typer.BadParameter(f"Unable to connect to Platform API: {e}") from e
209
+
210
+ if resp.status_code == 404:
211
+ raise typer.BadParameter(f"'{identifier}' not found in Platform API")
212
+ if not resp.ok:
213
+ raise typer.BadParameter(
214
+ f"Platform API request failed with "
215
+ f"status {resp.status_code}. Details: {resp.text}"
216
+ )
217
+
218
+ data = resp.json()
219
+ if "zip_url" not in data:
220
+ raise typer.BadParameter("Invalid response from Platform API")
221
+ return str(data["zip_url"])
222
+
223
+
224
+ def download_remote_app_via_api(identifier: str) -> None:
225
+ """Download App from Platform API."""
226
+ # Parse @user/app just to derive local dir name
227
+ m = re.match(APP_ID_PATTERN, identifier)
228
+ if not m:
229
+ raise typer.BadParameter(
230
+ "Invalid remote app ID. Expected format: '@user_name/app_name'."
231
+ )
232
+ app_name = m.group("app")
233
+
234
+ project_dir = Path.cwd() / app_name
235
+ if project_dir.exists():
236
+ if not typer.confirm(
237
+ typer.style(
238
+ f"\n💬 {app_name} already exists, do you want to override it?",
239
+ fg=typer.colors.MAGENTA,
240
+ bold=True,
241
+ )
242
+ ):
243
+ return
244
+
245
+ print(
246
+ typer.style(
247
+ f"\n🔗 Requesting download link for {identifier}...",
248
+ fg=typer.colors.GREEN,
249
+ bold=True,
250
+ )
251
+ )
252
+ presigned_url = _request_download_link(identifier)
253
+
254
+ print(
255
+ typer.style(
256
+ "⬇️ Downloading ZIP into memory...",
257
+ fg=typer.colors.GREEN,
258
+ bold=True,
259
+ )
260
+ )
261
+ zip_buf = _download_zip_to_memory(presigned_url)
262
+
263
+ print(
264
+ typer.style(
265
+ f"📦 Unpacking into {project_dir}...",
266
+ fg=typer.colors.GREEN,
267
+ bold=True,
268
+ )
269
+ )
270
+ with zipfile.ZipFile(zip_buf) as zf:
271
+ _safe_extract_zip(zf, Path.cwd())
272
+
273
+ print_success_prompt(app_name)
274
+
275
+
95
276
  # pylint: disable=too-many-locals,too-many-branches,too-many-statements
96
277
  def new(
97
278
  app_name: Annotated[
@@ -110,6 +291,12 @@ def new(
110
291
  """Create new Flower App."""
111
292
  if app_name is None:
112
293
  app_name = prompt_text("Please provide the app name")
294
+
295
+ # Download remote app
296
+ if app_name and app_name.startswith("@"):
297
+ download_remote_app_via_api(app_name)
298
+ return
299
+
113
300
  if not is_valid_project_name(app_name):
114
301
  app_name = prompt_text(
115
302
  "Please provide a name that only contains "
@@ -155,8 +342,8 @@ def new(
155
342
  if framework_str == MlFramework.BASELINE:
156
343
  framework_str = "baseline"
157
344
 
158
- if framework_str == MlFramework.PYTORCH_MSG_API:
159
- framework_str = "pytorch_msg_api"
345
+ if framework_str == MlFramework.PYTORCH_LEGACY_API:
346
+ framework_str = "pytorch_legacy_api"
160
347
 
161
348
  print(
162
349
  typer.style(
@@ -201,7 +388,7 @@ def new(
201
388
  }
202
389
 
203
390
  # Challenge specific context
204
- fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
391
+ fraction_train = "0.2" if llm_challenge_str == "code" else "0.1"
205
392
  if llm_challenge_str == "generalnlp":
206
393
  challenge_name = "General NLP"
207
394
  num_clients = "20"
@@ -220,7 +407,7 @@ def new(
220
407
  dataset_name = "flwrlabs/code-alpaca-20k"
221
408
 
222
409
  context["llm_challenge_str"] = llm_challenge_str
223
- context["fraction_fit"] = fraction_fit
410
+ context["fraction_train"] = fraction_train
224
411
  context["challenge_name"] = challenge_name
225
412
  context["num_clients"] = num_clients
226
413
  context["dataset_name"] = dataset_name
@@ -247,14 +434,15 @@ def new(
247
434
  MlFramework.TENSORFLOW.value,
248
435
  MlFramework.SKLEARN.value,
249
436
  MlFramework.NUMPY.value,
250
- "pytorch_msg_api",
437
+ MlFramework.XGBOOST.value,
438
+ "pytorch_legacy_api",
251
439
  ]
252
440
  if framework_str in frameworks_with_tasks:
253
441
  files[f"{import_name}/task.py"] = {
254
442
  "template": f"app/code/task.{template_name}.py.tpl"
255
443
  }
256
444
 
257
- if framework_str == "pytorch_msg_api":
445
+ if framework_str == "pytorch_legacy_api":
258
446
  # Use custom __init__ that better captures name of framework
259
447
  files[f"{import_name}/__init__.py"] = {
260
448
  "template": f"app/code/__init__.{framework_str}.py.tpl"
@@ -280,38 +468,4 @@ def new(
280
468
  context=context,
281
469
  )
282
470
 
283
- prompt = typer.style(
284
- "🎊 Flower App creation successful.\n\n"
285
- "To run your Flower App, first install its dependencies:\n\n",
286
- fg=typer.colors.GREEN,
287
- bold=True,
288
- )
289
-
290
- _add = " huggingface-cli login\n" if llm_challenge_str else ""
291
-
292
- prompt += typer.style(
293
- f" cd {package_name} && pip install -e .\n" + _add + "\n",
294
- fg=typer.colors.BRIGHT_CYAN,
295
- bold=True,
296
- )
297
-
298
- prompt += typer.style(
299
- "then, run the app:\n\n ",
300
- fg=typer.colors.GREEN,
301
- bold=True,
302
- )
303
-
304
- prompt += typer.style(
305
- "\tflwr run .\n\n",
306
- fg=typer.colors.BRIGHT_CYAN,
307
- bold=True,
308
- )
309
-
310
- prompt += typer.style(
311
- "💡 Check the README in your app directory to learn how to\n"
312
- "customize it and how to run it using the Deployment Runtime.\n",
313
- fg=typer.colors.GREEN,
314
- bold=True,
315
- )
316
-
317
- print(prompt)
471
+ print_success_prompt(package_name, llm_challenge_str)
@@ -26,7 +26,7 @@ pip install -e .
26
26
  ## Experimental setup
27
27
 
28
28
  The dataset is divided into $num_clients partitions in an IID fashion, a partition is assigned to each ClientApp.
29
- We randomly sample a fraction ($fraction_fit) of the total nodes to participate in each round, for a total of `200` rounds.
29
+ We randomly sample a fraction ($fraction_train) of the total nodes to participate in each round, for a total of `200` rounds.
30
30
  All settings are defined in `pyproject.toml`.
31
31
 
32
32
  > [!IMPORTANT]
@@ -1,58 +1,75 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
3
  import torch
4
- from flwr.client import ClientApp, NumPyClient
5
- from flwr.common import Context
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
6
6
 
7
7
  from $import_name.dataset import load_data
8
- from $import_name.model import Net, get_weights, set_weights, test, train
9
-
10
-
11
- class FlowerClient(NumPyClient):
12
- """A class defining the client."""
13
-
14
- def __init__(self, net, trainloader, valloader, local_epochs):
15
- self.net = net
16
- self.trainloader = trainloader
17
- self.valloader = valloader
18
- self.local_epochs = local_epochs
19
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
- self.net.to(self.device)
21
-
22
- def fit(self, parameters, config):
23
- """Traim model using this client's data."""
24
- set_weights(self.net, parameters)
25
- train_loss = train(
26
- self.net,
27
- self.trainloader,
28
- self.local_epochs,
29
- self.device,
30
- )
31
- return (
32
- get_weights(self.net),
33
- len(self.trainloader.dataset),
34
- {"train_loss": train_loss},
35
- )
36
-
37
- def evaluate(self, parameters, config):
38
- """Evaluate model using this client's data."""
39
- set_weights(self.net, parameters)
40
- loss, accuracy = test(self.net, self.valloader, self.device)
41
- return loss, len(self.valloader.dataset), {"accuracy": accuracy}
42
-
43
-
44
- def client_fn(context: Context):
45
- """Construct a Client that will be run in a ClientApp."""
46
- # Load model and data
47
- net = Net()
8
+ from $import_name.model import Net
9
+ from $import_name.model import test as test_fn
10
+ from $import_name.model import train as train_fn
11
+
12
+ # Flower ClientApp
13
+ app = ClientApp()
14
+
15
+
16
+ @app.train()
17
+ def train(msg: Message, context: Context):
18
+ """Train the model on local data."""
19
+
20
+ # Load the model and initialize it with the received weights
21
+ model = Net()
22
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Load the data
48
26
  partition_id = int(context.node_config["partition-id"])
49
27
  num_partitions = int(context.node_config["num-partitions"])
50
- trainloader, valloader = load_data(partition_id, num_partitions)
28
+ trainloader, _ = load_data(partition_id, num_partitions)
51
29
  local_epochs = context.run_config["local-epochs"]
52
30
 
53
- # Return Client instance
54
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
31
+ # Call the training function
32
+ train_loss = train_fn(
33
+ model,
34
+ trainloader,
35
+ local_epochs,
36
+ device,
37
+ )
55
38
 
39
+ # Construct and return reply Message
40
+ model_record = ArrayRecord(model.state_dict())
41
+ metrics = {
42
+ "train_loss": train_loss,
43
+ "num-examples": len(trainloader.dataset),
44
+ }
45
+ metric_record = MetricRecord(metrics)
46
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
+ return Message(content=content, reply_to=msg)
56
48
 
57
- # Flower ClientApp
58
- app = ClientApp(client_fn)
49
+
50
+ @app.evaluate()
51
+ def evaluate(msg: Message, context: Context):
52
+ """Evaluate the model on local data."""
53
+
54
+ # Load the model and initialize it with the received weights
55
+ model = Net()
56
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+
59
+ # Load the data
60
+ partition_id = int(context.node_config["partition-id"])
61
+ num_partitions = int(context.node_config["num-partitions"])
62
+ _, valloader = load_data(partition_id, num_partitions)
63
+
64
+ # Call the evaluation function
65
+ eval_loss, eval_acc = test_fn(model, valloader, device)
66
+
67
+ # Construct and return reply Message
68
+ metrics = {
69
+ "eval_loss": eval_loss,
70
+ "eval_acc": eval_acc,
71
+ "num-examples": len(valloader.dataset),
72
+ }
73
+ metric_record = MetricRecord(metrics)
74
+ content = RecordDict({"metrics": metric_record})
75
+ return Message(content=content, reply_to=msg)