flwr-nightly 1.14.0.dev20241202__py3-none-any.whl → 1.14.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (108) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +1 -0
  7. flwr/cli/log.py +11 -31
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +83 -0
  10. flwr/cli/ls.py +198 -102
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +2 -1
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  14. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
  15. flwr/cli/run/__init__.py +1 -0
  16. flwr/cli/run/run.py +96 -39
  17. flwr/cli/stop.py +91 -0
  18. flwr/cli/utils.py +109 -1
  19. flwr/client/app.py +3 -2
  20. flwr/client/client.py +1 -0
  21. flwr/client/clientapp/app.py +1 -0
  22. flwr/client/clientapp/utils.py +1 -0
  23. flwr/client/grpc_adapter_client/connection.py +1 -1
  24. flwr/client/grpc_client/connection.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +3 -3
  26. flwr/client/message_handler/message_handler.py +1 -0
  27. flwr/client/mod/comms_mods.py +1 -0
  28. flwr/client/mod/localdp_mod.py +1 -1
  29. flwr/client/nodestate/__init__.py +1 -0
  30. flwr/client/nodestate/nodestate.py +1 -0
  31. flwr/client/nodestate/nodestate_factory.py +1 -0
  32. flwr/client/rest_client/connection.py +3 -3
  33. flwr/client/supernode/app.py +1 -0
  34. flwr/common/address.py +1 -0
  35. flwr/common/args.py +1 -0
  36. flwr/common/auth_plugin/__init__.py +24 -0
  37. flwr/common/auth_plugin/auth_plugin.py +111 -0
  38. flwr/common/config.py +3 -1
  39. flwr/common/constant.py +17 -1
  40. flwr/common/logger.py +25 -0
  41. flwr/common/message.py +1 -0
  42. flwr/common/object_ref.py +57 -54
  43. flwr/common/pyproject.py +1 -0
  44. flwr/common/record/__init__.py +1 -0
  45. flwr/common/record/parametersrecord.py +1 -0
  46. flwr/common/retry_invoker.py +75 -0
  47. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  48. flwr/common/telemetry.py +2 -1
  49. flwr/common/typing.py +12 -0
  50. flwr/common/version.py +1 -0
  51. flwr/proto/exec_pb2.py +38 -14
  52. flwr/proto/exec_pb2.pyi +107 -2
  53. flwr/proto/exec_pb2_grpc.py +102 -0
  54. flwr/proto/exec_pb2_grpc.pyi +39 -0
  55. flwr/proto/fab_pb2.py +4 -4
  56. flwr/proto/fab_pb2.pyi +4 -1
  57. flwr/proto/serverappio_pb2.py +18 -18
  58. flwr/proto/serverappio_pb2.pyi +8 -2
  59. flwr/proto/serverappio_pb2_grpc.py +34 -0
  60. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  61. flwr/proto/simulationio_pb2.py +2 -2
  62. flwr/proto/simulationio_pb2_grpc.py +34 -0
  63. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  64. flwr/server/app.py +53 -1
  65. flwr/server/compat/app_utils.py +7 -1
  66. flwr/server/driver/grpc_driver.py +11 -63
  67. flwr/server/driver/inmemory_driver.py +5 -1
  68. flwr/server/serverapp/app.py +9 -2
  69. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  70. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  71. flwr/server/superlink/driver/serverappio_servicer.py +73 -23
  72. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  73. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  74. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +31 -2
  77. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  78. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  79. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  80. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  81. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
  82. flwr/server/superlink/linkstate/linkstate.py +13 -2
  83. flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
  84. flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
  85. flwr/server/superlink/utils.py +65 -0
  86. flwr/simulation/app.py +1 -0
  87. flwr/simulation/ray_transport/ray_actor.py +1 -0
  88. flwr/simulation/ray_transport/utils.py +1 -0
  89. flwr/simulation/run_simulation.py +1 -0
  90. flwr/superexec/app.py +1 -0
  91. flwr/superexec/deployment.py +1 -0
  92. flwr/superexec/exec_grpc.py +19 -1
  93. flwr/superexec/exec_servicer.py +76 -2
  94. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  95. flwr/superexec/executor.py +1 -0
  96. {flwr_nightly-1.14.0.dev20241202.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/METADATA +8 -7
  97. {flwr_nightly-1.14.0.dev20241202.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/RECORD +100 -100
  98. flwr/proto/common_pb2.py +0 -36
  99. flwr/proto/common_pb2.pyi +0 -121
  100. flwr/proto/common_pb2_grpc.py +0 -4
  101. flwr/proto/common_pb2_grpc.pyi +0 -4
  102. flwr/proto/control_pb2.py +0 -27
  103. flwr/proto/control_pb2.pyi +0 -7
  104. flwr/proto/control_pb2_grpc.py +0 -135
  105. flwr/proto/control_pb2_grpc.pyi +0 -53
  106. {flwr_nightly-1.14.0.dev20241202.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/LICENSE +0 -0
  107. {flwr_nightly-1.14.0.dev20241202.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/WHEEL +0 -0
  108. {flwr_nightly-1.14.0.dev20241202.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/entry_points.txt +0 -0
flwr/cli/ls.py CHANGED
@@ -15,27 +15,27 @@
15
15
  """Flower command line interface `ls` command."""
16
16
 
17
17
 
18
+ import io
19
+ import json
18
20
  from datetime import datetime, timedelta
19
- from logging import DEBUG
20
21
  from pathlib import Path
21
- from typing import Annotated, Any, Optional
22
+ from typing import Annotated, Optional, Union
22
23
 
23
- import grpc
24
24
  import typer
25
25
  from rich.console import Console
26
26
  from rich.table import Table
27
27
  from rich.text import Text
28
+ from typer import Exit
28
29
 
29
30
  from flwr.cli.config_utils import (
31
+ exit_if_no_address,
30
32
  load_and_validate,
31
- validate_certificate_in_federation_config,
33
+ process_loaded_project_config,
32
34
  validate_federation_in_project_config,
33
- validate_project_config,
34
35
  )
35
- from flwr.common.constant import FAB_CONFIG_FILE, SubStatus
36
+ from flwr.common.constant import FAB_CONFIG_FILE, CliOutputFormat, SubStatus
36
37
  from flwr.common.date import format_timedelta, isoformat8601_utc
37
- from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
38
- from flwr.common.logger import log
38
+ from flwr.common.logger import redirect_output, remove_emojis, restore_output
39
39
  from flwr.common.serde import run_from_proto
40
40
  from flwr.common.typing import Run
41
41
  from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
@@ -44,8 +44,12 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
44
44
  )
45
45
  from flwr.proto.exec_pb2_grpc import ExecStub
46
46
 
47
+ from .utils import init_channel, try_obtain_cli_auth_plugin
47
48
 
48
- def ls(
49
+ _RunListType = tuple[int, str, str, str, str, str, str, str, str]
50
+
51
+
52
+ def ls( # pylint: disable=too-many-locals, too-many-branches
49
53
  app: Annotated[
50
54
  Path,
51
55
  typer.Argument(help="Path of the Flower project"),
@@ -68,94 +72,86 @@ def ls(
68
72
  help="Specific run ID to display",
69
73
  ),
70
74
  ] = None,
75
+ output_format: Annotated[
76
+ str,
77
+ typer.Option(
78
+ "--format",
79
+ case_sensitive=False,
80
+ help="Format output using 'default' view or 'json'",
81
+ ),
82
+ ] = CliOutputFormat.DEFAULT,
71
83
  ) -> None:
72
84
  """List runs."""
73
- # Load and validate federation config
74
- typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
75
-
76
- pyproject_path = app / FAB_CONFIG_FILE if app else None
77
- config, errors, warnings = load_and_validate(path=pyproject_path)
78
- config = validate_project_config(config, errors, warnings)
79
- federation, federation_config = validate_federation_in_project_config(
80
- federation, config
81
- )
82
-
83
- if "address" not in federation_config:
84
- typer.secho(
85
- "❌ `flwr ls` currently works with Exec API. Ensure that the correct"
86
- "Exec API address is provided in the `pyproject.toml`.",
87
- fg=typer.colors.RED,
88
- bold=True,
89
- )
90
- raise typer.Exit(code=1)
91
-
85
+ suppress_output = output_format == CliOutputFormat.JSON
86
+ captured_output = io.StringIO()
92
87
  try:
93
- if runs and run_id is not None:
94
- raise ValueError(
95
- "The options '--runs' and '--run-id' are mutually exclusive."
96
- )
88
+ if suppress_output:
89
+ redirect_output(captured_output)
97
90
 
98
- channel = _init_channel(app, federation_config)
99
- stub = ExecStub(channel)
91
+ # Load and validate federation config
92
+ typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
100
93
 
101
- # Display information about a specific run ID
102
- if run_id is not None:
103
- typer.echo(f"🔍 Displaying information for run ID {run_id}...")
104
- _display_one_run(stub, run_id)
105
- # By default, list all runs
106
- else:
107
- typer.echo("📄 Listing all runs...")
108
- _list_runs(stub)
109
-
110
- except ValueError as err:
111
- typer.secho(
112
- f"❌ {err}",
113
- fg=typer.colors.RED,
114
- bold=True,
94
+ pyproject_path = app / FAB_CONFIG_FILE if app else None
95
+ config, errors, warnings = load_and_validate(path=pyproject_path)
96
+ config = process_loaded_project_config(config, errors, warnings)
97
+ federation, federation_config = validate_federation_in_project_config(
98
+ federation, config
115
99
  )
116
- raise typer.Exit(code=1) from err
100
+ exit_if_no_address(federation_config, "ls")
101
+
102
+ try:
103
+ if runs and run_id is not None:
104
+ raise ValueError(
105
+ "The options '--runs' and '--run-id' are mutually exclusive."
106
+ )
107
+ auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
108
+ channel = init_channel(app, federation_config, auth_plugin)
109
+ stub = ExecStub(channel)
110
+
111
+ # Display information about a specific run ID
112
+ if run_id is not None:
113
+ typer.echo(f"🔍 Displaying information for run ID {run_id}...")
114
+ restore_output()
115
+ _display_one_run(stub, run_id, output_format)
116
+ # By default, list all runs
117
+ else:
118
+ typer.echo("📄 Listing all runs...")
119
+ restore_output()
120
+ _list_runs(stub, output_format)
121
+
122
+ except ValueError as err:
123
+ typer.secho(
124
+ f"❌ {err}",
125
+ fg=typer.colors.RED,
126
+ bold=True,
127
+ )
128
+ raise typer.Exit(code=1) from err
129
+ finally:
130
+ channel.close()
131
+ except (typer.Exit, Exception) as err: # pylint: disable=broad-except
132
+ if suppress_output:
133
+ restore_output()
134
+ e_message = captured_output.getvalue()
135
+ _print_json_error(e_message, err)
136
+ else:
137
+ typer.secho(
138
+ f"{err}",
139
+ fg=typer.colors.RED,
140
+ bold=True,
141
+ )
117
142
  finally:
118
- channel.close()
119
-
120
-
121
- def on_channel_state_change(channel_connectivity: str) -> None:
122
- """Log channel connectivity."""
123
- log(DEBUG, channel_connectivity)
124
-
125
-
126
- def _init_channel(app: Path, federation_config: dict[str, Any]) -> grpc.Channel:
127
- """Initialize gRPC channel to the Exec API."""
128
- insecure, root_certificates_bytes = validate_certificate_in_federation_config(
129
- app, federation_config
130
- )
131
- channel = create_channel(
132
- server_address=federation_config["address"],
133
- insecure=insecure,
134
- root_certificates=root_certificates_bytes,
135
- max_message_length=GRPC_MAX_MESSAGE_LENGTH,
136
- interceptors=None,
137
- )
138
- channel.subscribe(on_channel_state_change)
139
- return channel
143
+ if suppress_output:
144
+ restore_output()
145
+ captured_output.close()
140
146
 
141
147
 
142
- def _format_run_table(run_dict: dict[int, Run], now_isoformat: str) -> Table:
143
- """Format run status as a rich Table."""
144
- table = Table(header_style="bold cyan", show_lines=True)
148
+ def _format_runs(run_dict: dict[int, Run], now_isoformat: str) -> list[_RunListType]:
149
+ """Format runs to a list."""
145
150
 
146
151
  def _format_datetime(dt: Optional[datetime]) -> str:
147
152
  return isoformat8601_utc(dt).replace("T", " ") if dt else "N/A"
148
153
 
149
- # Add columns
150
- table.add_column(
151
- Text("Run ID", justify="center"), style="bright_white", overflow="fold"
152
- )
153
- table.add_column(Text("FAB", justify="center"), style="dim white")
154
- table.add_column(Text("Status", justify="center"))
155
- table.add_column(Text("Elapsed", justify="center"), style="blue")
156
- table.add_column(Text("Created At", justify="center"), style="dim white")
157
- table.add_column(Text("Running At", justify="center"), style="dim white")
158
- table.add_column(Text("Finished At", justify="center"), style="dim white")
154
+ run_list: list[_RunListType] = []
159
155
 
160
156
  # Add rows
161
157
  for run in sorted(
@@ -167,15 +163,6 @@ def _format_run_table(run_dict: dict[int, Run], now_isoformat: str) -> Table:
167
163
  else:
168
164
  status_text = f"{run.status.status}:{run.status.sub_status}"
169
165
 
170
- # Style the status based on its value
171
- sub_status = run.status.sub_status
172
- if sub_status == SubStatus.COMPLETED:
173
- status_style = "green"
174
- elif sub_status == SubStatus.FAILED:
175
- status_style = "red"
176
- else:
177
- status_style = "yellow"
178
-
179
166
  # Convert isoformat to datetime
180
167
  pending_at = datetime.fromisoformat(run.pending_at) if run.pending_at else None
181
168
  running_at = datetime.fromisoformat(run.running_at) if run.running_at else None
@@ -192,31 +179,124 @@ def _format_run_table(run_dict: dict[int, Run], now_isoformat: str) -> Table:
192
179
  end_time = datetime.fromisoformat(now_isoformat)
193
180
  elapsed_time = end_time - running_at
194
181
 
195
- table.add_row(
196
- f"[bold]{run.run_id}[/bold]",
197
- f"{run.fab_id} (v{run.fab_version})",
182
+ run_list.append(
183
+ (
184
+ run.run_id,
185
+ run.fab_id,
186
+ run.fab_version,
187
+ run.fab_hash,
188
+ status_text,
189
+ format_timedelta(elapsed_time),
190
+ _format_datetime(pending_at),
191
+ _format_datetime(running_at),
192
+ _format_datetime(finished_at),
193
+ )
194
+ )
195
+ return run_list
196
+
197
+
198
+ def _to_table(run_list: list[_RunListType]) -> Table:
199
+ """Format the provided run list to a rich Table."""
200
+ table = Table(header_style="bold cyan", show_lines=True)
201
+
202
+ # Add columns
203
+ table.add_column(
204
+ Text("Run ID", justify="center"), style="bright_white", overflow="fold"
205
+ )
206
+ table.add_column(Text("FAB", justify="center"), style="dim white")
207
+ table.add_column(Text("Status", justify="center"))
208
+ table.add_column(Text("Elapsed", justify="center"), style="blue")
209
+ table.add_column(Text("Created At", justify="center"), style="dim white")
210
+ table.add_column(Text("Running At", justify="center"), style="dim white")
211
+ table.add_column(Text("Finished At", justify="center"), style="dim white")
212
+
213
+ for row in run_list:
214
+ (
215
+ run_id,
216
+ fab_id,
217
+ fab_version,
218
+ _,
219
+ status_text,
220
+ elapsed,
221
+ created_at,
222
+ running_at,
223
+ finished_at,
224
+ ) = row
225
+ # Style the status based on its value
226
+ sub_status = status_text.rsplit(":", maxsplit=1)[-1]
227
+ if sub_status == SubStatus.COMPLETED:
228
+ status_style = "green"
229
+ elif sub_status == SubStatus.FAILED:
230
+ status_style = "red"
231
+ else:
232
+ status_style = "yellow"
233
+
234
+ formatted_row = (
235
+ f"[bold]{run_id}[/bold]",
236
+ f"{fab_id} (v{fab_version})",
198
237
  f"[{status_style}]{status_text}[/{status_style}]",
199
- format_timedelta(elapsed_time),
200
- _format_datetime(pending_at),
201
- _format_datetime(running_at),
202
- _format_datetime(finished_at),
238
+ elapsed,
239
+ created_at,
240
+ running_at,
241
+ finished_at,
203
242
  )
243
+ table.add_row(*formatted_row)
244
+
204
245
  return table
205
246
 
206
247
 
248
+ def _to_json(run_list: list[_RunListType]) -> str:
249
+ """Format run status list to a JSON formatted string."""
250
+ runs_list = []
251
+ for row in run_list:
252
+ (
253
+ run_id,
254
+ fab_id,
255
+ fab_version,
256
+ fab_hash,
257
+ status_text,
258
+ elapsed,
259
+ created_at,
260
+ running_at,
261
+ finished_at,
262
+ ) = row
263
+ runs_list.append(
264
+ {
265
+ "run-id": run_id,
266
+ "fab-id": fab_id,
267
+ "fab-name": fab_id.split("/")[-1],
268
+ "fab-version": fab_version,
269
+ "fab-hash": fab_hash[:8],
270
+ "status": status_text,
271
+ "elapsed": elapsed,
272
+ "created-at": created_at,
273
+ "running-at": running_at,
274
+ "finished-at": finished_at,
275
+ }
276
+ )
277
+
278
+ return json.dumps({"success": True, "runs": runs_list})
279
+
280
+
207
281
  def _list_runs(
208
282
  stub: ExecStub,
283
+ output_format: str = CliOutputFormat.DEFAULT,
209
284
  ) -> None:
210
285
  """List all runs."""
211
286
  res: ListRunsResponse = stub.ListRuns(ListRunsRequest())
212
287
  run_dict = {run_id: run_from_proto(proto) for run_id, proto in res.run_dict.items()}
213
288
 
214
- Console().print(_format_run_table(run_dict, res.now))
289
+ formatted_runs = _format_runs(run_dict, res.now)
290
+ if output_format == CliOutputFormat.JSON:
291
+ Console().print_json(_to_json(formatted_runs))
292
+ else:
293
+ Console().print(_to_table(formatted_runs))
215
294
 
216
295
 
217
296
  def _display_one_run(
218
297
  stub: ExecStub,
219
298
  run_id: int,
299
+ output_format: str = CliOutputFormat.DEFAULT,
220
300
  ) -> None:
221
301
  """Display information about a specific run."""
222
302
  res: ListRunsResponse = stub.ListRuns(ListRunsRequest(run_id=run_id))
@@ -225,4 +305,20 @@ def _display_one_run(
225
305
 
226
306
  run_dict = {run_id: run_from_proto(proto) for run_id, proto in res.run_dict.items()}
227
307
 
228
- Console().print(_format_run_table(run_dict, res.now))
308
+ formatted_runs = _format_runs(run_dict, res.now)
309
+ if output_format == CliOutputFormat.JSON:
310
+ Console().print_json(_to_json(formatted_runs))
311
+ else:
312
+ Console().print(_to_table(formatted_runs))
313
+
314
+
315
+ def _print_json_error(msg: str, e: Union[Exit, Exception]) -> None:
316
+ """Print error message as JSON."""
317
+ Console().print_json(
318
+ json.dumps(
319
+ {
320
+ "success": False,
321
+ "error-message": remove_emojis(str(msg) + "\n" + str(e)),
322
+ }
323
+ )
324
+ )
flwr/cli/new/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `new` command."""
16
16
 
17
+
17
18
  from .new import new as new
18
19
 
19
20
  __all__ = [
flwr/cli/new/new.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `new` command."""
16
16
 
17
+
17
18
  import re
18
19
  from enum import Enum
19
20
  from pathlib import Path
@@ -81,7 +82,7 @@ def render_template(template: str, data: dict[str, str]) -> str:
81
82
  def create_file(file_path: Path, content: str) -> None:
82
83
  """Create file including all nessecary directories and write content into file."""
83
84
  file_path.parent.mkdir(exist_ok=True)
84
- file_path.write_text(content)
85
+ file_path.write_text(content, encoding="utf-8")
85
86
 
86
87
 
87
88
  def render_and_create(file_path: Path, template: str, context: dict[str, str]) -> None:
@@ -12,10 +12,10 @@ dependencies = [
12
12
  "flwr-datasets>=0.3.0",
13
13
  "torch==2.3.1",
14
14
  "trl==0.8.1",
15
- "bitsandbytes==0.43.0",
15
+ "bitsandbytes==0.45.0",
16
16
  "scipy==1.13.0",
17
17
  "peft==0.6.2",
18
- "transformers==4.43.1",
18
+ "transformers==4.47.0",
19
19
  "sentencepiece==0.2.0",
20
20
  "omegaconf==2.3.0",
21
21
  "hf_transfer==0.1.8",
@@ -10,8 +10,7 @@ license = "Apache-2.0"
10
10
  dependencies = [
11
11
  "flwr[simulation]>=1.13.1",
12
12
  "flwr-datasets[vision]>=0.3.0",
13
- "mlx==0.16.1",
14
- "numpy==1.24.4",
13
+ "mlx==0.21.1",
15
14
  ]
16
15
 
17
16
  [tool.hatch.build.targets.wheel]
flwr/cli/run/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `run` command."""
16
16
 
17
+
17
18
  from .run import run as run
18
19
 
19
20
  __all__ = [
flwr/cli/run/run.py CHANGED
@@ -14,28 +14,30 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `run` command."""
16
16
 
17
+
18
+ import io
17
19
  import json
18
20
  import subprocess
19
- from logging import DEBUG
20
21
  from pathlib import Path
21
- from typing import Annotated, Any, Optional
22
+ from typing import Annotated, Any, Optional, Union
22
23
 
23
24
  import typer
25
+ from rich.console import Console
24
26
 
25
27
  from flwr.cli.build import build
26
28
  from flwr.cli.config_utils import (
29
+ get_fab_metadata,
27
30
  load_and_validate,
28
- validate_certificate_in_federation_config,
31
+ process_loaded_project_config,
29
32
  validate_federation_in_project_config,
30
- validate_project_config,
31
33
  )
32
34
  from flwr.common.config import (
33
35
  flatten_dict,
34
36
  parse_config_args,
35
37
  user_config_to_configsrecord,
36
38
  )
37
- from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
38
- from flwr.common.logger import log
39
+ from flwr.common.constant import CliOutputFormat
40
+ from flwr.common.logger import redirect_output, remove_emojis, restore_output
39
41
  from flwr.common.serde import (
40
42
  configs_record_to_proto,
41
43
  fab_to_proto,
@@ -46,15 +48,11 @@ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
46
48
  from flwr.proto.exec_pb2_grpc import ExecStub
47
49
 
48
50
  from ..log import start_stream
51
+ from ..utils import init_channel, try_obtain_cli_auth_plugin
49
52
 
50
53
  CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
51
54
 
52
55
 
53
- def on_channel_state_change(channel_connectivity: str) -> None:
54
- """Log channel connectivity."""
55
- log(DEBUG, channel_connectivity)
56
-
57
-
58
56
  # pylint: disable-next=too-many-locals
59
57
  def run(
60
58
  app: Annotated[
@@ -85,46 +83,74 @@ def run(
85
83
  "logs are not streamed by default.",
86
84
  ),
87
85
  ] = False,
86
+ output_format: Annotated[
87
+ str,
88
+ typer.Option(
89
+ "--format",
90
+ case_sensitive=False,
91
+ help="Format output using 'default' view or 'json'",
92
+ ),
93
+ ] = CliOutputFormat.DEFAULT,
88
94
  ) -> None:
89
95
  """Run Flower App."""
90
- typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
91
-
92
- pyproject_path = app / "pyproject.toml" if app else None
93
- config, errors, warnings = load_and_validate(path=pyproject_path)
94
- config = validate_project_config(config, errors, warnings)
95
- federation, federation_config = validate_federation_in_project_config(
96
- federation, config
97
- )
98
-
99
- if "address" in federation_config:
100
- _run_with_exec_api(app, federation_config, config_overrides, stream)
101
- else:
102
- _run_without_exec_api(app, federation_config, config_overrides, federation)
103
-
96
+ suppress_output = output_format == CliOutputFormat.JSON
97
+ captured_output = io.StringIO()
98
+ try:
99
+ if suppress_output:
100
+ redirect_output(captured_output)
101
+ typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
102
+
103
+ pyproject_path = app / "pyproject.toml" if app else None
104
+ config, errors, warnings = load_and_validate(path=pyproject_path)
105
+ config = process_loaded_project_config(config, errors, warnings)
106
+ federation, federation_config = validate_federation_in_project_config(
107
+ federation, config
108
+ )
104
109
 
105
- # pylint: disable-next=too-many-locals
110
+ if "address" in federation_config:
111
+ _run_with_exec_api(
112
+ app,
113
+ federation,
114
+ federation_config,
115
+ config_overrides,
116
+ stream,
117
+ output_format,
118
+ )
119
+ else:
120
+ _run_without_exec_api(app, federation_config, config_overrides, federation)
121
+ except (typer.Exit, Exception) as err: # pylint: disable=broad-except
122
+ if suppress_output:
123
+ restore_output()
124
+ e_message = captured_output.getvalue()
125
+ _print_json_error(e_message, err)
126
+ else:
127
+ typer.secho(
128
+ f"{err}",
129
+ fg=typer.colors.RED,
130
+ bold=True,
131
+ )
132
+ finally:
133
+ if suppress_output:
134
+ restore_output()
135
+ captured_output.close()
136
+
137
+
138
+ # pylint: disable-next=R0913, R0914, R0917
106
139
  def _run_with_exec_api(
107
140
  app: Path,
141
+ federation: str,
108
142
  federation_config: dict[str, Any],
109
143
  config_overrides: Optional[list[str]],
110
144
  stream: bool,
145
+ output_format: str,
111
146
  ) -> None:
112
-
113
- insecure, root_certificates_bytes = validate_certificate_in_federation_config(
114
- app, federation_config
115
- )
116
- channel = create_channel(
117
- server_address=federation_config["address"],
118
- insecure=insecure,
119
- root_certificates=root_certificates_bytes,
120
- max_message_length=GRPC_MAX_MESSAGE_LENGTH,
121
- interceptors=None,
122
- )
123
- channel.subscribe(on_channel_state_change)
147
+ auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
148
+ channel = init_channel(app, federation_config, auth_plugin)
124
149
  stub = ExecStub(channel)
125
150
 
126
151
  fab_path, fab_hash = build(app)
127
152
  content = Path(fab_path).read_bytes()
153
+ fab_id, fab_version = get_fab_metadata(Path(fab_path))
128
154
 
129
155
  # Delete FAB file once the bytes is computed
130
156
  Path(fab_path).unlink()
@@ -142,7 +168,26 @@ def _run_with_exec_api(
142
168
  )
143
169
  res = stub.StartRun(req)
144
170
 
145
- typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
171
+ if res.HasField("run_id"):
172
+ typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
173
+ else:
174
+ typer.secho("❌ Failed to start run", fg=typer.colors.RED)
175
+ raise typer.Exit(code=1)
176
+
177
+ if output_format == CliOutputFormat.JSON:
178
+ run_output = json.dumps(
179
+ {
180
+ "success": res.HasField("run_id"),
181
+ "run-id": res.run_id if res.HasField("run_id") else None,
182
+ "fab-id": fab_id,
183
+ "fab-name": fab_id.rsplit("/", maxsplit=1)[-1],
184
+ "fab-version": fab_version,
185
+ "fab-hash": fab_hash[:8],
186
+ "fab-filename": fab_path,
187
+ }
188
+ )
189
+ restore_output()
190
+ Console().print_json(run_output)
146
191
 
147
192
  if stream:
148
193
  start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
@@ -194,3 +239,15 @@ def _run_without_exec_api(
194
239
  check=True,
195
240
  text=True,
196
241
  )
242
+
243
+
244
+ def _print_json_error(msg: str, e: Union[typer.Exit, Exception]) -> None:
245
+ """Print error message as JSON."""
246
+ Console().print_json(
247
+ json.dumps(
248
+ {
249
+ "success": False,
250
+ "error-message": remove_emojis(str(msg) + "\n" + str(e)),
251
+ }
252
+ )
253
+ )