ezscreen 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezscreen/__init__.py +1 -0
- ezscreen/admet/__init__.py +0 -0
- ezscreen/admet/filter.py +175 -0
- ezscreen/auth.py +238 -0
- ezscreen/backends/__init__.py +0 -0
- ezscreen/backends/kaggle/__init__.py +0 -0
- ezscreen/backends/kaggle/dataset.py +124 -0
- ezscreen/backends/kaggle/kernel.py +91 -0
- ezscreen/backends/kaggle/poller.py +126 -0
- ezscreen/backends/kaggle/runner.py +265 -0
- ezscreen/backends/kaggle/templates/vina_shard.ipynb.j2 +395 -0
- ezscreen/checkpoint.py +187 -0
- ezscreen/cli.py +173 -0
- ezscreen/commands/admet.py +63 -0
- ezscreen/commands/auth.py +8 -0
- ezscreen/commands/run.py +563 -0
- ezscreen/commands/status.py +97 -0
- ezscreen/commands/validate.py +67 -0
- ezscreen/commands/view.py +119 -0
- ezscreen/config.py +77 -0
- ezscreen/errors.py +133 -0
- ezscreen/pocket/__init__.py +0 -0
- ezscreen/pocket/detect.py +226 -0
- ezscreen/prep/__init__.py +0 -0
- ezscreen/prep/ligands.py +248 -0
- ezscreen/prep/receptor.py +305 -0
- ezscreen/report.py +179 -0
- ezscreen/results/__init__.py +0 -0
- ezscreen/results/merger.py +103 -0
- ezscreen/state.py +9 -0
- ezscreen/vendor/__init__.py +1 -0
- ezscreen/vendor/scrubber/__init__.py +38 -0
- ezscreen/version_check.py +77 -0
- ezscreen-0.1.0.dist-info/METADATA +121 -0
- ezscreen-0.1.0.dist-info/RECORD +37 -0
- ezscreen-0.1.0.dist-info/WHEEL +4 -0
- ezscreen-0.1.0.dist-info/entry_points.txt +2 -0
ezscreen/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.2"
|
|
File without changes
|
ezscreen/admet/filter.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
V1_DISCLAIMER = (
|
|
7
|
+
"v1 ADMET is rule-based only — not predictive. "
|
|
8
|
+
"Results reflect simple physicochemical filters, not biological activity."
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
# ---------------------------------------------------------------------------
|
|
12
|
+
# Filter definitions
|
|
13
|
+
# ---------------------------------------------------------------------------
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class FilterConfig:
|
|
17
|
+
lipinski: bool = True # Lipinski Rule of Five
|
|
18
|
+
pains: bool = True # PAINS alerts
|
|
19
|
+
toxicophores: bool = True # basic toxicophore patterns
|
|
20
|
+
veber: bool = True # Veber oral bioavailability
|
|
21
|
+
egan_bbb: bool = False # Egan BBB (off by default — most VS targets aren't CNS)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class FilterResult:
|
|
26
|
+
passed: bool
|
|
27
|
+
failures: list[str] = field(default_factory=list)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ---------------------------------------------------------------------------
|
|
31
|
+
# Individual filters
|
|
32
|
+
# ---------------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
def _check_lipinski(mol) -> list[str]:
|
|
35
|
+
from rdkit.Chem.Descriptors import MolWt, MolLogP, NumHDonors, NumHAcceptors
|
|
36
|
+
from rdkit.Chem.rdMolDescriptors import CalcNumHBD, CalcNumHBA
|
|
37
|
+
failures = []
|
|
38
|
+
mw = MolWt(mol)
|
|
39
|
+
lp = MolLogP(mol)
|
|
40
|
+
hbd = CalcNumHBD(mol)
|
|
41
|
+
hba = CalcNumHBA(mol)
|
|
42
|
+
if mw > 500: failures.append(f"MW {mw:.1f} > 500")
|
|
43
|
+
if lp > 5: failures.append(f"LogP {lp:.2f} > 5")
|
|
44
|
+
if hbd > 5: failures.append(f"HBD {hbd} > 5")
|
|
45
|
+
if hba > 10: failures.append(f"HBA {hba} > 10")
|
|
46
|
+
return failures
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _check_pains(mol) -> list[str]:
|
|
50
|
+
from rdkit.Chem import FilterCatalog
|
|
51
|
+
params = FilterCatalog.FilterCatalogParams()
|
|
52
|
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS)
|
|
53
|
+
catalog = FilterCatalog.FilterCatalog(params)
|
|
54
|
+
entry = catalog.GetFirstMatch(mol)
|
|
55
|
+
if entry:
|
|
56
|
+
return [f"PAINS alert: {entry.GetDescription()}"]
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _check_toxicophores(mol) -> list[str]:
|
|
61
|
+
from rdkit.Chem import FilterCatalog
|
|
62
|
+
params = FilterCatalog.FilterCatalogParams()
|
|
63
|
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.BRENK)
|
|
64
|
+
catalog = FilterCatalog.FilterCatalog(params)
|
|
65
|
+
entry = catalog.GetFirstMatch(mol)
|
|
66
|
+
if entry:
|
|
67
|
+
return [f"Toxicophore: {entry.GetDescription()}"]
|
|
68
|
+
return []
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _check_veber(mol) -> list[str]:
|
|
72
|
+
from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcNumRotatableBonds
|
|
73
|
+
failures = []
|
|
74
|
+
tpsa = CalcTPSA(mol)
|
|
75
|
+
rotb = CalcNumRotatableBonds(mol)
|
|
76
|
+
if tpsa > 140: failures.append(f"TPSA {tpsa:.1f} > 140 Ų")
|
|
77
|
+
if rotb > 10: failures.append(f"RotBonds {rotb} > 10")
|
|
78
|
+
return failures
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _check_egan_bbb(mol) -> list[str]:
|
|
82
|
+
from rdkit.Chem.Descriptors import MolLogP
|
|
83
|
+
from rdkit.Chem.rdMolDescriptors import CalcTPSA
|
|
84
|
+
failures = []
|
|
85
|
+
lp = MolLogP(mol)
|
|
86
|
+
tpsa = CalcTPSA(mol)
|
|
87
|
+
if not (-1 <= lp <= 6): failures.append(f"Egan BBB: LogP {lp:.2f} out of [-1, 6]")
|
|
88
|
+
if not (0 <= tpsa <= 131): failures.append(f"Egan BBB: TPSA {tpsa:.1f} out of [0, 131]")
|
|
89
|
+
return failures
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# ---------------------------------------------------------------------------
|
|
93
|
+
# Main filter function
|
|
94
|
+
# ---------------------------------------------------------------------------
|
|
95
|
+
|
|
96
|
+
def filter_mol(mol, cfg: FilterConfig) -> FilterResult:
|
|
97
|
+
all_failures: list[str] = []
|
|
98
|
+
|
|
99
|
+
if cfg.lipinski:
|
|
100
|
+
all_failures.extend(_check_lipinski(mol))
|
|
101
|
+
if cfg.pains:
|
|
102
|
+
all_failures.extend(_check_pains(mol))
|
|
103
|
+
if cfg.toxicophores:
|
|
104
|
+
all_failures.extend(_check_toxicophores(mol))
|
|
105
|
+
if cfg.veber:
|
|
106
|
+
all_failures.extend(_check_veber(mol))
|
|
107
|
+
if cfg.egan_bbb:
|
|
108
|
+
all_failures.extend(_check_egan_bbb(mol))
|
|
109
|
+
|
|
110
|
+
return FilterResult(passed=len(all_failures) == 0, failures=all_failures)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def filter_library(
|
|
114
|
+
input_path: str,
|
|
115
|
+
output_path: str,
|
|
116
|
+
cfg: FilterConfig | None = None,
|
|
117
|
+
) -> dict[str, Any]:
|
|
118
|
+
"""
|
|
119
|
+
Filter an SDF or SMILES file. Returns a summary dict for the prep report.
|
|
120
|
+
Molecules that pass are written to output_path as SDF.
|
|
121
|
+
Molecules that fail are counted by rule.
|
|
122
|
+
"""
|
|
123
|
+
from pathlib import Path as _Path
|
|
124
|
+
from rdkit.Chem import SDMolSupplier, SDWriter, SmilesMolSupplier
|
|
125
|
+
|
|
126
|
+
if cfg is None:
|
|
127
|
+
cfg = FilterConfig()
|
|
128
|
+
|
|
129
|
+
suffix = _Path(input_path).suffix.lower()
|
|
130
|
+
if suffix in (".smi", ".smiles"):
|
|
131
|
+
supplier = SmilesMolSupplier(str(input_path), delimiter="\t ", titleLine=False)
|
|
132
|
+
else:
|
|
133
|
+
supplier = SDMolSupplier(str(input_path), removeHs=False, sanitize=True)
|
|
134
|
+
writer = SDWriter(str(output_path))
|
|
135
|
+
|
|
136
|
+
total = passed = 0
|
|
137
|
+
breakdown: dict[str, int] = {
|
|
138
|
+
"ro5_violations": 0,
|
|
139
|
+
"pains_alerts": 0,
|
|
140
|
+
"toxicophores": 0,
|
|
141
|
+
"veber_violations": 0,
|
|
142
|
+
"egan_bbb": 0,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
for mol in supplier:
|
|
146
|
+
if mol is None:
|
|
147
|
+
continue
|
|
148
|
+
total += 1
|
|
149
|
+
result = filter_mol(mol, cfg)
|
|
150
|
+
if result.passed:
|
|
151
|
+
writer.write(mol)
|
|
152
|
+
passed += 1
|
|
153
|
+
else:
|
|
154
|
+
for f in result.failures:
|
|
155
|
+
fl = f.lower()
|
|
156
|
+
if "mw" in fl or "logp" in fl or "hbd" in fl or "hba" in fl:
|
|
157
|
+
breakdown["ro5_violations"] += 1
|
|
158
|
+
elif "pains" in fl:
|
|
159
|
+
breakdown["pains_alerts"] += 1
|
|
160
|
+
elif "toxicophore" in fl or "brenk" in fl:
|
|
161
|
+
breakdown["toxicophores"] += 1
|
|
162
|
+
elif "tpsa" in fl or "rotbond" in fl:
|
|
163
|
+
breakdown["veber_violations"] += 1
|
|
164
|
+
elif "egan" in fl:
|
|
165
|
+
breakdown["egan_bbb"] += 1
|
|
166
|
+
|
|
167
|
+
writer.close()
|
|
168
|
+
removed = total - passed
|
|
169
|
+
|
|
170
|
+
return {
|
|
171
|
+
"total_input": total,
|
|
172
|
+
"admet_removed": removed,
|
|
173
|
+
"admet_breakdown": breakdown,
|
|
174
|
+
"disclaimer": V1_DISCLAIMER,
|
|
175
|
+
}
|
ezscreen/auth.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import stat
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import questionary
|
|
10
|
+
import requests
|
|
11
|
+
import tomli_w
|
|
12
|
+
import tomllib
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
from rich.panel import Panel
|
|
15
|
+
|
|
16
|
+
from ezscreen.errors import (
|
|
17
|
+
CredentialPermissionError,
|
|
18
|
+
KaggleAuthError,
|
|
19
|
+
NetworkTimeoutError,
|
|
20
|
+
NIMAuthError,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
CREDS_DIR: Path = Path.home() / ".ezscreen"
|
|
24
|
+
CREDS_PATH: Path = CREDS_DIR / "credentials"
|
|
25
|
+
NIM_HEALTH_ENDPOINT = "https://health.api.nvidia.com/v1/biology/mit/diffdock"
|
|
26
|
+
|
|
27
|
+
console = Console()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ---------------------------------------------------------------------------
|
|
31
|
+
# Credential I/O
|
|
32
|
+
# ---------------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
def load_credentials() -> dict[str, Any]:
|
|
35
|
+
if not CREDS_PATH.exists():
|
|
36
|
+
return {}
|
|
37
|
+
with CREDS_PATH.open("rb") as f:
|
|
38
|
+
return tomllib.load(f)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save_credentials(creds: dict[str, Any]) -> None:
|
|
42
|
+
CREDS_DIR.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
with CREDS_PATH.open("wb") as f:
|
|
44
|
+
tomli_w.dump(creds, f)
|
|
45
|
+
try:
|
|
46
|
+
os.chmod(CREDS_PATH, 0o600)
|
|
47
|
+
except OSError:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_kaggle_json_path(creds: dict[str, Any] | None = None) -> Path | None:
|
|
52
|
+
creds = creds or load_credentials()
|
|
53
|
+
raw = creds.get("kaggle_json_path")
|
|
54
|
+
return Path(raw).expanduser() if raw else None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_nim_key(creds: dict[str, Any] | None = None) -> str | None:
|
|
58
|
+
creds = creds or load_credentials()
|
|
59
|
+
return creds.get("nim_api_key") or None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def has_kaggle_credentials() -> bool:
|
|
63
|
+
path = get_kaggle_json_path()
|
|
64
|
+
return path is not None and path.exists()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def has_nim_key() -> bool:
|
|
68
|
+
return get_nim_key() is not None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ---------------------------------------------------------------------------
|
|
72
|
+
# Validation helpers
|
|
73
|
+
# ---------------------------------------------------------------------------
|
|
74
|
+
|
|
75
|
+
def _warn_env_overrides() -> None:
|
|
76
|
+
if os.environ.get("KAGGLE_KEY") or os.environ.get("KAGGLE_USERNAME"):
|
|
77
|
+
console.print(
|
|
78
|
+
"[yellow]⚠ Found KAGGLE_KEY / KAGGLE_USERNAME env vars "
|
|
79
|
+
"— these override your kaggle.json[/yellow]"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _check_json_permissions(path: Path) -> None:
|
|
84
|
+
if os.name == "nt":
|
|
85
|
+
return
|
|
86
|
+
try:
|
|
87
|
+
mode = stat.S_IMODE(os.stat(path).st_mode)
|
|
88
|
+
if mode & 0o177:
|
|
89
|
+
raise CredentialPermissionError(
|
|
90
|
+
f"{path} has insecure permissions ({oct(mode)}). "
|
|
91
|
+
"Fix with: chmod 600 ~/.kaggle/kaggle.json"
|
|
92
|
+
)
|
|
93
|
+
except (OSError, NotImplementedError):
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def validate_kaggle_json(path: Path) -> dict[str, str]:
|
|
98
|
+
if not path.exists():
|
|
99
|
+
raise KaggleAuthError(f"kaggle.json not found at {path}")
|
|
100
|
+
|
|
101
|
+
_check_json_permissions(path)
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
data: dict[str, str] = json.loads(path.read_text())
|
|
105
|
+
except (json.JSONDecodeError, OSError) as exc:
|
|
106
|
+
raise KaggleAuthError(f"kaggle.json is not valid JSON: {exc}") from exc
|
|
107
|
+
|
|
108
|
+
for field in ("username", "key"):
|
|
109
|
+
if field not in data:
|
|
110
|
+
raise KaggleAuthError(f"kaggle.json is missing the '{field}' field")
|
|
111
|
+
|
|
112
|
+
return data
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _live_kaggle_check(kaggle_data: dict[str, str]) -> None:
|
|
116
|
+
import kaggle as kaggle_pkg # lazy import — kaggle is slow to load
|
|
117
|
+
|
|
118
|
+
os.environ.setdefault("KAGGLE_USERNAME", kaggle_data["username"])
|
|
119
|
+
os.environ.setdefault("KAGGLE_KEY", kaggle_data["key"])
|
|
120
|
+
try:
|
|
121
|
+
kaggle_pkg.api.authenticate()
|
|
122
|
+
except Exception as exc:
|
|
123
|
+
msg = str(exc).lower()
|
|
124
|
+
if "401" in msg or "unauthorized" in msg:
|
|
125
|
+
raise KaggleAuthError(
|
|
126
|
+
"API key rejected — go to kaggle.com/settings/account "
|
|
127
|
+
"→ API → Create New Token"
|
|
128
|
+
) from exc
|
|
129
|
+
if "403" in msg or "forbidden" in msg:
|
|
130
|
+
raise KaggleAuthError(
|
|
131
|
+
"Account needs phone verification — "
|
|
132
|
+
"complete at kaggle.com/settings"
|
|
133
|
+
) from exc
|
|
134
|
+
raise KaggleAuthError(str(exc)) from exc
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def validate_nim_key(key: str) -> None:
|
|
138
|
+
try:
|
|
139
|
+
resp = requests.post(
|
|
140
|
+
NIM_HEALTH_ENDPOINT,
|
|
141
|
+
headers={"Authorization": f"Bearer {key}"},
|
|
142
|
+
json={},
|
|
143
|
+
timeout=10,
|
|
144
|
+
)
|
|
145
|
+
except requests.Timeout as exc:
|
|
146
|
+
raise NetworkTimeoutError("NIM endpoint timed out") from exc
|
|
147
|
+
except requests.ConnectionError as exc:
|
|
148
|
+
raise NetworkTimeoutError(f"Could not reach NIM API: {exc}") from exc
|
|
149
|
+
|
|
150
|
+
if resp.status_code == 401:
|
|
151
|
+
raise NIMAuthError(
|
|
152
|
+
"NIM key rejected — get a free key at build.nvidia.com"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# ---------------------------------------------------------------------------
|
|
157
|
+
# Wizard steps
|
|
158
|
+
# ---------------------------------------------------------------------------
|
|
159
|
+
|
|
160
|
+
def _step_kaggle(creds: dict[str, Any]) -> dict[str, Any]:
|
|
161
|
+
_warn_env_overrides()
|
|
162
|
+
|
|
163
|
+
default = get_kaggle_json_path(creds) or Path("~/.kaggle/kaggle.json").expanduser()
|
|
164
|
+
raw = questionary.text("Path to kaggle.json:", default=str(default)).ask()
|
|
165
|
+
if raw is None:
|
|
166
|
+
raise KeyboardInterrupt
|
|
167
|
+
|
|
168
|
+
path = Path(raw).expanduser()
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
kaggle_data = validate_kaggle_json(path)
|
|
172
|
+
except CredentialPermissionError as exc:
|
|
173
|
+
fix = questionary.confirm(f"\n {exc}\n Auto-fix permissions?", default=True).ask()
|
|
174
|
+
if fix:
|
|
175
|
+
os.chmod(path, 0o600)
|
|
176
|
+
kaggle_data = validate_kaggle_json(path)
|
|
177
|
+
|
|
178
|
+
console.print(" [dim]Checking Kaggle API...[/dim]")
|
|
179
|
+
_live_kaggle_check(kaggle_data)
|
|
180
|
+
console.print(f" [green]Kaggle ✓[/green] [dim]{kaggle_data['username']}[/dim]")
|
|
181
|
+
|
|
182
|
+
creds["kaggle_json_path"] = str(path)
|
|
183
|
+
return creds
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _step_nim(creds: dict[str, Any]) -> dict[str, Any]:
|
|
187
|
+
console.print(" [dim]optional — only needed for ezscreen validate[/dim]")
|
|
188
|
+
raw = questionary.password("NIM API key (Enter to skip):", default="").ask()
|
|
189
|
+
if raw is None:
|
|
190
|
+
raise KeyboardInterrupt
|
|
191
|
+
|
|
192
|
+
if not raw.strip():
|
|
193
|
+
console.print(" [dim]NIM — skipped[/dim]")
|
|
194
|
+
return creds
|
|
195
|
+
|
|
196
|
+
console.print(" [dim]Checking NIM API...[/dim]")
|
|
197
|
+
validate_nim_key(raw.strip())
|
|
198
|
+
console.print(" [green]NIM ✓[/green]")
|
|
199
|
+
|
|
200
|
+
creds["nim_api_key"] = raw.strip()
|
|
201
|
+
return creds
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# ---------------------------------------------------------------------------
|
|
205
|
+
# Public wizard entry point
|
|
206
|
+
# ---------------------------------------------------------------------------
|
|
207
|
+
|
|
208
|
+
def run_wizard(update: str | None = None) -> None:
|
|
209
|
+
existing = load_credentials()
|
|
210
|
+
|
|
211
|
+
if existing and update is None:
|
|
212
|
+
choice = questionary.select(
|
|
213
|
+
"Credentials already set. What would you like to update?",
|
|
214
|
+
choices=["Kaggle credentials", "NIM API key", "Both", "← Cancel"],
|
|
215
|
+
).ask()
|
|
216
|
+
if choice is None or choice == "← Cancel":
|
|
217
|
+
return
|
|
218
|
+
update = choice
|
|
219
|
+
|
|
220
|
+
creds = dict(existing)
|
|
221
|
+
run_kaggle = update in (None, "Kaggle credentials", "Both")
|
|
222
|
+
run_nim = update in (None, "NIM API key", "Both")
|
|
223
|
+
|
|
224
|
+
if run_kaggle:
|
|
225
|
+
console.print("\n[bold]Step 1 — Kaggle[/bold]")
|
|
226
|
+
creds = _step_kaggle(creds)
|
|
227
|
+
|
|
228
|
+
if run_nim:
|
|
229
|
+
console.print("\n[bold]Step 2 — NIM[/bold] [dim](optional)[/dim]")
|
|
230
|
+
creds = _step_nim(creds)
|
|
231
|
+
|
|
232
|
+
save_credentials(creds)
|
|
233
|
+
console.print(
|
|
234
|
+
Panel(
|
|
235
|
+
f" [green]Credentials saved[/green] [dim]{CREDS_PATH}[/dim]",
|
|
236
|
+
title="[bold]Done[/bold]",
|
|
237
|
+
)
|
|
238
|
+
)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import shutil
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from ezscreen.errors import (
|
|
9
|
+
KaggleBadRequestError,
|
|
10
|
+
KaggleForbiddenError,
|
|
11
|
+
KaggleNotFoundError,
|
|
12
|
+
KaggleRateLimitError,
|
|
13
|
+
KaggleServerError,
|
|
14
|
+
KaggleUnauthorizedError,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
MANIFEST_PATH = Path.home() / ".ezscreen" / "manifest.json"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
# Helpers
|
|
22
|
+
# ---------------------------------------------------------------------------
|
|
23
|
+
|
|
24
|
+
def _api():
|
|
25
|
+
import kaggle
|
|
26
|
+
kaggle.api.authenticate()
|
|
27
|
+
return kaggle.api
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _handle_error(exc: Exception) -> None:
|
|
31
|
+
msg = str(exc).lower()
|
|
32
|
+
if "401" in msg or "unauthorized" in msg:
|
|
33
|
+
raise KaggleUnauthorizedError(
|
|
34
|
+
"API key rejected — go to kaggle.com/settings → API → Create New Token"
|
|
35
|
+
) from exc
|
|
36
|
+
if "403" in msg or "forbidden" in msg:
|
|
37
|
+
raise KaggleForbiddenError(
|
|
38
|
+
"Account needs phone verification — complete at kaggle.com/settings"
|
|
39
|
+
) from exc
|
|
40
|
+
if "404" in msg or "not found" in msg:
|
|
41
|
+
raise KaggleNotFoundError(str(exc)) from exc
|
|
42
|
+
if "429" in msg or "rate limit" in msg:
|
|
43
|
+
raise KaggleRateLimitError(str(exc)) from exc
|
|
44
|
+
if any(c in msg for c in ("500", "502", "503", "504")):
|
|
45
|
+
raise KaggleServerError(str(exc)) from exc
|
|
46
|
+
raise KaggleBadRequestError(str(exc)) from exc
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def sha256(path: Path) -> str:
|
|
50
|
+
h = hashlib.sha256()
|
|
51
|
+
with path.open("rb") as f:
|
|
52
|
+
for chunk in iter(lambda: f.read(65536), b""):
|
|
53
|
+
h.update(chunk)
|
|
54
|
+
return h.hexdigest()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _load_manifest() -> dict[str, str]:
|
|
58
|
+
if MANIFEST_PATH.exists():
|
|
59
|
+
return json.loads(MANIFEST_PATH.read_text())
|
|
60
|
+
return {}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _save_manifest(m: dict[str, str]) -> None:
|
|
64
|
+
MANIFEST_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
65
|
+
MANIFEST_PATH.write_text(json.dumps(m, indent=2))
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
# Public API
|
|
70
|
+
# ---------------------------------------------------------------------------
|
|
71
|
+
|
|
72
|
+
def upload_run_dataset(
|
|
73
|
+
run_id: str,
|
|
74
|
+
receptor_pdbqt: Path,
|
|
75
|
+
shard_paths: list[Path],
|
|
76
|
+
username: str,
|
|
77
|
+
work_dir: Path,
|
|
78
|
+
) -> str:
|
|
79
|
+
"""
|
|
80
|
+
Upload receptor (skipped if SHA-256 matches) and ligand shards.
|
|
81
|
+
Returns dataset ref: 'username/ezscreen-{run_id}'.
|
|
82
|
+
"""
|
|
83
|
+
api = _api()
|
|
84
|
+
manifest = _load_manifest()
|
|
85
|
+
|
|
86
|
+
dataset_dir = work_dir / f"dataset_{run_id}"
|
|
87
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
|
|
89
|
+
# Receptor — dedup; always uploaded as "receptor.pdbqt" so the notebook
|
|
90
|
+
# template can reference a stable, predictable filename
|
|
91
|
+
receptor_hash = sha256(receptor_pdbqt)
|
|
92
|
+
cache_key = str(receptor_pdbqt.resolve())
|
|
93
|
+
manifest[cache_key] = receptor_hash
|
|
94
|
+
shutil.copy2(receptor_pdbqt, dataset_dir / "receptor.pdbqt")
|
|
95
|
+
|
|
96
|
+
# Shards — always fresh
|
|
97
|
+
for sp in shard_paths:
|
|
98
|
+
shutil.copy2(sp, dataset_dir / sp.name)
|
|
99
|
+
|
|
100
|
+
slug = f"ezscreen-{run_id}"
|
|
101
|
+
meta = {
|
|
102
|
+
"title": f"ezscreen {run_id}",
|
|
103
|
+
"id": f"{username}/{slug}",
|
|
104
|
+
"licenses": [{"name": "other"}],
|
|
105
|
+
}
|
|
106
|
+
(dataset_dir / "dataset-metadata.json").write_text(json.dumps(meta, indent=2))
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
api.dataset_create_new(str(dataset_dir), public=False, quiet=True)
|
|
110
|
+
except Exception as exc:
|
|
111
|
+
_handle_error(exc)
|
|
112
|
+
|
|
113
|
+
_save_manifest(manifest)
|
|
114
|
+
return f"{username}/{slug}"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def delete_run_dataset(run_id: str, username: str) -> None:
|
|
118
|
+
"""Delete a run's Kaggle dataset. Used by ezscreen clean."""
|
|
119
|
+
api = _api()
|
|
120
|
+
slug = f"ezscreen-{run_id}"
|
|
121
|
+
try:
|
|
122
|
+
api.dataset_delete(username, slug)
|
|
123
|
+
except Exception as exc:
|
|
124
|
+
_handle_error(exc)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import shutil
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
from ezscreen.errors import KaggleForbiddenError, KaggleUnauthorizedError
|
|
11
|
+
|
|
12
|
+
console = Console()
|
|
13
|
+
_MAX_RETRIES = 5
|
|
14
|
+
_BACKOFF_BASE = 2
|
|
15
|
+
# 409 = kernel version currently queued/saving — transient lock, safe to retry
|
|
16
|
+
_TRANSIENT_CODES = ("409", "429", "500", "502", "503", "504", "rate")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _api():
|
|
20
|
+
import kaggle
|
|
21
|
+
kaggle.api.authenticate()
|
|
22
|
+
return kaggle.api
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _with_backoff(fn, *args, **kwargs):
|
|
26
|
+
"""Retry transient errors with exponential backoff. Never retries 401/403."""
|
|
27
|
+
for attempt in range(_MAX_RETRIES):
|
|
28
|
+
try:
|
|
29
|
+
return fn(*args, **kwargs)
|
|
30
|
+
except (KaggleUnauthorizedError, KaggleForbiddenError):
|
|
31
|
+
raise
|
|
32
|
+
except Exception as exc:
|
|
33
|
+
msg = str(exc).lower()
|
|
34
|
+
is_transient = any(c in msg for c in _TRANSIENT_CODES)
|
|
35
|
+
if not is_transient or attempt == _MAX_RETRIES - 1:
|
|
36
|
+
raise
|
|
37
|
+
wait = _BACKOFF_BASE ** (attempt + 1)
|
|
38
|
+
console.print(f" [dim]Kaggle error — retrying in {wait}s ({attempt + 1}/{_MAX_RETRIES})[/dim]")
|
|
39
|
+
time.sleep(wait)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def push_kernel(
|
|
43
|
+
run_id: str,
|
|
44
|
+
notebook_path: Path,
|
|
45
|
+
dataset_ref: str,
|
|
46
|
+
username: str,
|
|
47
|
+
work_dir: Path,
|
|
48
|
+
) -> str:
|
|
49
|
+
"""Render and push the notebook to Kaggle. Returns kernel ref."""
|
|
50
|
+
api = _api()
|
|
51
|
+
|
|
52
|
+
kernel_dir = work_dir / f"kernel_{run_id}"
|
|
53
|
+
kernel_dir.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
shutil.copy2(notebook_path, kernel_dir / "notebook.ipynb")
|
|
55
|
+
|
|
56
|
+
# run_id already carries the "ezs-" prefix — use it directly as the slug
|
|
57
|
+
slug = run_id
|
|
58
|
+
# title must slugify to exactly the slug — replace hyphens with spaces so
|
|
59
|
+
# Kaggle's slug derivation round-trips back to the same value
|
|
60
|
+
title = slug.replace("-", " ")
|
|
61
|
+
meta = {
|
|
62
|
+
"id": f"{username}/{slug}",
|
|
63
|
+
"title": title,
|
|
64
|
+
"code_file": "notebook.ipynb",
|
|
65
|
+
"language": "python",
|
|
66
|
+
"kernel_type": "notebook",
|
|
67
|
+
"is_private": True,
|
|
68
|
+
"enable_gpu": True,
|
|
69
|
+
"accelerator": "nvidiaTeslaT4",
|
|
70
|
+
"enable_internet": True,
|
|
71
|
+
"dataset_sources": [dataset_ref],
|
|
72
|
+
"competition_sources": [],
|
|
73
|
+
"kernel_sources": [],
|
|
74
|
+
}
|
|
75
|
+
(kernel_dir / "kernel-metadata.json").write_text(json.dumps(meta, indent=2))
|
|
76
|
+
|
|
77
|
+
def _push():
|
|
78
|
+
api.kernels_push(str(kernel_dir))
|
|
79
|
+
|
|
80
|
+
_with_backoff(_push)
|
|
81
|
+
return f"{username}/{slug}"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def delete_kernel(run_id: str, username: str) -> None:
|
|
85
|
+
"""Delete run kernel. Used by ezscreen clean."""
|
|
86
|
+
api = _api()
|
|
87
|
+
slug = f"ezs-{run_id}"
|
|
88
|
+
try:
|
|
89
|
+
api.kernel_delete(username, slug)
|
|
90
|
+
except Exception:
|
|
91
|
+
pass # best-effort
|