flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/app.py +2 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +15 -2
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  14. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
  15. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  16. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  17. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  18. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  19. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  20. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  21. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  23. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  24. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  26. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  27. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  28. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  29. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  30. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  31. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  32. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  33. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  34. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  35. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  36. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  37. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  38. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  39. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  40. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  41. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  42. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  43. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  44. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
  45. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  46. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  47. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  49. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  50. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  52. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  53. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  54. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
  55. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  56. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  57. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  58. flwr/cli/pull.py +100 -0
  59. flwr/cli/run/run.py +9 -13
  60. flwr/cli/stop.py +7 -4
  61. flwr/cli/utils.py +36 -8
  62. flwr/client/grpc_rere_client/connection.py +1 -12
  63. flwr/client/rest_client/connection.py +3 -0
  64. flwr/clientapp/__init__.py +10 -0
  65. flwr/clientapp/mod/__init__.py +29 -0
  66. flwr/clientapp/mod/centraldp_mods.py +248 -0
  67. flwr/clientapp/mod/localdp_mod.py +169 -0
  68. flwr/clientapp/typing.py +22 -0
  69. flwr/common/args.py +20 -6
  70. flwr/common/auth_plugin/__init__.py +4 -4
  71. flwr/common/auth_plugin/auth_plugin.py +7 -7
  72. flwr/common/constant.py +26 -4
  73. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  74. flwr/common/exit/__init__.py +4 -0
  75. flwr/common/exit/exit.py +8 -1
  76. flwr/common/exit/exit_code.py +30 -7
  77. flwr/common/exit/exit_handler.py +62 -0
  78. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  79. flwr/common/grpc.py +0 -11
  80. flwr/common/inflatable_utils.py +1 -1
  81. flwr/common/logger.py +1 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/retry_invoker.py +30 -11
  84. flwr/common/telemetry.py +4 -0
  85. flwr/compat/server/app.py +2 -2
  86. flwr/proto/appio_pb2.py +25 -17
  87. flwr/proto/appio_pb2.pyi +46 -2
  88. flwr/proto/clientappio_pb2.py +3 -11
  89. flwr/proto/clientappio_pb2.pyi +0 -47
  90. flwr/proto/clientappio_pb2_grpc.py +19 -20
  91. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  92. flwr/proto/control_pb2.py +66 -0
  93. flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
  94. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
  95. flwr/proto/control_pb2_grpc.pyi +106 -0
  96. flwr/proto/serverappio_pb2.py +2 -2
  97. flwr/proto/serverappio_pb2_grpc.py +68 -0
  98. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  99. flwr/proto/simulationio_pb2.py +4 -11
  100. flwr/proto/simulationio_pb2.pyi +0 -58
  101. flwr/proto/simulationio_pb2_grpc.py +129 -27
  102. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  103. flwr/server/app.py +142 -152
  104. flwr/server/grid/grpc_grid.py +3 -0
  105. flwr/server/grid/inmemory_grid.py +1 -0
  106. flwr/server/serverapp/app.py +157 -146
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  110. flwr/server/superlink/linkstate/linkstate.py +2 -1
  111. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  112. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  113. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  114. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  115. flwr/serverapp/__init__.py +12 -0
  116. flwr/serverapp/exception.py +38 -0
  117. flwr/serverapp/strategy/__init__.py +64 -0
  118. flwr/serverapp/strategy/bulyan.py +238 -0
  119. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  120. flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
  121. flwr/serverapp/strategy/fedadagrad.py +159 -0
  122. flwr/serverapp/strategy/fedadam.py +178 -0
  123. flwr/serverapp/strategy/fedavg.py +320 -0
  124. flwr/serverapp/strategy/fedavgm.py +198 -0
  125. flwr/serverapp/strategy/fedmedian.py +105 -0
  126. flwr/serverapp/strategy/fedopt.py +218 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +170 -0
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/result.py +105 -0
  136. flwr/serverapp/strategy/strategy.py +285 -0
  137. flwr/serverapp/strategy/strategy_utils.py +299 -0
  138. flwr/simulation/app.py +161 -164
  139. flwr/simulation/run_simulation.py +25 -30
  140. flwr/supercore/app_utils.py +58 -0
  141. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  142. flwr/supercore/cli/flower_superexec.py +166 -0
  143. flwr/supercore/constant.py +19 -0
  144. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  145. flwr/supercore/corestate/corestate.py +81 -0
  146. flwr/supercore/grpc_health/__init__.py +3 -0
  147. flwr/supercore/grpc_health/health_server.py +53 -0
  148. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  149. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  150. flwr/supercore/superexec/plugin/__init__.py +28 -0
  151. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  152. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  153. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
  154. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  155. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  156. flwr/supercore/superexec/run_superexec.py +199 -0
  157. flwr/superlink/artifact_provider/__init__.py +22 -0
  158. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  159. flwr/superlink/servicer/__init__.py +15 -0
  160. flwr/superlink/servicer/control/__init__.py +22 -0
  161. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  162. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
  163. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  164. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
  165. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  166. flwr/supernode/cli/flower_supernode.py +3 -0
  167. flwr/supernode/cli/flwr_clientapp.py +18 -21
  168. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  169. flwr/supernode/nodestate/nodestate.py +3 -59
  170. flwr/supernode/runtime/run_clientapp.py +39 -102
  171. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  172. flwr/supernode/start_client_internal.py +35 -76
  173. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
  174. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
  175. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
  176. flwr/proto/exec_pb2.py +0 -62
  177. flwr/proto/exec_pb2_grpc.pyi +0 -93
  178. flwr/superexec/app.py +0 -45
  179. flwr/superexec/deployment.py +0 -191
  180. flwr/superexec/executor.py +0 -100
  181. flwr/superexec/simulation.py +0 -129
  182. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
flwr/__init__.py CHANGED
@@ -17,12 +17,15 @@
17
17
 
18
18
  from flwr.common.version import package_version as _package_version
19
19
 
20
- from . import client, common, server, simulation
20
+ from . import app, client, clientapp, common, server, serverapp, simulation
21
21
 
22
22
  __all__ = [
23
+ "app",
23
24
  "client",
25
+ "clientapp",
24
26
  "common",
25
27
  "server",
28
+ "serverapp",
26
29
  "simulation",
27
30
  ]
28
31
 
flwr/app/__init__.py CHANGED
@@ -13,3 +13,31 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """Public Flower App APIs."""
16
+
17
+
18
+ from flwr.common.constant import MessageType
19
+ from flwr.common.context import Context
20
+ from flwr.common.message import Message
21
+ from flwr.common.record import (
22
+ Array,
23
+ ArrayRecord,
24
+ ConfigRecord,
25
+ MetricRecord,
26
+ RecordDict,
27
+ )
28
+
29
+ from .error import Error
30
+ from .metadata import Metadata
31
+
32
+ __all__ = [
33
+ "Array",
34
+ "ArrayRecord",
35
+ "ConfigRecord",
36
+ "Context",
37
+ "Error",
38
+ "Message",
39
+ "MessageType",
40
+ "Metadata",
41
+ "MetricRecord",
42
+ "RecordDict",
43
+ ]
flwr/app/exception.py ADDED
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower application exceptions."""
16
+
17
+
18
+ class AppExitException(BaseException):
19
+ """Base exception for all application-level errors in ServerApp and ClientApp.
20
+
21
+ When raised, the process will exit and report a telemetry event with the associated
22
+ exit code. This is not intended to be caught by user code.
23
+ """
24
+
25
+ # Default exit code — subclasses must override
26
+ exit_code = -1
27
+
28
+ def __init_subclass__(cls) -> None:
29
+ """Ensure subclasses override the exit_code attribute."""
30
+ if cls.exit_code == -1:
31
+ raise ValueError("Subclasses must override the exit_code attribute.")
flwr/cli/app.py CHANGED
@@ -25,6 +25,7 @@ from .log import log
25
25
  from .login import login
26
26
  from .ls import ls
27
27
  from .new import new
28
+ from .pull import pull
28
29
  from .run import run
29
30
  from .stop import stop
30
31
 
@@ -46,6 +47,7 @@ app.command()(log)
46
47
  app.command()(ls)
47
48
  app.command()(stop)
48
49
  app.command()(login)
50
+ app.command()(pull)
49
51
 
50
52
  typer_click_object = get_command(app)
51
53
 
@@ -31,11 +31,11 @@ from flwr.common.constant import (
31
31
  AuthType,
32
32
  )
33
33
  from flwr.common.typing import UserAuthCredentials, UserAuthLoginDetails
34
- from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
34
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
35
35
  GetAuthTokensRequest,
36
36
  GetAuthTokensResponse,
37
37
  )
38
- from flwr.proto.exec_pb2_grpc import ExecStub
38
+ from flwr.proto.control_pb2_grpc import ControlStub
39
39
 
40
40
 
41
41
  class OidcCliPlugin(CliAuthPlugin):
@@ -49,7 +49,7 @@ class OidcCliPlugin(CliAuthPlugin):
49
49
  @staticmethod
50
50
  def login(
51
51
  login_details: UserAuthLoginDetails,
52
- exec_stub: ExecStub,
52
+ control_stub: ControlStub,
53
53
  ) -> UserAuthCredentials:
54
54
  """Authenticate the user and retrieve authentication credentials."""
55
55
  typer.secho(
@@ -61,7 +61,7 @@ class OidcCliPlugin(CliAuthPlugin):
61
61
  time.sleep(login_details.interval)
62
62
 
63
63
  while (time.time() - start_time) < login_details.expires_in:
64
- res: GetAuthTokensResponse = exec_stub.GetAuthTokens(
64
+ res: GetAuthTokensResponse = control_stub.GetAuthTokens(
65
65
  GetAuthTokensRequest(device_code=login_details.device_code)
66
66
  )
67
67
 
@@ -20,7 +20,7 @@ from typing import Any, Callable, Union
20
20
  import grpc
21
21
 
22
22
  from flwr.common.auth_plugin import CliAuthPlugin
23
- from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
23
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
24
24
  StartRunRequest,
25
25
  StreamLogsRequest,
26
26
  )
flwr/cli/config_utils.py CHANGED
@@ -143,7 +143,7 @@ def validate_federation_in_project_config(
143
143
  if federation is None:
144
144
  typer.secho(
145
145
  "❌ No federation name was provided and the project's `pyproject.toml` "
146
- "doesn't declare a default federation (with an Exec API address or an "
146
+ "doesn't declare a default federation (with an Control API address or an "
147
147
  "`options.num-supernodes` value).",
148
148
  fg=typer.colors.RED,
149
149
  bold=True,
@@ -230,8 +230,8 @@ def exit_if_no_address(federation_config: dict[str, Any], cmd: str) -> None:
230
230
  """Exit if the provided federation_config has no "address" key."""
231
231
  if "address" not in federation_config:
232
232
  typer.secho(
233
- f"❌ `flwr {cmd}` currently works with a SuperLink. Ensure that the correct"
234
- "SuperLink (Exec API) address is provided in `pyproject.toml`.",
233
+ f"❌ `flwr {cmd}` currently works with a SuperLink. Ensure that the "
234
+ "correct SuperLink (Control API) address is provided in `pyproject.toml`.",
235
235
  fg=typer.colors.RED,
236
236
  bold=True,
237
237
  )
flwr/cli/constant.py CHANGED
@@ -15,13 +15,30 @@
15
15
  """Constants for CLI commands."""
16
16
 
17
17
 
18
+ # General help message for config overrides
19
+ CONFIG_HELP_MESSAGE = (
20
+ "Override {0} values using one of the following formats:\n\n"
21
+ "--{1} '<k1>=<v1> <k2>=<v2>' | --{1} '<k1>=<v1>' --{1} '<k2>=<v2>'{2}\n\n"
22
+ "When providing key-value pairs, values can be of any type supported by TOML "
23
+ "(e.g., bool, int, float, string). The specified keys (<k1> and <k2> in the "
24
+ "example) must exist in the {0} under the `{3}` section of `pyproject.toml` to be "
25
+ "overridden.{4}"
26
+ )
27
+
28
+ # The help message for `--run-config` option
29
+ RUN_CONFIG_HELP_MESSAGE = CONFIG_HELP_MESSAGE.format(
30
+ "run configuration",
31
+ "run-config",
32
+ " | --run-config <path/to/your/toml>",
33
+ "[tool.flwr.app.config]",
34
+ " Alternatively, provide a TOML file containing overrides.",
35
+ )
36
+
18
37
  # The help message for `--federation-config` option
19
- FEDERATION_CONFIG_HELP_MESSAGE = (
20
- "Override federation configuration values in the format:\n\n"
21
- "`--federation-config 'key1=value1 key2=value2' --federation-config "
22
- "'key3=value3'`\n\nValues can be of any type supported in TOML, such as "
23
- "bool, int, float, or string. Ensure that the keys (`key1`, `key2`, `key3` "
24
- "in this example) exist in the federation configuration under the "
25
- "`[tool.flwr.federations.<YOUR_FEDERATION>]` table of the `pyproject.toml` "
26
- "for proper overriding."
38
+ FEDERATION_CONFIG_HELP_MESSAGE = CONFIG_HELP_MESSAGE.format(
39
+ "federation configuration",
40
+ "federation-config",
41
+ "",
42
+ "[tool.flwr.federations.<YOUR-FEDERATION>]",
43
+ "",
27
44
  )
flwr/cli/log.py CHANGED
@@ -32,8 +32,8 @@ from flwr.cli.config_utils import (
32
32
  from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
33
33
  from flwr.common.constant import CONN_RECONNECT_INTERVAL, CONN_REFRESH_PERIOD
34
34
  from flwr.common.logger import log as logger
35
- from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
36
- from flwr.proto.exec_pb2_grpc import ExecStub
35
+ from flwr.proto.control_pb2 import StreamLogsRequest # pylint: disable=E0611
36
+ from flwr.proto.control_pb2_grpc import ControlStub
37
37
 
38
38
  from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
39
39
 
@@ -46,7 +46,7 @@ def start_stream(
46
46
  run_id: int, channel: grpc.Channel, refresh_period: int = CONN_REFRESH_PERIOD
47
47
  ) -> None:
48
48
  """Start log streaming for a given run ID."""
49
- stub = ExecStub(channel)
49
+ stub = ControlStub(channel)
50
50
  after_timestamp = 0.0
51
51
  try:
52
52
  logger(INFO, "Starting logstream for run_id `%s`", run_id)
@@ -69,7 +69,7 @@ def start_stream(
69
69
 
70
70
 
71
71
  def stream_logs(
72
- run_id: int, stub: ExecStub, duration: int, after_timestamp: float
72
+ run_id: int, stub: ControlStub, duration: int, after_timestamp: float
73
73
  ) -> float:
74
74
  """Stream logs from the beginning of a run with connection refresh.
75
75
 
@@ -77,8 +77,8 @@ def stream_logs(
77
77
  ----------
78
78
  run_id : int
79
79
  The identifier of the run.
80
- stub : ExecStub
81
- The gRPC stub to interact with the Exec service.
80
+ stub : ControlStub
81
+ The gRPC stub to interact with the Control service.
82
82
  duration : int
83
83
  The timeout duration for each stream connection in seconds.
84
84
  after_timestamp : float
@@ -112,7 +112,7 @@ def stream_logs(
112
112
 
113
113
  def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
114
114
  """Print logs from the beginning of a run."""
115
- stub = ExecStub(channel)
115
+ stub = ControlStub(channel)
116
116
  req = StreamLogsRequest(run_id=run_id, after_timestamp=0.0)
117
117
 
118
118
  try:
@@ -173,13 +173,13 @@ def log(
173
173
  exit_if_no_address(federation_config, "log")
174
174
 
175
175
  try:
176
- _log_with_exec_api(app, federation, federation_config, run_id, stream)
176
+ _log_with_control_api(app, federation, federation_config, run_id, stream)
177
177
  except Exception as err: # pylint: disable=broad-except
178
178
  typer.secho(str(err), fg=typer.colors.RED, bold=True)
179
179
  raise typer.Exit(code=1) from None
180
180
 
181
181
 
182
- def _log_with_exec_api(
182
+ def _log_with_control_api(
183
183
  app: Path,
184
184
  federation: str,
185
185
  federation_config: dict[str, Any],
flwr/cli/login/login.py CHANGED
@@ -29,11 +29,11 @@ from flwr.cli.config_utils import (
29
29
  )
30
30
  from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
31
31
  from flwr.common.typing import UserAuthLoginDetails
32
- from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
32
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
33
33
  GetLoginDetailsRequest,
34
34
  GetLoginDetailsResponse,
35
35
  )
36
- from flwr.proto.exec_pb2_grpc import ExecStub
36
+ from flwr.proto.control_pb2_grpc import ControlStub
37
37
 
38
38
  from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
39
39
 
@@ -89,7 +89,7 @@ def login( # pylint: disable=R0914
89
89
  raise typer.Exit(code=1)
90
90
 
91
91
  channel = init_channel(app, federation_config, None)
92
- stub = ExecStub(channel)
92
+ stub = ControlStub(channel)
93
93
 
94
94
  login_request = GetLoginDetailsRequest()
95
95
  with flwr_cli_grpc_exc_handler():
flwr/cli/ls.py CHANGED
@@ -38,11 +38,11 @@ from flwr.common.date import format_timedelta, isoformat8601_utc
38
38
  from flwr.common.logger import print_json_error, redirect_output, restore_output
39
39
  from flwr.common.serde import run_from_proto
40
40
  from flwr.common.typing import Run
41
- from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
41
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
42
42
  ListRunsRequest,
43
43
  ListRunsResponse,
44
44
  )
45
- from flwr.proto.exec_pb2_grpc import ExecStub
45
+ from flwr.proto.control_pb2_grpc import ControlStub
46
46
 
47
47
  from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
48
48
 
@@ -125,7 +125,7 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
125
125
  )
126
126
  auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
127
127
  channel = init_channel(app, federation_config, auth_plugin)
128
- stub = ExecStub(channel)
128
+ stub = ControlStub(channel)
129
129
 
130
130
  # Display information about a specific run ID
131
131
  if run_id is not None:
@@ -293,7 +293,7 @@ def _to_json(run_list: list[_RunListType]) -> str:
293
293
  return json.dumps({"success": True, "runs": runs_list})
294
294
 
295
295
 
296
- def _list_runs(stub: ExecStub) -> list[_RunListType]:
296
+ def _list_runs(stub: ControlStub) -> list[_RunListType]:
297
297
  """List all runs."""
298
298
  with flwr_cli_grpc_exc_handler():
299
299
  res: ListRunsResponse = stub.ListRuns(ListRunsRequest())
@@ -302,7 +302,7 @@ def _list_runs(stub: ExecStub) -> list[_RunListType]:
302
302
  return _format_runs(run_dict, res.now)
303
303
 
304
304
 
305
- def _display_one_run(stub: ExecStub, run_id: int) -> list[_RunListType]:
305
+ def _display_one_run(stub: ControlStub, run_id: int) -> list[_RunListType]:
306
306
  """Display information about a specific run."""
307
307
  with flwr_cli_grpc_exc_handler():
308
308
  res: ListRunsResponse = stub.ListRuns(ListRunsRequest(run_id=run_id))
flwr/cli/new/new.py CHANGED
@@ -41,8 +41,10 @@ class MlFramework(str, Enum):
41
41
  JAX = "JAX"
42
42
  MLX = "MLX"
43
43
  NUMPY = "NumPy"
44
+ XGBOOST = "XGBoost"
44
45
  FLOWERTUNE = "FlowerTune"
45
46
  BASELINE = "Flower Baseline"
47
+ PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
46
48
 
47
49
 
48
50
  class LlmChallengeName(str, Enum):
@@ -154,6 +156,9 @@ def new(
154
156
  if framework_str == MlFramework.BASELINE:
155
157
  framework_str = "baseline"
156
158
 
159
+ if framework_str == MlFramework.PYTORCH_LEGACY_API:
160
+ framework_str = "pytorch_legacy_api"
161
+
157
162
  print(
158
163
  typer.style(
159
164
  f"\n🔨 Creating Flower App {app_name}...",
@@ -197,7 +202,7 @@ def new(
197
202
  }
198
203
 
199
204
  # Challenge specific context
200
- fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
205
+ fraction_train = "0.2" if llm_challenge_str == "code" else "0.1"
201
206
  if llm_challenge_str == "generalnlp":
202
207
  challenge_name = "General NLP"
203
208
  num_clients = "20"
@@ -216,7 +221,7 @@ def new(
216
221
  dataset_name = "flwrlabs/code-alpaca-20k"
217
222
 
218
223
  context["llm_challenge_str"] = llm_challenge_str
219
- context["fraction_fit"] = fraction_fit
224
+ context["fraction_train"] = fraction_train
220
225
  context["challenge_name"] = challenge_name
221
226
  context["num_clients"] = num_clients
222
227
  context["dataset_name"] = dataset_name
@@ -243,12 +248,20 @@ def new(
243
248
  MlFramework.TENSORFLOW.value,
244
249
  MlFramework.SKLEARN.value,
245
250
  MlFramework.NUMPY.value,
251
+ MlFramework.XGBOOST.value,
252
+ "pytorch_legacy_api",
246
253
  ]
247
254
  if framework_str in frameworks_with_tasks:
248
255
  files[f"{import_name}/task.py"] = {
249
256
  "template": f"app/code/task.{template_name}.py.tpl"
250
257
  }
251
258
 
259
+ if framework_str == "pytorch_legacy_api":
260
+ # Use custom __init__ that better captures name of framework
261
+ files[f"{import_name}/__init__.py"] = {
262
+ "template": f"app/code/__init__.{framework_str}.py.tpl"
263
+ }
264
+
252
265
  if framework_str == "baseline":
253
266
  # Include additional files for baseline template
254
267
  for file_name in ["model", "dataset", "strategy", "utils", "__init__"]:
@@ -26,7 +26,7 @@ pip install -e .
26
26
  ## Experimental setup
27
27
 
28
28
  The dataset is divided into $num_clients partitions in an IID fashion, a partition is assigned to each ClientApp.
29
- We randomly sample a fraction ($fraction_fit) of the total nodes to participate in each round, for a total of `200` rounds.
29
+ We randomly sample a fraction ($fraction_train) of the total nodes to participate in each round, for a total of `200` rounds.
30
30
  All settings are defined in `pyproject.toml`.
31
31
 
32
32
  > [!IMPORTANT]
@@ -0,0 +1 @@
1
+ """$project_name: A Flower / PyTorch app."""
@@ -1,58 +1,75 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
3
  import torch
4
- from flwr.client import ClientApp, NumPyClient
5
- from flwr.common import Context
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
6
6
 
7
7
  from $import_name.dataset import load_data
8
- from $import_name.model import Net, get_weights, set_weights, test, train
9
-
10
-
11
- class FlowerClient(NumPyClient):
12
- """A class defining the client."""
13
-
14
- def __init__(self, net, trainloader, valloader, local_epochs):
15
- self.net = net
16
- self.trainloader = trainloader
17
- self.valloader = valloader
18
- self.local_epochs = local_epochs
19
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
- self.net.to(self.device)
21
-
22
- def fit(self, parameters, config):
23
- """Traim model using this client's data."""
24
- set_weights(self.net, parameters)
25
- train_loss = train(
26
- self.net,
27
- self.trainloader,
28
- self.local_epochs,
29
- self.device,
30
- )
31
- return (
32
- get_weights(self.net),
33
- len(self.trainloader.dataset),
34
- {"train_loss": train_loss},
35
- )
36
-
37
- def evaluate(self, parameters, config):
38
- """Evaluate model using this client's data."""
39
- set_weights(self.net, parameters)
40
- loss, accuracy = test(self.net, self.valloader, self.device)
41
- return loss, len(self.valloader.dataset), {"accuracy": accuracy}
42
-
43
-
44
- def client_fn(context: Context):
45
- """Construct a Client that will be run in a ClientApp."""
46
- # Load model and data
47
- net = Net()
8
+ from $import_name.model import Net
9
+ from $import_name.model import test as test_fn
10
+ from $import_name.model import train as train_fn
11
+
12
+ # Flower ClientApp
13
+ app = ClientApp()
14
+
15
+
16
+ @app.train()
17
+ def train(msg: Message, context: Context):
18
+ """Train the model on local data."""
19
+
20
+ # Load the model and initialize it with the received weights
21
+ model = Net()
22
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Load the data
48
26
  partition_id = int(context.node_config["partition-id"])
49
27
  num_partitions = int(context.node_config["num-partitions"])
50
- trainloader, valloader = load_data(partition_id, num_partitions)
28
+ trainloader, _ = load_data(partition_id, num_partitions)
51
29
  local_epochs = context.run_config["local-epochs"]
52
30
 
53
- # Return Client instance
54
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
31
+ # Call the training function
32
+ train_loss = train_fn(
33
+ model,
34
+ trainloader,
35
+ local_epochs,
36
+ device,
37
+ )
55
38
 
39
+ # Construct and return reply Message
40
+ model_record = ArrayRecord(model.state_dict())
41
+ metrics = {
42
+ "train_loss": train_loss,
43
+ "num-examples": len(trainloader.dataset),
44
+ }
45
+ metric_record = MetricRecord(metrics)
46
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
+ return Message(content=content, reply_to=msg)
56
48
 
57
- # Flower ClientApp
58
- app = ClientApp(client_fn)
49
+
50
+ @app.evaluate()
51
+ def evaluate(msg: Message, context: Context):
52
+ """Evaluate the model on local data."""
53
+
54
+ # Load the model and initialize it with the received weights
55
+ model = Net()
56
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+
59
+ # Load the data
60
+ partition_id = int(context.node_config["partition-id"])
61
+ num_partitions = int(context.node_config["num-partitions"])
62
+ _, valloader = load_data(partition_id, num_partitions)
63
+
64
+ # Call the evaluation function
65
+ eval_loss, eval_acc = test_fn(model, valloader, device)
66
+
67
+ # Construct and return reply Message
68
+ metrics = {
69
+ "eval_loss": eval_loss,
70
+ "eval_acc": eval_acc,
71
+ "num-examples": len(valloader.dataset),
72
+ }
73
+ metric_record = MetricRecord(metrics)
74
+ content = RecordDict({"metrics": metric_record})
75
+ return Message(content=content, reply_to=msg)
@@ -1,41 +1,67 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import torch
4
- from flwr.client import ClientApp, NumPyClient
5
- from flwr.common import Context
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
6
6
  from transformers import AutoModelForSequenceClassification
7
7
 
8
- from $import_name.task import get_weights, load_data, set_weights, test, train
8
+ from $import_name.task import load_data
9
+ from $import_name.task import test as test_fn
10
+ from $import_name.task import train as train_fn
9
11
 
12
+ # Flower ClientApp
13
+ app = ClientApp()
14
+
15
+
16
+ @app.train()
17
+ def train(msg: Message, context: Context):
18
+ """Train the model on local data."""
19
+
20
+ # Get this client's dataset partition
21
+ partition_id = context.node_config["partition-id"]
22
+ num_partitions = context.node_config["num-partitions"]
23
+ model_name = context.run_config["model-name"]
24
+ trainloader, _ = load_data(partition_id, num_partitions, model_name)
25
+
26
+ # Load model
27
+ num_labels = context.run_config["num-labels"]
28
+ net = AutoModelForSequenceClassification.from_pretrained(
29
+ model_name, num_labels=num_labels
30
+ )
10
31
 
11
- # Flower client
12
- class FlowerClient(NumPyClient):
13
- def __init__(self, net, trainloader, testloader, local_epochs):
14
- self.net = net
15
- self.trainloader = trainloader
16
- self.testloader = testloader
17
- self.local_epochs = local_epochs
18
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
- self.net.to(self.device)
32
+ # Initialize it with the received weights
33
+ net.load_state_dict(msg.content["arrays"].to_torch_state_dict())
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ net.to(device)
20
36
 
21
- def fit(self, parameters, config):
22
- set_weights(self.net, parameters)
23
- train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
24
- return get_weights(self.net), len(self.trainloader), {}
37
+ # Train the model on local data
38
+ train_loss = train_fn(
39
+ net,
40
+ trainloader,
41
+ context.run_config["local-steps"],
42
+ device,
43
+ )
25
44
 
26
- def evaluate(self, parameters, config):
27
- set_weights(self.net, parameters)
28
- loss, accuracy = test(self.net, self.testloader, self.device)
29
- return float(loss), len(self.testloader), {"accuracy": accuracy}
45
+ # Construct and return reply Message
46
+ model_record = ArrayRecord(net.state_dict())
47
+ metrics = {
48
+ "train_loss": train_loss,
49
+ "num-examples": len(trainloader.dataset),
50
+ }
51
+ metric_record = MetricRecord(metrics)
52
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
53
+ return Message(content=content, reply_to=msg)
30
54
 
31
55
 
32
- def client_fn(context: Context):
56
+ @app.evaluate()
57
+ def evaluate(msg: Message, context: Context):
58
+ """Evaluate the model on local data."""
33
59
 
34
60
  # Get this client's dataset partition
35
61
  partition_id = context.node_config["partition-id"]
36
62
  num_partitions = context.node_config["num-partitions"]
37
63
  model_name = context.run_config["model-name"]
38
- trainloader, valloader = load_data(partition_id, num_partitions, model_name)
64
+ _, valloader = load_data(partition_id, num_partitions, model_name)
39
65
 
40
66
  # Load model
41
67
  num_labels = context.run_config["num-labels"]
@@ -43,13 +69,25 @@ def client_fn(context: Context):
43
69
  model_name, num_labels=num_labels
44
70
  )
45
71
 
46
- local_epochs = context.run_config["local-epochs"]
47
-
48
- # Return Client instance
49
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
72
+ # Initialize it with the received weights
73
+ net.load_state_dict(msg.content["arrays"].to_torch_state_dict())
74
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75
+ net.to(device)
50
76
 
77
+ # Evaluate the model on local data
78
+ val_loss, val_accuracy = test_fn(
79
+ net,
80
+ valloader,
81
+ device,
82
+ )
51
83
 
52
- # Flower ClientApp
53
- app = ClientApp(
54
- client_fn,
55
- )
84
+ # Construct and return reply Message
85
+ model_record = ArrayRecord(net.state_dict())
86
+ metrics = {
87
+ "val_loss": val_loss,
88
+ "val_accuracy": val_accuracy,
89
+ "num-examples": len(valloader.dataset),
90
+ }
91
+ metric_record = MetricRecord(metrics)
92
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
93
+ return Message(content=content, reply_to=msg)