apherisfold-cli 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from .version import __version__
2
+
3
+ __all__ = ["__version__"]
@@ -0,0 +1,3 @@
1
+ from . import benchmark, fine_tune, jobs, login, predict, weights, workflow
2
+
3
+ __all__ = ["benchmark", "fine_tune", "jobs", "login", "predict", "weights", "workflow"]
@@ -0,0 +1,310 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Annotated, Any
5
+
6
+ import typer
7
+ from rich import box
8
+ from rich.console import Group
9
+ from rich.panel import Panel
10
+ from rich.table import Table
11
+ from rich.text import Text
12
+
13
+ from apherisfold_cli.commands.predict import _resolve_prediction_capability, load_model_params_schema
14
+ from apherisfold_cli.core.config import load_config, resolve_command_config
15
+ from apherisfold_cli.core.errors import EXIT_USAGE, CLIError
16
+ from apherisfold_cli.core.hub_models import resolve_weight
17
+ from apherisfold_cli.core.hub_requests import (
18
+ create_benchmark,
19
+ start_benchmark,
20
+ upload_benchmark_structure,
21
+ upload_benchmark_structure_msa,
22
+ )
23
+ from apherisfold_cli.core.io import ensure_output_dir, write_result, write_status
24
+ from apherisfold_cli.core.model_schema import validate_schema_object
25
+ from apherisfold_cli.core.predict_inputs import load_json_object
26
+ from apherisfold_cli.handlers import get_ctx, run_simple_listing
27
+
28
+ _STRUCTURE_SUFFIXES = {".cif", ".mmcif"}
29
+
30
+
31
+ def benchmark(
32
+ output_dir: Annotated[
33
+ Path,
34
+ typer.Option(
35
+ "--output",
36
+ "--output-dir",
37
+ "-o",
38
+ dir_okay=True,
39
+ file_okay=False,
40
+ help="Directory for workflow artifacts and downloaded outputs.",
41
+ ),
42
+ ],
43
+ name: Annotated[str, typer.Option("--name", "-n", help="Human-friendly name for this benchmark run.")],
44
+ model: Annotated[str, typer.Option("--model", "-m", help="Installed model identifier, for example 'openfold3'.")],
45
+ weight: Annotated[
46
+ str,
47
+ typer.Option("--weight", "-w", help="Weight version shown in 'apherisfold weights list', for example '3.0.0'."),
48
+ ],
49
+ input_paths: Annotated[
50
+ list[Path] | None,
51
+ typer.Option("--input", help="Repeat for CIF files or pass a directory containing CIF files."),
52
+ ] = None,
53
+ model_version: Annotated[
54
+ str | None,
55
+ typer.Option(
56
+ "--model-version",
57
+ help="Installed model version/build identifier when selection is ambiguous, for example 'openfold3:1.2.3'.",
58
+ ),
59
+ ] = None,
60
+ model_params: Annotated[
61
+ str | None, typer.Option("--model-params", help="Hub-style modelParams JSON object, inline or as @file.json.")
62
+ ] = None,
63
+ model_param_overrides: dict[str, Any] | None = None,
64
+ ) -> None:
65
+ selected_weight = resolve_weight(model, model_version, weight, task="benchmark")
66
+ selected_capability = "inference"
67
+ if selected_capability not in selected_weight["capabilities"]:
68
+ selected_capability = _resolve_prediction_capability(available_capabilities=selected_weight["capabilities"])
69
+ model_params_schema = load_model_params_schema(
70
+ model_version_id=str(selected_weight["model_version"]),
71
+ weight_version=str(selected_weight["weight"]),
72
+ capability=selected_capability,
73
+ task="benchmark",
74
+ )
75
+ loaded_config = load_config(get_ctx().config_raw)
76
+ resolved = resolve_command_config(
77
+ command_name="benchmark",
78
+ defaults={
79
+ "name": name,
80
+ "model": model,
81
+ "weight": str(selected_weight["weight"]),
82
+ "input": [],
83
+ "model_params": {},
84
+ },
85
+ loaded_config=loaded_config,
86
+ cli_overrides={
87
+ "name": name,
88
+ "model": model,
89
+ "weight": str(selected_weight["weight"]),
90
+ "input": input_paths,
91
+ "model_params": load_json_object(model_params, option_name="--model-params", task="benchmark")
92
+ if model_params is not None
93
+ else None,
94
+ },
95
+ allowed_keys={"name", "model", "weight", "input", "model_params"},
96
+ )
97
+ raw_inputs = resolved.get("input")
98
+ if not isinstance(raw_inputs, list) or not raw_inputs:
99
+ raise CLIError(
100
+ code=EXIT_USAGE,
101
+ category="usage",
102
+ message="Benchmark requires at least one '--input' path.",
103
+ task="benchmark",
104
+ )
105
+ normalized_inputs = [Path(value) for value in raw_inputs]
106
+ resolved_inputs = _resolve_benchmark_inputs(normalized_inputs)
107
+
108
+ ctx = get_ctx()
109
+ ctx.output = output_dir
110
+ ctx.extra["write_errors"] = True
111
+ ensure_output_dir(output_dir, overwrite=ctx.overwrite)
112
+
113
+ resolved_model_params = dict(resolved.get("model_params") or {})
114
+ resolved_model_params.update(model_param_overrides or {})
115
+ validate_schema_object(
116
+ resolved_model_params,
117
+ model_params_schema,
118
+ task="benchmark",
119
+ label="model parameter",
120
+ )
121
+
122
+ result_path = (ctx.output / "result.json").expanduser().resolve()
123
+
124
+ if ctx.dry_run:
125
+ result = {
126
+ "schema_version": "1.0",
127
+ "kind": "benchmark",
128
+ "status": "dry-run",
129
+ "request": {
130
+ "name": resolved["name"],
131
+ "model": resolved["model"],
132
+ "weight": resolved["weight"],
133
+ "input": [str(path) for path in normalized_inputs],
134
+ "structures": [str(path) for path in resolved_inputs["structures"]],
135
+ "paired_msas": [
136
+ {"structure": str(structure), "msa": str(msa)} for structure, msa in resolved_inputs["paired_msas"]
137
+ ],
138
+ "model_params": resolved_model_params,
139
+ },
140
+ "summary": {
141
+ "queued_structures": len(resolved_inputs["structures"]),
142
+ "paired_msa_count": len(resolved_inputs["paired_msas"]),
143
+ },
144
+ "result_path": str(result_path),
145
+ }
146
+ write_status(ctx, task="benchmark", phase="validation", state="complete")
147
+ write_result(ctx, result)
148
+ run_simple_listing(
149
+ task_name="benchmark",
150
+ payload=result,
151
+ renderable_output=_benchmark_renderable("Benchmark Dry Run", result),
152
+ )
153
+ return
154
+
155
+ write_status(ctx, task="benchmark", phase="creating-benchmark", state="running")
156
+ created = create_benchmark(
157
+ name=str(resolved["name"]),
158
+ model_version_id=str(selected_weight["model_version"]),
159
+ weight_version=str(selected_weight["weight"]),
160
+ model_params=resolved_model_params,
161
+ )
162
+ benchmark_id = str(created["id"])
163
+
164
+ for structure_path in resolved_inputs["structures"]:
165
+ query_id = structure_path.stem
166
+ write_status(ctx, task="benchmark", phase=f"uploading-structure:{query_id}", state="running")
167
+ upload_benchmark_structure(
168
+ benchmark_id=benchmark_id,
169
+ query_id=query_id,
170
+ cif_path=structure_path,
171
+ )
172
+
173
+ for structure_path, msa_path in resolved_inputs["paired_msas"]:
174
+ query_id = structure_path.stem
175
+ write_status(ctx, task="benchmark", phase=f"uploading-msa:{query_id}", state="running")
176
+ upload_benchmark_structure_msa(
177
+ benchmark_id=benchmark_id,
178
+ query_id=query_id,
179
+ msa_path=msa_path,
180
+ )
181
+
182
+ write_status(ctx, task="benchmark", phase="starting", state="running")
183
+ start_benchmark(benchmark_id)
184
+
185
+ write_status(ctx, task="benchmark", phase="submitted", state="complete")
186
+ result = {
187
+ "schema_version": "1.0",
188
+ "kind": "benchmark",
189
+ "status": "submitted",
190
+ "job_id": benchmark_id,
191
+ "request": created,
192
+ "submission": {
193
+ "name": resolved["name"],
194
+ "model": resolved["model"],
195
+ "weight": resolved["weight"],
196
+ "structures": [str(path) for path in resolved_inputs["structures"]],
197
+ "paired_msas": [
198
+ {"structure": str(structure), "msa": str(msa)} for structure, msa in resolved_inputs["paired_msas"]
199
+ ],
200
+ "model_params": resolved_model_params,
201
+ },
202
+ "summary": {
203
+ "queued_structures": len(resolved_inputs["structures"]),
204
+ "paired_msa_count": len(resolved_inputs["paired_msas"]),
205
+ },
206
+ "result_path": str(result_path),
207
+ }
208
+ write_result(ctx, result)
209
+ next_cmd = f"apherisfold workflow benchmark get --id {benchmark_id}"
210
+ run_simple_listing(
211
+ task_name="benchmark",
212
+ payload=result,
213
+ renderable_output=_benchmark_renderable("Benchmark Submitted", result, next_cmd=next_cmd),
214
+ )
215
+
216
+
217
+ def _resolve_benchmark_inputs(input_paths: list[Path]) -> dict[str, list]:
218
+ structures: list[Path] = []
219
+ paired_msas: list[tuple[Path, Path]] = []
220
+ seen_structures: set[Path] = set()
221
+
222
+ for raw_path in input_paths:
223
+ path = raw_path.expanduser()
224
+ if not path.exists():
225
+ raise CLIError(
226
+ code=EXIT_USAGE,
227
+ category="usage",
228
+ message=f"Benchmark input path does not exist: {raw_path}",
229
+ task="benchmark",
230
+ )
231
+ if path.is_dir():
232
+ directory_structures = sorted(
233
+ child for child in path.iterdir() if child.is_file() and child.suffix.lower() in _STRUCTURE_SUFFIXES
234
+ )
235
+ if not directory_structures:
236
+ raise CLIError(
237
+ code=EXIT_USAGE,
238
+ category="usage",
239
+ message=f"Benchmark input directory does not contain any CIF files: {raw_path}",
240
+ task="benchmark",
241
+ )
242
+ for structure_path in directory_structures:
243
+ _append_structure(structure_path, structures, paired_msas, seen_structures)
244
+ continue
245
+ if not path.is_file():
246
+ raise CLIError(
247
+ code=EXIT_USAGE,
248
+ category="usage",
249
+ message=f"Benchmark input is not a regular file or directory: {raw_path}",
250
+ task="benchmark",
251
+ )
252
+ if path.suffix.lower() not in _STRUCTURE_SUFFIXES:
253
+ raise CLIError(
254
+ code=EXIT_USAGE,
255
+ category="usage",
256
+ message=f"Benchmark input files must be CIF or mmCIF: {raw_path}",
257
+ task="benchmark",
258
+ )
259
+ _append_structure(path, structures, paired_msas, seen_structures)
260
+
261
+ return {"structures": structures, "paired_msas": paired_msas}
262
+
263
+
264
+ def _append_structure(
265
+ structure_path: Path,
266
+ structures: list[Path],
267
+ paired_msas: list[tuple[Path, Path]],
268
+ seen_structures: set[Path],
269
+ ) -> None:
270
+ resolved_path = structure_path.resolve()
271
+ if resolved_path in seen_structures:
272
+ return
273
+ seen_structures.add(resolved_path)
274
+ structures.append(structure_path)
275
+
276
+ candidate_msa = structure_path.with_suffix(".a3m")
277
+ if candidate_msa.exists():
278
+ paired_msas.append((structure_path, candidate_msa))
279
+
280
+
281
+ def _benchmark_renderable(title: str, payload: dict[str, Any], *, next_cmd: str | None = None) -> Group:
282
+ submission = payload.get("submission") or payload.get("request") or {}
283
+ summary = payload.get("summary") or {}
284
+ table = Table(box=box.SIMPLE_HEAD, show_header=True, padding=(0, 1), expand=False)
285
+ table.add_column("FIELD", style="bold", no_wrap=True)
286
+ table.add_column("VALUE", overflow="fold")
287
+ if payload.get("job_id"):
288
+ table.add_row("JOB ID", str(payload["job_id"]))
289
+ table.add_row("STATUS", str(payload.get("status", "")))
290
+ if isinstance(submission, dict):
291
+ if submission.get("name"):
292
+ table.add_row("NAME", str(submission["name"]))
293
+ if submission.get("model"):
294
+ table.add_row("MODEL", str(submission["model"]))
295
+ if submission.get("weight"):
296
+ table.add_row("WEIGHT", str(submission["weight"]))
297
+ if isinstance(summary, dict):
298
+ if summary.get("queued_structures") is not None:
299
+ table.add_row("STRUCTURES", str(summary["queued_structures"]))
300
+ if summary.get("paired_msa_count") is not None:
301
+ table.add_row("PAIRED MSAs", str(summary["paired_msa_count"]))
302
+ if payload.get("result_path"):
303
+ table.add_row("RESULT", str(payload["result_path"]))
304
+ panels: list = [Panel(table, title=title, border_style="dim")]
305
+ if next_cmd:
306
+ next_table = Table(box=box.SIMPLE_HEAD, show_header=False, padding=(0, 1), expand=False)
307
+ next_table.add_column("command", style="cyan", overflow="fold")
308
+ next_table.add_row(Text(next_cmd, style="cyan"))
309
+ panels.append(Panel(next_table, title="Next", border_style="dim"))
310
+ return Group(*panels)