flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/config_utils.py +97 -0
  3. flwr/cli/log.py +63 -97
  4. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  7. flwr/cli/run/run.py +34 -88
  8. flwr/client/app.py +23 -20
  9. flwr/client/clientapp/app.py +22 -18
  10. flwr/client/nodestate/__init__.py +25 -0
  11. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  12. flwr/client/nodestate/nodestate.py +30 -0
  13. flwr/client/nodestate/nodestate_factory.py +37 -0
  14. flwr/client/{node_state.py → run_info_store.py} +4 -3
  15. flwr/client/supernode/app.py +6 -8
  16. flwr/common/args.py +83 -0
  17. flwr/common/config.py +10 -0
  18. flwr/common/constant.py +39 -5
  19. flwr/common/context.py +9 -4
  20. flwr/common/date.py +3 -3
  21. flwr/common/logger.py +108 -1
  22. flwr/common/object_ref.py +47 -16
  23. flwr/common/serde.py +24 -0
  24. flwr/common/telemetry.py +0 -6
  25. flwr/common/typing.py +10 -1
  26. flwr/proto/exec_pb2.py +14 -17
  27. flwr/proto/exec_pb2.pyi +14 -22
  28. flwr/proto/log_pb2.py +29 -0
  29. flwr/proto/log_pb2.pyi +39 -0
  30. flwr/proto/log_pb2_grpc.py +4 -0
  31. flwr/proto/log_pb2_grpc.pyi +4 -0
  32. flwr/proto/message_pb2.py +8 -8
  33. flwr/proto/message_pb2.pyi +4 -1
  34. flwr/proto/run_pb2.py +32 -27
  35. flwr/proto/run_pb2.pyi +26 -0
  36. flwr/proto/serverappio_pb2.py +52 -0
  37. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
  38. flwr/proto/serverappio_pb2_grpc.py +376 -0
  39. flwr/proto/serverappio_pb2_grpc.pyi +147 -0
  40. flwr/proto/simulationio_pb2.py +38 -0
  41. flwr/proto/simulationio_pb2.pyi +65 -0
  42. flwr/proto/simulationio_pb2_grpc.py +205 -0
  43. flwr/proto/simulationio_pb2_grpc.pyi +81 -0
  44. flwr/server/app.py +272 -105
  45. flwr/server/driver/driver.py +15 -1
  46. flwr/server/driver/grpc_driver.py +25 -36
  47. flwr/server/driver/inmemory_driver.py +6 -16
  48. flwr/server/run_serverapp.py +29 -23
  49. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  50. flwr/server/serverapp/app.py +214 -0
  51. flwr/server/strategy/aggregate.py +4 -4
  52. flwr/server/strategy/fedadam.py +11 -1
  53. flwr/server/superlink/driver/__init__.py +1 -1
  54. flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
  55. flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
  56. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  57. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
  58. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  59. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  60. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  61. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  62. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  63. flwr/server/superlink/linkstate/__init__.py +28 -0
  64. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
  65. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
  66. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  67. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
  68. flwr/server/superlink/{state → linkstate}/utils.py +81 -30
  69. flwr/server/superlink/simulation/__init__.py +15 -0
  70. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  71. flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
  72. flwr/simulation/__init__.py +5 -1
  73. flwr/simulation/app.py +273 -345
  74. flwr/simulation/legacy_app.py +382 -0
  75. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  76. flwr/simulation/run_simulation.py +57 -131
  77. flwr/simulation/simulationio_connection.py +86 -0
  78. flwr/superexec/app.py +6 -134
  79. flwr/superexec/deployment.py +61 -66
  80. flwr/superexec/exec_grpc.py +15 -8
  81. flwr/superexec/exec_servicer.py +36 -65
  82. flwr/superexec/executor.py +26 -7
  83. flwr/superexec/simulation.py +54 -107
  84. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
  85. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
  86. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
  87. flwr/client/node_state_tests.py +0 -66
  88. flwr/proto/driver_pb2.py +0 -42
  89. flwr/proto/driver_pb2_grpc.py +0 -239
  90. flwr/proto/driver_pb2_grpc.pyi +0 -94
  91. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
  92. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
@@ -23,7 +23,7 @@ from typing import Optional, cast
23
23
  import grpc
24
24
 
25
25
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
26
- from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
26
+ from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS
27
27
  from flwr.common.grpc import create_channel
28
28
  from flwr.common.logger import log
29
29
  from flwr.common.serde import (
@@ -32,7 +32,9 @@ from flwr.common.serde import (
32
32
  user_config_from_proto,
33
33
  )
34
34
  from flwr.common.typing import Run
35
- from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
35
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
36
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
37
+ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
36
38
  GetNodesRequest,
37
39
  GetNodesResponse,
38
40
  PullTaskResRequest,
@@ -40,9 +42,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
40
42
  PushTaskInsRequest,
41
43
  PushTaskInsResponse,
42
44
  )
43
- from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
44
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
45
- from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
45
+ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
46
46
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
47
47
 
48
48
  from .driver import Driver
@@ -56,14 +56,12 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
56
56
 
57
57
 
58
58
  class GrpcDriver(Driver):
59
- """`GrpcDriver` provides an interface to the Driver API.
59
+ """`GrpcDriver` provides an interface to the ServerAppIo API.
60
60
 
61
61
  Parameters
62
62
  ----------
63
- run_id : int
64
- The identifier of the run.
65
- driver_service_address : str (default: "[::]:9091")
66
- The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
63
+ serverappio_service_address : str (default: "[::]:9091")
64
+ The address (URL, IPv6, IPv4) of the SuperLink ServerAppIo API service.
67
65
  root_certificates : Optional[bytes] (default: None)
68
66
  The PEM-encoded root certificates as a byte string.
69
67
  If provided, a secure connection using the certificates will be
@@ -72,25 +70,23 @@ class GrpcDriver(Driver):
72
70
 
73
71
  def __init__( # pylint: disable=too-many-arguments
74
72
  self,
75
- run_id: int,
76
- driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
73
+ serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_ADDRESS,
77
74
  root_certificates: Optional[bytes] = None,
78
75
  ) -> None:
79
- self._run_id = run_id
80
- self._addr = driver_service_address
76
+ self._addr = serverappio_service_address
81
77
  self._cert = root_certificates
82
78
  self._run: Optional[Run] = None
83
- self._grpc_stub: Optional[DriverStub] = None
79
+ self._grpc_stub: Optional[ServerAppIoStub] = None
84
80
  self._channel: Optional[grpc.Channel] = None
85
81
  self.node = Node(node_id=0, anonymous=True)
86
82
 
87
83
  @property
88
84
  def _is_connected(self) -> bool:
89
- """Check if connected to the Driver API server."""
85
+ """Check if connected to the ServerAppIo API server."""
90
86
  return self._channel is not None
91
87
 
92
88
  def _connect(self) -> None:
93
- """Connect to the Driver API.
89
+ """Connect to the ServerAppIo API.
94
90
 
95
91
  This will not call GetRun.
96
92
  """
@@ -102,11 +98,11 @@ class GrpcDriver(Driver):
102
98
  insecure=(self._cert is None),
103
99
  root_certificates=self._cert,
104
100
  )
105
- self._grpc_stub = DriverStub(self._channel)
101
+ self._grpc_stub = ServerAppIoStub(self._channel)
106
102
  log(DEBUG, "[Driver] Connected to %s", self._addr)
107
103
 
108
104
  def _disconnect(self) -> None:
109
- """Disconnect from the Driver API."""
105
+ """Disconnect from the ServerAppIo API."""
110
106
  if not self._is_connected:
111
107
  log(DEBUG, "Already disconnected")
112
108
  return
@@ -116,15 +112,13 @@ class GrpcDriver(Driver):
116
112
  channel.close()
117
113
  log(DEBUG, "[Driver] Disconnected")
118
114
 
119
- def _init_run(self) -> None:
120
- # Check if is initialized
121
- if self._run is not None:
122
- return
115
+ def set_run(self, run_id: int) -> None:
116
+ """Set the run."""
123
117
  # Get the run info
124
- req = GetRunRequest(run_id=self._run_id)
118
+ req = GetRunRequest(run_id=run_id)
125
119
  res: GetRunResponse = self._stub.GetRun(req)
126
120
  if not res.HasField("run"):
127
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
121
+ raise RuntimeError(f"Cannot find the run with ID: {run_id}")
128
122
  self._run = Run(
129
123
  run_id=res.run.run_id,
130
124
  fab_id=res.run.fab_id,
@@ -136,21 +130,20 @@ class GrpcDriver(Driver):
136
130
  @property
137
131
  def run(self) -> Run:
138
132
  """Run information."""
139
- self._init_run()
140
133
  return Run(**vars(self._run))
141
134
 
142
135
  @property
143
- def _stub(self) -> DriverStub:
144
- """Driver stub."""
136
+ def _stub(self) -> ServerAppIoStub:
137
+ """ServerAppIo stub."""
145
138
  if not self._is_connected:
146
139
  self._connect()
147
- return cast(DriverStub, self._grpc_stub)
140
+ return cast(ServerAppIoStub, self._grpc_stub)
148
141
 
149
142
  def _check_message(self, message: Message) -> None:
150
143
  # Check if the message is valid
151
144
  if not (
152
145
  # Assume self._run being initialized
153
- message.metadata.run_id == self._run_id
146
+ message.metadata.run_id == cast(Run, self._run).run_id
154
147
  and message.metadata.src_node_id == self.node.node_id
155
148
  and message.metadata.message_id == ""
156
149
  and message.metadata.reply_to_message == ""
@@ -171,7 +164,6 @@ class GrpcDriver(Driver):
171
164
  This method constructs a new `Message` with given content and metadata.
172
165
  The `run_id` and `src_node_id` will be set automatically.
173
166
  """
174
- self._init_run()
175
167
  if ttl:
176
168
  warnings.warn(
177
169
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -182,7 +174,7 @@ class GrpcDriver(Driver):
182
174
 
183
175
  ttl_ = DEFAULT_TTL if ttl is None else ttl
184
176
  metadata = Metadata(
185
- run_id=self._run_id,
177
+ run_id=cast(Run, self._run).run_id,
186
178
  message_id="", # Will be set by the server
187
179
  src_node_id=self.node.node_id,
188
180
  dst_node_id=dst_node_id,
@@ -195,10 +187,9 @@ class GrpcDriver(Driver):
195
187
 
196
188
  def get_node_ids(self) -> list[int]:
197
189
  """Get node IDs."""
198
- self._init_run()
199
190
  # Call GrpcDriverStub method
200
191
  res: GetNodesResponse = self._stub.GetNodes(
201
- GetNodesRequest(run_id=self._run_id)
192
+ GetNodesRequest(run_id=cast(Run, self._run).run_id)
202
193
  )
203
194
  return [node.node_id for node in res.nodes]
204
195
 
@@ -208,7 +199,6 @@ class GrpcDriver(Driver):
208
199
  This method takes an iterable of messages and sends each message
209
200
  to the node specified in `dst_node_id`.
210
201
  """
211
- self._init_run()
212
202
  # Construct TaskIns
213
203
  task_ins_list: list[TaskIns] = []
214
204
  for msg in messages:
@@ -230,7 +220,6 @@ class GrpcDriver(Driver):
230
220
  This method is used to collect messages from the SuperLink that correspond to a
231
221
  set of given message IDs.
232
222
  """
233
- self._init_run()
234
223
  # Pull TaskRes
235
224
  res: PullTaskResResponse = self._stub.PullTaskRes(
236
225
  PullTaskResRequest(node=self.node, task_ids=message_ids)
@@ -25,18 +25,16 @@ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
25
  from flwr.common.serde import message_from_taskres, message_to_taskins
26
26
  from flwr.common.typing import Run
27
27
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
- from flwr.server.superlink.state import StateFactory
28
+ from flwr.server.superlink.linkstate import LinkStateFactory
29
29
 
30
30
  from .driver import Driver
31
31
 
32
32
 
33
33
  class InMemoryDriver(Driver):
34
- """`InMemoryDriver` class provides an interface to the Driver API.
34
+ """`InMemoryDriver` class provides an interface to the ServerAppIo API.
35
35
 
36
36
  Parameters
37
37
  ----------
38
- run_id : int
39
- The identifier of the run.
40
38
  state_factory : StateFactory
41
39
  A StateFactory embedding a state that this driver can interface with.
42
40
  pull_interval : float (default=0.1)
@@ -45,18 +43,15 @@ class InMemoryDriver(Driver):
45
43
 
46
44
  def __init__(
47
45
  self,
48
- run_id: int,
49
- state_factory: StateFactory,
46
+ state_factory: LinkStateFactory,
50
47
  pull_interval: float = 0.1,
51
48
  ) -> None:
52
- self._run_id = run_id
53
49
  self._run: Optional[Run] = None
54
50
  self.state = state_factory.state()
55
51
  self.pull_interval = pull_interval
56
52
  self.node = Node(node_id=0, anonymous=True)
57
53
 
58
54
  def _check_message(self, message: Message) -> None:
59
- self._init_run()
60
55
  # Check if the message is valid
61
56
  if not (
62
57
  message.metadata.run_id == cast(Run, self._run).run_id
@@ -67,19 +62,16 @@ class InMemoryDriver(Driver):
67
62
  ):
68
63
  raise ValueError(f"Invalid message: {message}")
69
64
 
70
- def _init_run(self) -> None:
65
+ def set_run(self, run_id: int) -> None:
71
66
  """Initialize the run."""
72
- if self._run is not None:
73
- return
74
- run = self.state.get_run(self._run_id)
67
+ run = self.state.get_run(run_id)
75
68
  if run is None:
76
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
69
+ raise RuntimeError(f"Cannot find the run with ID: {run_id}")
77
70
  self._run = run
78
71
 
79
72
  @property
80
73
  def run(self) -> Run:
81
74
  """Run ID."""
82
- self._init_run()
83
75
  return Run(**vars(cast(Run, self._run)))
84
76
 
85
77
  def create_message( # pylint: disable=too-many-arguments,R0917
@@ -95,7 +87,6 @@ class InMemoryDriver(Driver):
95
87
  This method constructs a new `Message` with given content and metadata.
96
88
  The `run_id` and `src_node_id` will be set automatically.
97
89
  """
98
- self._init_run()
99
90
  if ttl:
100
91
  warnings.warn(
101
92
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -119,7 +110,6 @@ class InMemoryDriver(Driver):
119
110
 
120
111
  def get_node_ids(self) -> list[int]:
121
112
  """Get node IDs."""
122
- self._init_run()
123
113
  return list(self.state.get_nodes(cast(Run, self._run).run_id))
124
114
 
125
115
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -31,10 +31,9 @@ from flwr.common.config import (
31
31
  get_project_config,
32
32
  get_project_dir,
33
33
  )
34
- from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
34
+ from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS
35
35
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
36
36
  from flwr.common.object_ref import load_app
37
- from flwr.common.typing import UserConfig
38
37
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
39
38
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
40
39
  CreateRunRequest,
@@ -48,11 +47,11 @@ from .server_app import LoadServerAppError, ServerApp
48
47
 
49
48
  def run(
50
49
  driver: Driver,
50
+ context: Context,
51
51
  server_app_dir: str,
52
- server_app_run_config: UserConfig,
53
52
  server_app_attr: Optional[str] = None,
54
53
  loaded_server_app: Optional[ServerApp] = None,
55
- ) -> None:
54
+ ) -> Context:
56
55
  """Run ServerApp with a given Driver."""
57
56
  if not (server_app_attr is None) ^ (loaded_server_app is None):
58
57
  raise ValueError(
@@ -78,15 +77,11 @@ def run(
78
77
 
79
78
  server_app = _load()
80
79
 
81
- # Initialize Context
82
- context = Context(
83
- node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
84
- )
85
-
86
80
  # Call ServerApp
87
81
  server_app(driver=driver, context=context)
88
82
 
89
83
  log(DEBUG, "ServerApp finished running.")
84
+ return context
90
85
 
91
86
 
92
87
  # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
@@ -111,18 +106,18 @@ def run_server_app() -> None:
111
106
  "app by executing `flwr new` and following the prompt."
112
107
  )
113
108
 
114
- if args.server != DRIVER_API_DEFAULT_ADDRESS:
109
+ if args.server != SERVERAPPIO_API_DEFAULT_ADDRESS:
115
110
  warn = "Passing flag --server is deprecated. Use --superlink instead."
116
111
  warn_deprecated_feature(warn)
117
112
 
118
- if args.superlink != DRIVER_API_DEFAULT_ADDRESS:
113
+ if args.superlink != SERVERAPPIO_API_DEFAULT_ADDRESS:
119
114
  # if `--superlink` also passed, then
120
115
  # warn user that this argument overrides what was passed with `--server`
121
116
  log(
122
117
  WARN,
123
118
  "Both `--server` and `--superlink` were passed. "
124
- "`--server` will be ignored. Connecting to the Superlink Driver API "
125
- "at %s.",
119
+ "`--server` will be ignored. Connecting to the "
120
+ "SuperLink ServerAppIo API at %s.",
126
121
  args.superlink,
127
122
  )
128
123
  else:
@@ -175,11 +170,11 @@ def run_server_app() -> None:
175
170
  if app_path is None:
176
171
  # User provided `--run-id`, but not `app_dir`
177
172
  driver = GrpcDriver(
178
- run_id=args.run_id,
179
- driver_service_address=args.superlink,
173
+ serverappio_service_address=args.superlink,
180
174
  root_certificates=root_certificates,
181
175
  )
182
176
  flwr_dir = get_flwr_dir(args.flwr_dir)
177
+ driver.set_run(args.run_id)
183
178
  run_ = driver.run
184
179
  if not run_.fab_hash:
185
180
  raise ValueError("FAB hash not provided.")
@@ -193,12 +188,12 @@ def run_server_app() -> None:
193
188
 
194
189
  app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
195
190
  config = get_project_config(app_path)
191
+ run_id = run_.run_id
196
192
  else:
197
193
  # User provided `app_dir`, but not `--run-id`
198
194
  # Create run if run_id is not provided
199
195
  driver = GrpcDriver(
200
- run_id=0, # Will be overwritten
201
- driver_service_address=args.superlink,
196
+ serverappio_service_address=args.superlink,
202
197
  root_certificates=root_certificates,
203
198
  )
204
199
  # Load config from the project directory
@@ -208,8 +203,9 @@ def run_server_app() -> None:
208
203
  # Create run
209
204
  req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
210
205
  res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
211
- # Overwrite driver._run_id
212
- driver._run_id = res.run_id # pylint: disable=W0212
206
+ # Fetch full `Run` using `run_id`
207
+ driver.set_run(res.run_id) # pylint: disable=W0212
208
+ run_id = res.run_id
213
209
 
214
210
  # Obtain server app reference and the run config
215
211
  server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
@@ -225,11 +221,20 @@ def run_server_app() -> None:
225
221
  root_certificates,
226
222
  )
227
223
 
224
+ # Initialize Context
225
+ context = Context(
226
+ run_id=run_id,
227
+ node_id=0,
228
+ node_config={},
229
+ state=RecordSet(),
230
+ run_config=server_app_run_config,
231
+ )
232
+
228
233
  # Run the ServerApp with the Driver
229
234
  run(
230
235
  driver=driver,
236
+ context=context,
231
237
  server_app_dir=app_path,
232
- server_app_run_config=server_app_run_config,
233
238
  server_app_attr=server_app_attr,
234
239
  )
235
240
 
@@ -272,13 +277,14 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
272
277
  )
273
278
  parser.add_argument(
274
279
  "--server",
275
- default=DRIVER_API_DEFAULT_ADDRESS,
280
+ default=SERVERAPPIO_API_DEFAULT_ADDRESS,
276
281
  help="Server address",
277
282
  )
278
283
  parser.add_argument(
279
284
  "--superlink",
280
- default=DRIVER_API_DEFAULT_ADDRESS,
281
- help="SuperLink Driver API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
285
+ default=SERVERAPPIO_API_DEFAULT_ADDRESS,
286
+ help="SuperLink ServerAppIo API (gRPC-rere) address "
287
+ "(IPv4, IPv6, or a domain name)",
282
288
  )
283
289
  parser.add_argument(
284
290
  "--run-id",
@@ -12,17 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower server state."""
15
+ """Flower AppIO service."""
16
16
 
17
17
 
18
- from .in_memory_state import InMemoryState as InMemoryState
19
- from .sqlite_state import SqliteState as SqliteState
20
- from .state import State as State
21
- from .state_factory import StateFactory as StateFactory
18
+ from .app import flwr_serverapp as flwr_serverapp
22
19
 
23
20
  __all__ = [
24
- "InMemoryState",
25
- "SqliteState",
26
- "State",
27
- "StateFactory",
21
+ "flwr_serverapp",
28
22
  ]
@@ -0,0 +1,214 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower ServerApp process."""
16
+
17
+ import argparse
18
+ from logging import DEBUG, ERROR, INFO
19
+ from pathlib import Path
20
+ from queue import Queue
21
+ from time import sleep
22
+ from typing import Optional
23
+
24
+ from flwr.cli.config_utils import get_fab_metadata
25
+ from flwr.cli.install import install_from_fab
26
+ from flwr.common.args import add_args_flwr_app_common, try_obtain_certificates
27
+ from flwr.common.config import (
28
+ get_flwr_dir,
29
+ get_fused_config_from_dir,
30
+ get_project_config,
31
+ get_project_dir,
32
+ )
33
+ from flwr.common.constant import Status, SubStatus
34
+ from flwr.common.logger import (
35
+ log,
36
+ mirror_output_to_queue,
37
+ restore_output,
38
+ start_log_uploader,
39
+ stop_log_uploader,
40
+ )
41
+ from flwr.common.serde import (
42
+ context_from_proto,
43
+ context_to_proto,
44
+ fab_from_proto,
45
+ run_from_proto,
46
+ run_status_to_proto,
47
+ )
48
+ from flwr.common.typing import RunStatus
49
+ from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
50
+ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
51
+ PullServerAppInputsRequest,
52
+ PullServerAppInputsResponse,
53
+ PushServerAppOutputsRequest,
54
+ )
55
+ from flwr.server.driver.grpc_driver import GrpcDriver
56
+ from flwr.server.run_serverapp import run as run_
57
+
58
+
59
+ def flwr_serverapp() -> None:
60
+ """Run process-isolated Flower ServerApp."""
61
+ # Capture stdout/stderr
62
+ log_queue: Queue[Optional[str]] = Queue()
63
+ mirror_output_to_queue(log_queue)
64
+
65
+ parser = argparse.ArgumentParser(
66
+ description="Run a Flower ServerApp",
67
+ )
68
+ parser.add_argument(
69
+ "--superlink",
70
+ type=str,
71
+ help="Address of SuperLink's ServerAppIo API",
72
+ )
73
+ parser.add_argument(
74
+ "--run-once",
75
+ action="store_true",
76
+ help="When set, this process will start a single ServerApp for a pending Run. "
77
+ "If there is no pending Run, the process will exit.",
78
+ )
79
+ add_args_flwr_app_common(parser=parser)
80
+ args = parser.parse_args()
81
+
82
+ log(INFO, "Starting Flower ServerApp")
83
+ certificates = try_obtain_certificates(args)
84
+
85
+ log(
86
+ DEBUG,
87
+ "Starting isolated `ServerApp` connected to SuperLink's ServerAppIo API at %s",
88
+ args.superlink,
89
+ )
90
+ run_serverapp(
91
+ superlink=args.superlink,
92
+ log_queue=log_queue,
93
+ run_once=args.run_once,
94
+ flwr_dir=args.flwr_dir,
95
+ certificates=certificates,
96
+ )
97
+
98
+ # Restore stdout/stderr
99
+ restore_output()
100
+
101
+
102
+ def run_serverapp( # pylint: disable=R0914, disable=W0212
103
+ superlink: str,
104
+ log_queue: Queue[Optional[str]],
105
+ run_once: bool,
106
+ flwr_dir: Optional[str] = None,
107
+ certificates: Optional[bytes] = None,
108
+ ) -> None:
109
+ """Run Flower ServerApp process."""
110
+ driver = GrpcDriver(
111
+ serverappio_service_address=superlink,
112
+ root_certificates=certificates,
113
+ )
114
+
115
+ # Resolve directory where FABs are installed
116
+ flwr_dir_ = get_flwr_dir(flwr_dir)
117
+ log_uploader = None
118
+
119
+ while True:
120
+
121
+ try:
122
+ # Pull ServerAppInputs from LinkState
123
+ req = PullServerAppInputsRequest()
124
+ res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
125
+ if not res.HasField("run"):
126
+ sleep(3)
127
+ run_status = None
128
+ continue
129
+
130
+ context = context_from_proto(res.context)
131
+ run = run_from_proto(res.run)
132
+ fab = fab_from_proto(res.fab)
133
+
134
+ driver.set_run(run.run_id)
135
+
136
+ # Start log uploader for this run
137
+ log_uploader = start_log_uploader(
138
+ log_queue=log_queue,
139
+ node_id=0,
140
+ run_id=run.run_id,
141
+ stub=driver._stub,
142
+ )
143
+
144
+ log(DEBUG, "ServerApp process starts FAB installation.")
145
+ install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
146
+
147
+ fab_id, fab_version = get_fab_metadata(fab.content)
148
+
149
+ app_path = str(
150
+ get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir_)
151
+ )
152
+ config = get_project_config(app_path)
153
+
154
+ # Obtain server app reference and the run config
155
+ server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
156
+ server_app_run_config = get_fused_config_from_dir(
157
+ Path(app_path), run.override_config
158
+ )
159
+
160
+ # Update run_config in context
161
+ context.run_config = server_app_run_config
162
+
163
+ log(
164
+ DEBUG,
165
+ "Flower will load ServerApp `%s` in %s",
166
+ server_app_attr,
167
+ app_path,
168
+ )
169
+
170
+ # Change status to Running
171
+ run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
172
+ driver._stub.UpdateRunStatus(
173
+ UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
174
+ )
175
+
176
+ # Load and run the ServerApp with the Driver
177
+ updated_context = run_(
178
+ driver=driver,
179
+ server_app_dir=app_path,
180
+ server_app_attr=server_app_attr,
181
+ context=context,
182
+ )
183
+
184
+ # Send resulting context
185
+ context_proto = context_to_proto(updated_context)
186
+ out_req = PushServerAppOutputsRequest(
187
+ run_id=run.run_id, context=context_proto
188
+ )
189
+ _ = driver._stub.PushServerAppOutputs(out_req)
190
+
191
+ run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
192
+
193
+ except Exception as ex: # pylint: disable=broad-exception-caught
194
+ exc_entity = "ServerApp"
195
+ log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
196
+ run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
197
+
198
+ finally:
199
+ if run_status:
200
+ run_status_proto = run_status_to_proto(run_status)
201
+ driver._stub.UpdateRunStatus(
202
+ UpdateRunStatusRequest(
203
+ run_id=run.run_id, run_status=run_status_proto
204
+ )
205
+ )
206
+
207
+ # Stop log uploader for this run
208
+ if log_uploader:
209
+ stop_log_uploader(log_queue, log_uploader)
210
+ log_uploader = None
211
+
212
+ # Stop the loop if `flwr-serverapp` is expected to process a single run
213
+ if run_once:
214
+ break
@@ -48,12 +48,12 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
48
48
  num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
49
49
 
50
50
  # Compute scaling factors for each result
51
- scaling_factors = [
52
- fit_res.num_examples / num_examples_total for _, fit_res in results
53
- ]
51
+ scaling_factors = np.asarray(
52
+ [fit_res.num_examples / num_examples_total for _, fit_res in results]
53
+ )
54
54
 
55
55
  def _try_inplace(
56
- x: NDArray, y: Union[NDArray, float], np_binary_op: np.ufunc
56
+ x: NDArray, y: Union[NDArray, np.float64], np_binary_op: np.ufunc
57
57
  ) -> NDArray:
58
58
  return ( # type: ignore[no-any-return]
59
59
  np_binary_op(x, y, out=x)
@@ -170,8 +170,18 @@ class FedAdam(FedOpt):
170
170
  for x, y in zip(self.v_t, delta_t)
171
171
  ]
172
172
 
173
+ # Compute the bias-corrected learning rate, `eta_norm` for improving convergence
174
+ # in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &
175
+ # Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic
176
+ # Optimization" in the formula line right before Section 2.1.
177
+ eta_norm = (
178
+ self.eta
179
+ * np.sqrt(1 - np.power(self.beta_2, server_round + 1.0))
180
+ / (1 - np.power(self.beta_1, server_round + 1.0))
181
+ )
182
+
173
183
  new_weights = [
174
- x + self.eta * y / (np.sqrt(z) + self.tau)
184
+ x + eta_norm * y / (np.sqrt(z) + self.tau)
175
185
  for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
176
186
  ]
177
187
 
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service."""
15
+ """Flower ServerAppIo service."""