apherisfold-cli 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- apherisfold_cli/__init__.py +3 -0
- apherisfold_cli/commands/__init__.py +3 -0
- apherisfold_cli/commands/benchmark.py +310 -0
- apherisfold_cli/commands/fine_tune.py +592 -0
- apherisfold_cli/commands/jobs.py +1943 -0
- apherisfold_cli/commands/login.py +369 -0
- apherisfold_cli/commands/predict.py +510 -0
- apherisfold_cli/commands/weights.py +100 -0
- apherisfold_cli/commands/workflow.py +1373 -0
- apherisfold_cli/core/__init__.py +4 -0
- apherisfold_cli/core/auth.py +1274 -0
- apherisfold_cli/core/config.py +110 -0
- apherisfold_cli/core/context.py +19 -0
- apherisfold_cli/core/errors.py +38 -0
- apherisfold_cli/core/hub_models.py +321 -0
- apherisfold_cli/core/hub_requests.py +545 -0
- apherisfold_cli/core/io.py +53 -0
- apherisfold_cli/core/login_profiles.py +48 -0
- apherisfold_cli/core/model_schema.py +201 -0
- apherisfold_cli/core/predict_inputs.py +164 -0
- apherisfold_cli/core/render.py +25 -0
- apherisfold_cli/core/result_artifacts.py +67 -0
- apherisfold_cli/core/workflow_catalog.py +364 -0
- apherisfold_cli/handlers.py +82 -0
- apherisfold_cli/main.py +119 -0
- apherisfold_cli/version.py +1 -0
- apherisfold_cli-0.1.6.dist-info/METADATA +313 -0
- apherisfold_cli-0.1.6.dist-info/RECORD +32 -0
- apherisfold_cli-0.1.6.dist-info/WHEEL +5 -0
- apherisfold_cli-0.1.6.dist-info/entry_points.txt +2 -0
- apherisfold_cli-0.1.6.dist-info/licenses/LICENSE +21 -0
- apherisfold_cli-0.1.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Annotated, Any
|
|
5
|
+
|
|
6
|
+
import typer
|
|
7
|
+
from rich import box
|
|
8
|
+
from rich.console import Group
|
|
9
|
+
from rich.panel import Panel
|
|
10
|
+
from rich.table import Table
|
|
11
|
+
from rich.text import Text
|
|
12
|
+
|
|
13
|
+
from apherisfold_cli.commands.predict import _resolve_prediction_capability, load_model_params_schema
|
|
14
|
+
from apherisfold_cli.core.config import load_config, resolve_command_config
|
|
15
|
+
from apherisfold_cli.core.errors import EXIT_USAGE, CLIError
|
|
16
|
+
from apherisfold_cli.core.hub_models import resolve_weight
|
|
17
|
+
from apherisfold_cli.core.hub_requests import (
|
|
18
|
+
create_benchmark,
|
|
19
|
+
start_benchmark,
|
|
20
|
+
upload_benchmark_structure,
|
|
21
|
+
upload_benchmark_structure_msa,
|
|
22
|
+
)
|
|
23
|
+
from apherisfold_cli.core.io import ensure_output_dir, write_result, write_status
|
|
24
|
+
from apherisfold_cli.core.model_schema import validate_schema_object
|
|
25
|
+
from apherisfold_cli.core.predict_inputs import load_json_object
|
|
26
|
+
from apherisfold_cli.handlers import get_ctx, run_simple_listing
|
|
27
|
+
|
|
28
|
+
_STRUCTURE_SUFFIXES = {".cif", ".mmcif"}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def benchmark(
|
|
32
|
+
output_dir: Annotated[
|
|
33
|
+
Path,
|
|
34
|
+
typer.Option(
|
|
35
|
+
"--output",
|
|
36
|
+
"--output-dir",
|
|
37
|
+
"-o",
|
|
38
|
+
dir_okay=True,
|
|
39
|
+
file_okay=False,
|
|
40
|
+
help="Directory for workflow artifacts and downloaded outputs.",
|
|
41
|
+
),
|
|
42
|
+
],
|
|
43
|
+
name: Annotated[str, typer.Option("--name", "-n", help="Human-friendly name for this benchmark run.")],
|
|
44
|
+
model: Annotated[str, typer.Option("--model", "-m", help="Installed model identifier, for example 'openfold3'.")],
|
|
45
|
+
weight: Annotated[
|
|
46
|
+
str,
|
|
47
|
+
typer.Option("--weight", "-w", help="Weight version shown in 'apherisfold weights list', for example '3.0.0'."),
|
|
48
|
+
],
|
|
49
|
+
input_paths: Annotated[
|
|
50
|
+
list[Path] | None,
|
|
51
|
+
typer.Option("--input", help="Repeat for CIF files or pass a directory containing CIF files."),
|
|
52
|
+
] = None,
|
|
53
|
+
model_version: Annotated[
|
|
54
|
+
str | None,
|
|
55
|
+
typer.Option(
|
|
56
|
+
"--model-version",
|
|
57
|
+
help="Installed model version/build identifier when selection is ambiguous, for example 'openfold3:1.2.3'.",
|
|
58
|
+
),
|
|
59
|
+
] = None,
|
|
60
|
+
model_params: Annotated[
|
|
61
|
+
str | None, typer.Option("--model-params", help="Hub-style modelParams JSON object, inline or as @file.json.")
|
|
62
|
+
] = None,
|
|
63
|
+
model_param_overrides: dict[str, Any] | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
selected_weight = resolve_weight(model, model_version, weight, task="benchmark")
|
|
66
|
+
selected_capability = "inference"
|
|
67
|
+
if selected_capability not in selected_weight["capabilities"]:
|
|
68
|
+
selected_capability = _resolve_prediction_capability(available_capabilities=selected_weight["capabilities"])
|
|
69
|
+
model_params_schema = load_model_params_schema(
|
|
70
|
+
model_version_id=str(selected_weight["model_version"]),
|
|
71
|
+
weight_version=str(selected_weight["weight"]),
|
|
72
|
+
capability=selected_capability,
|
|
73
|
+
task="benchmark",
|
|
74
|
+
)
|
|
75
|
+
loaded_config = load_config(get_ctx().config_raw)
|
|
76
|
+
resolved = resolve_command_config(
|
|
77
|
+
command_name="benchmark",
|
|
78
|
+
defaults={
|
|
79
|
+
"name": name,
|
|
80
|
+
"model": model,
|
|
81
|
+
"weight": str(selected_weight["weight"]),
|
|
82
|
+
"input": [],
|
|
83
|
+
"model_params": {},
|
|
84
|
+
},
|
|
85
|
+
loaded_config=loaded_config,
|
|
86
|
+
cli_overrides={
|
|
87
|
+
"name": name,
|
|
88
|
+
"model": model,
|
|
89
|
+
"weight": str(selected_weight["weight"]),
|
|
90
|
+
"input": input_paths,
|
|
91
|
+
"model_params": load_json_object(model_params, option_name="--model-params", task="benchmark")
|
|
92
|
+
if model_params is not None
|
|
93
|
+
else None,
|
|
94
|
+
},
|
|
95
|
+
allowed_keys={"name", "model", "weight", "input", "model_params"},
|
|
96
|
+
)
|
|
97
|
+
raw_inputs = resolved.get("input")
|
|
98
|
+
if not isinstance(raw_inputs, list) or not raw_inputs:
|
|
99
|
+
raise CLIError(
|
|
100
|
+
code=EXIT_USAGE,
|
|
101
|
+
category="usage",
|
|
102
|
+
message="Benchmark requires at least one '--input' path.",
|
|
103
|
+
task="benchmark",
|
|
104
|
+
)
|
|
105
|
+
normalized_inputs = [Path(value) for value in raw_inputs]
|
|
106
|
+
resolved_inputs = _resolve_benchmark_inputs(normalized_inputs)
|
|
107
|
+
|
|
108
|
+
ctx = get_ctx()
|
|
109
|
+
ctx.output = output_dir
|
|
110
|
+
ctx.extra["write_errors"] = True
|
|
111
|
+
ensure_output_dir(output_dir, overwrite=ctx.overwrite)
|
|
112
|
+
|
|
113
|
+
resolved_model_params = dict(resolved.get("model_params") or {})
|
|
114
|
+
resolved_model_params.update(model_param_overrides or {})
|
|
115
|
+
validate_schema_object(
|
|
116
|
+
resolved_model_params,
|
|
117
|
+
model_params_schema,
|
|
118
|
+
task="benchmark",
|
|
119
|
+
label="model parameter",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
result_path = (ctx.output / "result.json").expanduser().resolve()
|
|
123
|
+
|
|
124
|
+
if ctx.dry_run:
|
|
125
|
+
result = {
|
|
126
|
+
"schema_version": "1.0",
|
|
127
|
+
"kind": "benchmark",
|
|
128
|
+
"status": "dry-run",
|
|
129
|
+
"request": {
|
|
130
|
+
"name": resolved["name"],
|
|
131
|
+
"model": resolved["model"],
|
|
132
|
+
"weight": resolved["weight"],
|
|
133
|
+
"input": [str(path) for path in normalized_inputs],
|
|
134
|
+
"structures": [str(path) for path in resolved_inputs["structures"]],
|
|
135
|
+
"paired_msas": [
|
|
136
|
+
{"structure": str(structure), "msa": str(msa)} for structure, msa in resolved_inputs["paired_msas"]
|
|
137
|
+
],
|
|
138
|
+
"model_params": resolved_model_params,
|
|
139
|
+
},
|
|
140
|
+
"summary": {
|
|
141
|
+
"queued_structures": len(resolved_inputs["structures"]),
|
|
142
|
+
"paired_msa_count": len(resolved_inputs["paired_msas"]),
|
|
143
|
+
},
|
|
144
|
+
"result_path": str(result_path),
|
|
145
|
+
}
|
|
146
|
+
write_status(ctx, task="benchmark", phase="validation", state="complete")
|
|
147
|
+
write_result(ctx, result)
|
|
148
|
+
run_simple_listing(
|
|
149
|
+
task_name="benchmark",
|
|
150
|
+
payload=result,
|
|
151
|
+
renderable_output=_benchmark_renderable("Benchmark Dry Run", result),
|
|
152
|
+
)
|
|
153
|
+
return
|
|
154
|
+
|
|
155
|
+
write_status(ctx, task="benchmark", phase="creating-benchmark", state="running")
|
|
156
|
+
created = create_benchmark(
|
|
157
|
+
name=str(resolved["name"]),
|
|
158
|
+
model_version_id=str(selected_weight["model_version"]),
|
|
159
|
+
weight_version=str(selected_weight["weight"]),
|
|
160
|
+
model_params=resolved_model_params,
|
|
161
|
+
)
|
|
162
|
+
benchmark_id = str(created["id"])
|
|
163
|
+
|
|
164
|
+
for structure_path in resolved_inputs["structures"]:
|
|
165
|
+
query_id = structure_path.stem
|
|
166
|
+
write_status(ctx, task="benchmark", phase=f"uploading-structure:{query_id}", state="running")
|
|
167
|
+
upload_benchmark_structure(
|
|
168
|
+
benchmark_id=benchmark_id,
|
|
169
|
+
query_id=query_id,
|
|
170
|
+
cif_path=structure_path,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
for structure_path, msa_path in resolved_inputs["paired_msas"]:
|
|
174
|
+
query_id = structure_path.stem
|
|
175
|
+
write_status(ctx, task="benchmark", phase=f"uploading-msa:{query_id}", state="running")
|
|
176
|
+
upload_benchmark_structure_msa(
|
|
177
|
+
benchmark_id=benchmark_id,
|
|
178
|
+
query_id=query_id,
|
|
179
|
+
msa_path=msa_path,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
write_status(ctx, task="benchmark", phase="starting", state="running")
|
|
183
|
+
start_benchmark(benchmark_id)
|
|
184
|
+
|
|
185
|
+
write_status(ctx, task="benchmark", phase="submitted", state="complete")
|
|
186
|
+
result = {
|
|
187
|
+
"schema_version": "1.0",
|
|
188
|
+
"kind": "benchmark",
|
|
189
|
+
"status": "submitted",
|
|
190
|
+
"job_id": benchmark_id,
|
|
191
|
+
"request": created,
|
|
192
|
+
"submission": {
|
|
193
|
+
"name": resolved["name"],
|
|
194
|
+
"model": resolved["model"],
|
|
195
|
+
"weight": resolved["weight"],
|
|
196
|
+
"structures": [str(path) for path in resolved_inputs["structures"]],
|
|
197
|
+
"paired_msas": [
|
|
198
|
+
{"structure": str(structure), "msa": str(msa)} for structure, msa in resolved_inputs["paired_msas"]
|
|
199
|
+
],
|
|
200
|
+
"model_params": resolved_model_params,
|
|
201
|
+
},
|
|
202
|
+
"summary": {
|
|
203
|
+
"queued_structures": len(resolved_inputs["structures"]),
|
|
204
|
+
"paired_msa_count": len(resolved_inputs["paired_msas"]),
|
|
205
|
+
},
|
|
206
|
+
"result_path": str(result_path),
|
|
207
|
+
}
|
|
208
|
+
write_result(ctx, result)
|
|
209
|
+
next_cmd = f"apherisfold workflow benchmark get --id {benchmark_id}"
|
|
210
|
+
run_simple_listing(
|
|
211
|
+
task_name="benchmark",
|
|
212
|
+
payload=result,
|
|
213
|
+
renderable_output=_benchmark_renderable("Benchmark Submitted", result, next_cmd=next_cmd),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _resolve_benchmark_inputs(input_paths: list[Path]) -> dict[str, list]:
|
|
218
|
+
structures: list[Path] = []
|
|
219
|
+
paired_msas: list[tuple[Path, Path]] = []
|
|
220
|
+
seen_structures: set[Path] = set()
|
|
221
|
+
|
|
222
|
+
for raw_path in input_paths:
|
|
223
|
+
path = raw_path.expanduser()
|
|
224
|
+
if not path.exists():
|
|
225
|
+
raise CLIError(
|
|
226
|
+
code=EXIT_USAGE,
|
|
227
|
+
category="usage",
|
|
228
|
+
message=f"Benchmark input path does not exist: {raw_path}",
|
|
229
|
+
task="benchmark",
|
|
230
|
+
)
|
|
231
|
+
if path.is_dir():
|
|
232
|
+
directory_structures = sorted(
|
|
233
|
+
child for child in path.iterdir() if child.is_file() and child.suffix.lower() in _STRUCTURE_SUFFIXES
|
|
234
|
+
)
|
|
235
|
+
if not directory_structures:
|
|
236
|
+
raise CLIError(
|
|
237
|
+
code=EXIT_USAGE,
|
|
238
|
+
category="usage",
|
|
239
|
+
message=f"Benchmark input directory does not contain any CIF files: {raw_path}",
|
|
240
|
+
task="benchmark",
|
|
241
|
+
)
|
|
242
|
+
for structure_path in directory_structures:
|
|
243
|
+
_append_structure(structure_path, structures, paired_msas, seen_structures)
|
|
244
|
+
continue
|
|
245
|
+
if not path.is_file():
|
|
246
|
+
raise CLIError(
|
|
247
|
+
code=EXIT_USAGE,
|
|
248
|
+
category="usage",
|
|
249
|
+
message=f"Benchmark input is not a regular file or directory: {raw_path}",
|
|
250
|
+
task="benchmark",
|
|
251
|
+
)
|
|
252
|
+
if path.suffix.lower() not in _STRUCTURE_SUFFIXES:
|
|
253
|
+
raise CLIError(
|
|
254
|
+
code=EXIT_USAGE,
|
|
255
|
+
category="usage",
|
|
256
|
+
message=f"Benchmark input files must be CIF or mmCIF: {raw_path}",
|
|
257
|
+
task="benchmark",
|
|
258
|
+
)
|
|
259
|
+
_append_structure(path, structures, paired_msas, seen_structures)
|
|
260
|
+
|
|
261
|
+
return {"structures": structures, "paired_msas": paired_msas}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _append_structure(
|
|
265
|
+
structure_path: Path,
|
|
266
|
+
structures: list[Path],
|
|
267
|
+
paired_msas: list[tuple[Path, Path]],
|
|
268
|
+
seen_structures: set[Path],
|
|
269
|
+
) -> None:
|
|
270
|
+
resolved_path = structure_path.resolve()
|
|
271
|
+
if resolved_path in seen_structures:
|
|
272
|
+
return
|
|
273
|
+
seen_structures.add(resolved_path)
|
|
274
|
+
structures.append(structure_path)
|
|
275
|
+
|
|
276
|
+
candidate_msa = structure_path.with_suffix(".a3m")
|
|
277
|
+
if candidate_msa.exists():
|
|
278
|
+
paired_msas.append((structure_path, candidate_msa))
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _benchmark_renderable(title: str, payload: dict[str, Any], *, next_cmd: str | None = None) -> Group:
|
|
282
|
+
submission = payload.get("submission") or payload.get("request") or {}
|
|
283
|
+
summary = payload.get("summary") or {}
|
|
284
|
+
table = Table(box=box.SIMPLE_HEAD, show_header=True, padding=(0, 1), expand=False)
|
|
285
|
+
table.add_column("FIELD", style="bold", no_wrap=True)
|
|
286
|
+
table.add_column("VALUE", overflow="fold")
|
|
287
|
+
if payload.get("job_id"):
|
|
288
|
+
table.add_row("JOB ID", str(payload["job_id"]))
|
|
289
|
+
table.add_row("STATUS", str(payload.get("status", "")))
|
|
290
|
+
if isinstance(submission, dict):
|
|
291
|
+
if submission.get("name"):
|
|
292
|
+
table.add_row("NAME", str(submission["name"]))
|
|
293
|
+
if submission.get("model"):
|
|
294
|
+
table.add_row("MODEL", str(submission["model"]))
|
|
295
|
+
if submission.get("weight"):
|
|
296
|
+
table.add_row("WEIGHT", str(submission["weight"]))
|
|
297
|
+
if isinstance(summary, dict):
|
|
298
|
+
if summary.get("queued_structures") is not None:
|
|
299
|
+
table.add_row("STRUCTURES", str(summary["queued_structures"]))
|
|
300
|
+
if summary.get("paired_msa_count") is not None:
|
|
301
|
+
table.add_row("PAIRED MSAs", str(summary["paired_msa_count"]))
|
|
302
|
+
if payload.get("result_path"):
|
|
303
|
+
table.add_row("RESULT", str(payload["result_path"]))
|
|
304
|
+
panels: list = [Panel(table, title=title, border_style="dim")]
|
|
305
|
+
if next_cmd:
|
|
306
|
+
next_table = Table(box=box.SIMPLE_HEAD, show_header=False, padding=(0, 1), expand=False)
|
|
307
|
+
next_table.add_column("command", style="cyan", overflow="fold")
|
|
308
|
+
next_table.add_row(Text(next_cmd, style="cyan"))
|
|
309
|
+
panels.append(Panel(next_table, title="Next", border_style="dim"))
|
|
310
|
+
return Group(*panels)
|