flwr-nightly 1.11.0.dev20240823__py3-none-any.whl → 1.12.0.dev20240906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (48) hide show
  1. flwr/cli/app.py +0 -2
  2. flwr/cli/new/new.py +24 -10
  3. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  4. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  5. flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
  6. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  7. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  8. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
  10. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
  11. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
  13. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  14. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  15. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  16. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  17. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  18. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  19. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
  20. flwr/cli/run/run.py +2 -2
  21. flwr/client/__init__.py +0 -4
  22. flwr/client/grpc_rere_client/client_interceptor.py +13 -4
  23. flwr/client/supernode/app.py +3 -1
  24. flwr/common/config.py +14 -11
  25. flwr/common/telemetry.py +36 -30
  26. flwr/server/__init__.py +0 -4
  27. flwr/server/app.py +13 -13
  28. flwr/server/compat/app.py +0 -5
  29. flwr/server/driver/grpc_driver.py +1 -3
  30. flwr/server/run_serverapp.py +15 -1
  31. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
  32. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +11 -11
  33. flwr/server/superlink/state/in_memory_state.py +15 -15
  34. flwr/server/superlink/state/sqlite_state.py +10 -10
  35. flwr/server/superlink/state/state.py +8 -8
  36. flwr/simulation/run_simulation.py +23 -6
  37. flwr/superexec/__init__.py +0 -6
  38. flwr/superexec/app.py +3 -1
  39. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/METADATA +3 -3
  40. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/RECORD +43 -35
  41. flwr_nightly-1.12.0.dev20240906.dist-info/entry_points.txt +10 -0
  42. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
  43. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
  44. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
  45. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
  46. flwr_nightly-1.11.0.dev20240823.dist-info/entry_points.txt +0 -10
  47. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/LICENSE +0 -0
  48. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/WHEEL +0 -0
@@ -8,15 +8,15 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets>=0.1.0,<1.0.0",
13
- "hydra-core==1.3.2",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets>=0.3.0",
14
13
  "trl==0.8.1",
15
14
  "bitsandbytes==0.43.0",
16
15
  "scipy==1.13.0",
17
16
  "peft==0.6.2",
18
17
  "transformers==4.39.3",
19
18
  "sentencepiece==0.2.0",
19
+ "omegaconf==2.3.0",
20
20
  ]
21
21
 
22
22
  [tool.hatch.build.targets.wheel]
@@ -26,14 +26,41 @@ packages = ["."]
26
26
  publisher = "$username"
27
27
 
28
28
  [tool.flwr.app.components]
29
- serverapp = "$import_name.app:server"
30
- clientapp = "$import_name.app:client"
29
+ serverapp = "$import_name.server_app:app"
30
+ clientapp = "$import_name.client_app:app"
31
31
 
32
32
  [tool.flwr.app.config]
33
- num-server-rounds = 3
33
+ model.name = "mistralai/Mistral-7B-v0.3"
34
+ model.quantization = 4
35
+ model.gradient-checkpointing = true
36
+ model.lora.peft-lora-r = 32
37
+ model.lora.peft-lora-alpha = 64
38
+ train.save-every-round = 5
39
+ train.learning-rate-max = 5e-5
40
+ train.learning-rate-min = 1e-6
41
+ train.seq-length = 512
42
+ train.training-arguments.output-dir = ""
43
+ train.training-arguments.learning-rate = ""
44
+ train.training-arguments.per-device-train-batch-size = 16
45
+ train.training-arguments.gradient-accumulation-steps = 1
46
+ train.training-arguments.logging-steps = 10
47
+ train.training-arguments.num-train-epochs = 3
48
+ train.training-arguments.max-steps = 10
49
+ train.training-arguments.save-steps = 1000
50
+ train.training-arguments.save-total-limit = 10
51
+ train.training-arguments.gradient-checkpointing = true
52
+ train.training-arguments.lr-scheduler-type = "constant"
53
+ strategy.fraction-fit = $fraction_fit
54
+ strategy.fraction-evaluate = 0.0
55
+ num-server-rounds = 200
56
+
57
+ [tool.flwr.app.config.static]
58
+ dataset.name = "$dataset_name"
34
59
 
35
60
  [tool.flwr.federations]
36
61
  default = "local-simulation"
37
62
 
38
63
  [tool.flwr.federations.local-simulation]
39
- options.num-supernodes = 10
64
+ options.num-supernodes = $num_clients
65
+ options.backend.client-resources.num-cpus = 6
66
+ options.backend.client-resources.num-gpus = 1.0
flwr/cli/run/run.py CHANGED
@@ -124,14 +124,14 @@ def run(
124
124
 
125
125
 
126
126
  def _run_with_superexec(
127
- app: Optional[Path],
127
+ app: Path,
128
128
  federation_config: Dict[str, Any],
129
129
  config_overrides: Optional[List[str]],
130
130
  ) -> None:
131
131
 
132
132
  insecure_str = federation_config.get("insecure")
133
133
  if root_certificates := federation_config.get("root-certificates"):
134
- root_certificates_bytes = Path(root_certificates).read_bytes()
134
+ root_certificates_bytes = (app / root_certificates).read_bytes()
135
135
  if insecure := bool(insecure_str):
136
136
  typer.secho(
137
137
  "❌ `root_certificates` were provided but the `insecure` parameter"
flwr/client/__init__.py CHANGED
@@ -20,8 +20,6 @@ from .app import start_numpy_client as start_numpy_client
20
20
  from .client import Client as Client
21
21
  from .client_app import ClientApp as ClientApp
22
22
  from .numpy_client import NumPyClient as NumPyClient
23
- from .supernode import run_client_app as run_client_app
24
- from .supernode import run_supernode as run_supernode
25
23
  from .typing import ClientFn as ClientFn
26
24
  from .typing import ClientFnExt as ClientFnExt
27
25
 
@@ -32,8 +30,6 @@ __all__ = [
32
30
  "ClientFnExt",
33
31
  "NumPyClient",
34
32
  "mod",
35
- "run_client_app",
36
- "run_supernode",
37
33
  "start_client",
38
34
  "start_numpy_client",
39
35
  ]
@@ -17,11 +17,13 @@
17
17
 
18
18
  import base64
19
19
  import collections
20
+ from logging import WARNING
20
21
  from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
22
 
22
23
  import grpc
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
 
26
+ from flwr.common.logger import log
25
27
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
26
28
  bytes_to_public_key,
27
29
  compute_hmac,
@@ -151,8 +153,15 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
151
153
  server_public_key_bytes = base64.urlsafe_b64decode(
152
154
  _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
153
155
  )
154
- self.server_public_key = bytes_to_public_key(server_public_key_bytes)
155
- self.shared_secret = generate_shared_key(
156
- self.private_key, self.server_public_key
157
- )
156
+
157
+ if server_public_key_bytes != b"":
158
+ self.server_public_key = bytes_to_public_key(server_public_key_bytes)
159
+ else:
160
+ log(WARNING, "Can't get server public key, SuperLink may be offline")
161
+
162
+ if self.server_public_key is not None:
163
+ self.shared_secret = generate_shared_key(
164
+ self.private_key, self.server_public_key
165
+ )
166
+
158
167
  return response
@@ -77,7 +77,9 @@ def run_supernode() -> None:
77
77
  authentication_keys=authentication_keys,
78
78
  max_retries=args.max_retries,
79
79
  max_wait_time=args.max_wait_time,
80
- node_config=parse_config_args([args.node_config]),
80
+ node_config=parse_config_args(
81
+ [args.node_config] if args.node_config else args.node_config
82
+ ),
81
83
  isolation=args.isolation,
82
84
  supernode_address=args.supernode_address,
83
85
  )
flwr/common/config.py CHANGED
@@ -185,23 +185,26 @@ def parse_config_args(
185
185
  if config is None:
186
186
  return overrides
187
187
 
188
+ # Handle if .toml file is passed
189
+ if len(config) == 1 and config[0].endswith(".toml"):
190
+ with Path(config[0]).open("rb") as config_file:
191
+ overrides = flatten_dict(tomli.load(config_file))
192
+ return overrides
193
+
188
194
  # Regular expression to capture key-value pairs with possible quoted values
189
195
  pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)")
190
196
 
191
197
  for config_line in config:
192
198
  if config_line:
193
- matches = pattern.findall(config_line)
199
+ # .toml files aren't allowed alongside other configs
200
+ if config_line.endswith(".toml"):
201
+ raise ValueError(
202
+ "TOML files cannot be passed alongside key-value pairs."
203
+ )
194
204
 
195
- if (
196
- len(matches) == 1
197
- and "=" not in matches[0][0]
198
- and matches[0][0].endswith(".toml")
199
- ):
200
- with Path(matches[0][0]).open("rb") as config_file:
201
- overrides = flatten_dict(tomli.load(config_file))
202
- else:
203
- toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
204
- overrides.update(tomli.loads(toml_str))
205
+ matches = pattern.findall(config_line)
206
+ toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
207
+ overrides.update(tomli.loads(toml_str))
205
208
 
206
209
  return overrides
207
210
 
flwr/common/telemetry.py CHANGED
@@ -132,53 +132,59 @@ class EventType(str, Enum):
132
132
  # Ping
133
133
  PING = auto()
134
134
 
135
- # Client: start_client
135
+ # --- LEGACY FUNCTIONS -------------------------------------------------------------
136
+
137
+ # Legacy: `start_client` function
136
138
  START_CLIENT_ENTER = auto()
137
139
  START_CLIENT_LEAVE = auto()
138
140
 
139
- # Server: start_server
141
+ # Legacy: `start_server` function
140
142
  START_SERVER_ENTER = auto()
141
143
  START_SERVER_LEAVE = auto()
142
144
 
143
- # Driver API
144
- RUN_DRIVER_API_ENTER = auto()
145
- RUN_DRIVER_API_LEAVE = auto()
145
+ # Legacy: `start_simulation` function
146
+ START_SIMULATION_ENTER = auto()
147
+ START_SIMULATION_LEAVE = auto()
146
148
 
147
- # Fleet API
148
- RUN_FLEET_API_ENTER = auto()
149
- RUN_FLEET_API_LEAVE = auto()
149
+ # --- `flwr` CLI -------------------------------------------------------------------
150
150
 
151
- # Driver API and Fleet API
152
- RUN_SUPERLINK_ENTER = auto()
153
- RUN_SUPERLINK_LEAVE = auto()
151
+ # Not yet implemented
154
152
 
155
- # Simulation
156
- START_SIMULATION_ENTER = auto()
157
- START_SIMULATION_LEAVE = auto()
153
+ # --- SuperExec --------------------------------------------------------------------
158
154
 
159
- # Driver: Driver
160
- DRIVER_CONNECT = auto()
161
- DRIVER_DISCONNECT = auto()
155
+ # SuperExec
156
+ RUN_SUPEREXEC_ENTER = auto()
157
+ RUN_SUPEREXEC_LEAVE = auto()
162
158
 
163
- # Driver: start_driver
164
- START_DRIVER_ENTER = auto()
165
- START_DRIVER_LEAVE = auto()
159
+ # --- Simulation Engine ------------------------------------------------------------
166
160
 
167
- # flower-client-app
168
- RUN_CLIENT_APP_ENTER = auto()
169
- RUN_CLIENT_APP_LEAVE = auto()
161
+ # CLI: flower-simulation
162
+ CLI_FLOWER_SIMULATION_ENTER = auto()
163
+ CLI_FLOWER_SIMULATION_LEAVE = auto()
170
164
 
171
- # flower-server-app
172
- RUN_SERVER_APP_ENTER = auto()
173
- RUN_SERVER_APP_LEAVE = auto()
165
+ # Python API: `run_simulation`
166
+ PYTHON_API_RUN_SIMULATION_ENTER = auto()
167
+ PYTHON_API_RUN_SIMULATION_LEAVE = auto()
174
168
 
175
- # SuperNode
169
+ # --- Deployment Engine ------------------------------------------------------------
170
+
171
+ # CLI: `flower-superlink`
172
+ RUN_SUPERLINK_ENTER = auto()
173
+ RUN_SUPERLINK_LEAVE = auto()
174
+
175
+ # CLI: `flower-supernode`
176
176
  RUN_SUPERNODE_ENTER = auto()
177
177
  RUN_SUPERNODE_LEAVE = auto()
178
178
 
179
- # SuperExec
180
- RUN_SUPEREXEC_ENTER = auto()
181
- RUN_SUPEREXEC_LEAVE = auto()
179
+ # CLI: `flower-server-app`
180
+ RUN_SERVER_APP_ENTER = auto()
181
+ RUN_SERVER_APP_LEAVE = auto()
182
+
183
+ # --- DEPRECATED -------------------------------------------------------------------
184
+
185
+ # [DEPRECATED] CLI: `flower-client-app`
186
+ RUN_CLIENT_APP_ENTER = auto()
187
+ RUN_CLIENT_APP_LEAVE = auto()
182
188
 
183
189
 
184
190
  # Use the ThreadPoolExecutor with max_workers=1 to have a queue
flwr/server/__init__.py CHANGED
@@ -17,14 +17,12 @@
17
17
 
18
18
  from . import strategy
19
19
  from . import workflow as workflow
20
- from .app import run_superlink as run_superlink
21
20
  from .app import start_server as start_server
22
21
  from .client_manager import ClientManager as ClientManager
23
22
  from .client_manager import SimpleClientManager as SimpleClientManager
24
23
  from .compat import LegacyContext as LegacyContext
25
24
  from .driver import Driver as Driver
26
25
  from .history import History as History
27
- from .run_serverapp import run_server_app as run_server_app
28
26
  from .server import Server as Server
29
27
  from .server_app import ServerApp as ServerApp
30
28
  from .server_config import ServerConfig as ServerConfig
@@ -40,8 +38,6 @@ __all__ = [
40
38
  "ServerAppComponents",
41
39
  "ServerConfig",
42
40
  "SimpleClientManager",
43
- "run_server_app",
44
- "run_superlink",
45
41
  "start_server",
46
42
  "strategy",
47
43
  "workflow",
flwr/server/app.py CHANGED
@@ -278,24 +278,24 @@ def run_superlink() -> None:
278
278
  fleet_thread.start()
279
279
  bckg_threads.append(fleet_thread)
280
280
  elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
281
- maybe_keys = _try_setup_client_authentication(args, certificates)
281
+ maybe_keys = _try_setup_node_authentication(args, certificates)
282
282
  interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
283
283
  if maybe_keys is not None:
284
284
  (
285
- client_public_keys,
285
+ node_public_keys,
286
286
  server_private_key,
287
287
  server_public_key,
288
288
  ) = maybe_keys
289
289
  state = state_factory.state()
290
- state.store_client_public_keys(client_public_keys)
290
+ state.store_node_public_keys(node_public_keys)
291
291
  state.store_server_private_public_key(
292
292
  private_key_to_bytes(server_private_key),
293
293
  public_key_to_bytes(server_public_key),
294
294
  )
295
295
  log(
296
296
  INFO,
297
- "Client authentication enabled with %d known public keys",
298
- len(client_public_keys),
297
+ "Node authentication enabled with %d known public keys",
298
+ len(node_public_keys),
299
299
  )
300
300
  interceptors = [AuthenticateServerInterceptor(state)]
301
301
 
@@ -344,7 +344,7 @@ def _format_address(address: str) -> Tuple[str, str, int]:
344
344
  return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
345
345
 
346
346
 
347
- def _try_setup_client_authentication(
347
+ def _try_setup_node_authentication(
348
348
  args: argparse.Namespace,
349
349
  certificates: Optional[Tuple[bytes, bytes, bytes]],
350
350
  ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
@@ -373,16 +373,16 @@ def _try_setup_client_authentication(
373
373
  "`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
374
374
  )
375
375
 
376
- client_keys_file_path = Path(args.auth_list_public_keys)
377
- if not client_keys_file_path.exists():
376
+ node_keys_file_path = Path(args.auth_list_public_keys)
377
+ if not node_keys_file_path.exists():
378
378
  sys.exit(
379
379
  "The provided path to the known public keys CSV file does not exist: "
380
- f"{client_keys_file_path}. "
380
+ f"{node_keys_file_path}. "
381
381
  "Please provide the CSV file path containing known public keys "
382
382
  "to '--auth-list-public-keys'."
383
383
  )
384
384
 
385
- client_public_keys: Set[bytes] = set()
385
+ node_public_keys: Set[bytes] = set()
386
386
 
387
387
  try:
388
388
  ssh_private_key = load_ssh_private_key(
@@ -413,13 +413,13 @@ def _try_setup_client_authentication(
413
413
  "path points to a valid public key file and try again."
414
414
  )
415
415
 
416
- with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
416
+ with open(node_keys_file_path, newline="", encoding="utf-8") as csvfile:
417
417
  reader = csv.reader(csvfile)
418
418
  for row in reader:
419
419
  for element in row:
420
420
  public_key = load_ssh_public_key(element.encode())
421
421
  if isinstance(public_key, ec.EllipticCurvePublicKey):
422
- client_public_keys.add(public_key_to_bytes(public_key))
422
+ node_public_keys.add(public_key_to_bytes(public_key))
423
423
  else:
424
424
  sys.exit(
425
425
  "Error: Unable to parse the public keys in the CSV "
@@ -427,7 +427,7 @@ def _try_setup_client_authentication(
427
427
  "known SSH public keys files and try again."
428
428
  )
429
429
  return (
430
- client_public_keys,
430
+ node_public_keys,
431
431
  ssh_private_key,
432
432
  ssh_public_key,
433
433
  )
flwr/server/compat/app.py CHANGED
@@ -18,7 +18,6 @@
18
18
  from logging import INFO
19
19
  from typing import Optional
20
20
 
21
- from flwr.common import EventType, event
22
21
  from flwr.common.logger import log
23
22
  from flwr.server.client_manager import ClientManager
24
23
  from flwr.server.history import History
@@ -65,8 +64,6 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
65
64
  hist : flwr.server.history.History
66
65
  Object containing training and evaluation metrics.
67
66
  """
68
- event(EventType.START_DRIVER_ENTER)
69
-
70
67
  # Initialize the Driver API server and config
71
68
  initialized_server, initialized_config = init_defaults(
72
69
  server=server,
@@ -96,6 +93,4 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
96
93
  f_stop.set()
97
94
  thread.join()
98
95
 
99
- event(EventType.START_SERVER_LEAVE)
100
-
101
96
  return hist
@@ -21,7 +21,7 @@ from typing import Iterable, List, Optional, cast
21
21
 
22
22
  import grpc
23
23
 
24
- from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
24
+ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
27
  from flwr.common.serde import (
@@ -94,7 +94,6 @@ class GrpcDriver(Driver):
94
94
 
95
95
  This will not call GetRun.
96
96
  """
97
- event(EventType.DRIVER_CONNECT)
98
97
  if self._is_connected:
99
98
  log(WARNING, "Already connected")
100
99
  return
@@ -108,7 +107,6 @@ class GrpcDriver(Driver):
108
107
 
109
108
  def _disconnect(self) -> None:
110
109
  """Disconnect from the Driver API."""
111
- event(EventType.DRIVER_DISCONNECT)
112
110
  if not self._is_connected:
113
111
  log(DEBUG, "Already disconnected")
114
112
  return
@@ -97,6 +97,21 @@ def run_server_app() -> None:
97
97
 
98
98
  args = _parse_args_run_server_app().parse_args()
99
99
 
100
+ # Check if the server app reference is passed.
101
+ # Since Flower 1.11, passing a reference is not allowed.
102
+ app_path: Optional[str] = args.app
103
+ # If the provided app_path doesn't exist, and contains a ":",
104
+ # it is likely to be a server app reference instead of a path.
105
+ if app_path is not None and not Path(app_path).exists() and ":" in app_path:
106
+ sys.exit(
107
+ "It appears you've passed a reference like `server:app`.\n\n"
108
+ "Note that since version `1.11.0`, `flower-server-app` no longer supports "
109
+ "passing a reference to a `ServerApp` attribute. Instead, you need to pass "
110
+ "the path to Flower app via the argument `--app`. This is the path to a "
111
+ "directory containing a `pyproject.toml`. You can create a valid Flower "
112
+ "app by executing `flwr new` and following the prompt."
113
+ )
114
+
100
115
  if args.server != ADDRESS_DRIVER_API:
101
116
  warn = "Passing flag --server is deprecated. Use --superlink instead."
102
117
  warn_deprecated_feature(warn)
@@ -151,7 +166,6 @@ def run_server_app() -> None:
151
166
  cert_path,
152
167
  )
153
168
 
154
- app_path: Optional[str] = args.app
155
169
  if not (app_path is None) ^ (args.run_id is None):
156
170
  raise sys.exit(
157
171
  "Please provide either a Flower App path or a Run ID, but not both. "
@@ -51,19 +51,22 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
51
51
  self, request: CreateNodeRequest, context: grpc.ServicerContext
52
52
  ) -> CreateNodeResponse:
53
53
  """."""
54
- log(INFO, "FleetServicer.CreateNode")
54
+ log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
55
+ log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
55
56
  response = message_handler.create_node(
56
57
  request=request,
57
58
  state=self.state_factory.state(),
58
59
  )
59
- log(INFO, "FleetServicer: Created node_id=%s", response.node.node_id)
60
+ log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
61
+ log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
60
62
  return response
61
63
 
62
64
  def DeleteNode(
63
65
  self, request: DeleteNodeRequest, context: grpc.ServicerContext
64
66
  ) -> DeleteNodeResponse:
65
67
  """."""
66
- log(INFO, "FleetServicer.DeleteNode")
68
+ log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
69
+ log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
67
70
  return message_handler.delete_node(
68
71
  request=request,
69
72
  state=self.state_factory.state(),
@@ -71,7 +74,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
71
74
 
72
75
  def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
73
76
  """."""
74
- log(DEBUG, "FleetServicer.Ping")
77
+ log(DEBUG, "[Fleet.Ping] Request: %s", request)
75
78
  return message_handler.ping(
76
79
  request=request,
77
80
  state=self.state_factory.state(),
@@ -81,7 +84,8 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
81
84
  self, request: PullTaskInsRequest, context: grpc.ServicerContext
82
85
  ) -> PullTaskInsResponse:
83
86
  """Pull TaskIns."""
84
- log(INFO, "FleetServicer.PullTaskIns")
87
+ log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id)
88
+ log(DEBUG, "[Fleet.PullTaskIns] Request: %s", request)
85
89
  return message_handler.pull_task_ins(
86
90
  request=request,
87
91
  state=self.state_factory.state(),
@@ -91,7 +95,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
91
95
  self, request: PushTaskResRequest, context: grpc.ServicerContext
92
96
  ) -> PushTaskResResponse:
93
97
  """Push TaskRes."""
94
- log(INFO, "FleetServicer.PushTaskRes")
98
+ if request.task_res_list:
99
+ log(
100
+ INFO,
101
+ "[Fleet.PushTaskRes] Push results from node_id=%s",
102
+ request.task_res_list[0].task.producer.node_id,
103
+ )
104
+ else:
105
+ log(INFO, "[Fleet.PushTaskRes] No task results to push")
95
106
  return message_handler.push_task_res(
96
107
  request=request,
97
108
  state=self.state_factory.state(),
@@ -101,7 +112,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
101
112
  self, request: GetRunRequest, context: grpc.ServicerContext
102
113
  ) -> GetRunResponse:
103
114
  """Get run information."""
104
- log(INFO, "FleetServicer.GetRun")
115
+ log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
105
116
  return message_handler.get_run(
106
117
  request=request,
107
118
  state=self.state_factory.state(),
@@ -111,7 +122,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
111
122
  self, request: GetFabRequest, context: grpc.ServicerContext
112
123
  ) -> GetFabResponse:
113
124
  """Get FAB."""
114
- log(DEBUG, "DriverServicer.GetFab")
125
+ log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
115
126
  return message_handler.get_fab(
116
127
  request=request,
117
128
  ffs=self.ffs_factory.ffs(),
@@ -78,13 +78,13 @@ def _get_value_from_tuples(
78
78
 
79
79
 
80
80
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
81
- """Server interceptor for client authentication."""
81
+ """Server interceptor for node authentication."""
82
82
 
83
83
  def __init__(self, state: State):
84
84
  self.state = state
85
85
 
86
- self.client_public_keys = state.get_client_public_keys()
87
- if len(self.client_public_keys) == 0:
86
+ self.node_public_keys = state.get_node_public_keys()
87
+ if len(self.node_public_keys) == 0:
88
88
  log(WARNING, "Authentication enabled, but no known public keys configured")
89
89
 
90
90
  private_key = self.state.get_server_private_key()
@@ -103,9 +103,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
103
103
  ) -> grpc.RpcMethodHandler:
104
104
  """Flower server interceptor authentication logic.
105
105
 
106
- Intercept all unary calls from clients and authenticate clients by validating
107
- auth metadata sent by the client. Continue RPC call if client is authenticated,
108
- else, terminate RPC call by setting context to abort.
106
+ Intercept all unary calls from nodes and authenticate nodes by validating auth
107
+ metadata sent by the node. Continue RPC call if node is authenticated, else,
108
+ terminate RPC call by setting context to abort.
109
109
  """
110
110
  # One of the method handlers in
111
111
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
@@ -119,17 +119,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
119
119
  request: Request,
120
120
  context: grpc.ServicerContext,
121
121
  ) -> Response:
122
- client_public_key_bytes = base64.urlsafe_b64decode(
122
+ node_public_key_bytes = base64.urlsafe_b64decode(
123
123
  _get_value_from_tuples(
124
124
  _PUBLIC_KEY_HEADER, context.invocation_metadata()
125
125
  )
126
126
  )
127
- if client_public_key_bytes not in self.client_public_keys:
127
+ if node_public_key_bytes not in self.node_public_keys:
128
128
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
129
129
 
130
130
  if isinstance(request, CreateNodeRequest):
131
131
  response = self._create_authenticated_node(
132
- client_public_key_bytes, request, context
132
+ node_public_key_bytes, request, context
133
133
  )
134
134
  log(
135
135
  INFO,
@@ -144,13 +144,13 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
144
144
  _AUTH_TOKEN_HEADER, context.invocation_metadata()
145
145
  )
146
146
  )
147
- public_key = bytes_to_public_key(client_public_key_bytes)
147
+ public_key = bytes_to_public_key(node_public_key_bytes)
148
148
 
149
149
  if not self._verify_hmac(public_key, request, hmac_value):
150
150
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
151
151
 
152
152
  # Verify node_id
153
- node_id = self.state.get_node_id(client_public_key_bytes)
153
+ node_id = self.state.get_node_id(node_public_key_bytes)
154
154
 
155
155
  if not self._verify_node_id(node_id, request):
156
156
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")