imspy-predictors 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imspy_predictors/__init__.py +280 -0
- imspy_predictors/ccs/__init__.py +32 -0
- imspy_predictors/ccs/predictors.py +768 -0
- imspy_predictors/ccs/utility.py +84 -0
- imspy_predictors/data_utils.py +589 -0
- imspy_predictors/hashing.py +255 -0
- imspy_predictors/intensity/__init__.py +41 -0
- imspy_predictors/intensity/predictors.py +882 -0
- imspy_predictors/intensity/utility.py +458 -0
- imspy_predictors/ionization/__init__.py +25 -0
- imspy_predictors/ionization/predictors.py +518 -0
- imspy_predictors/koina_models/__init__.py +92 -0
- imspy_predictors/koina_models/access_models.py +371 -0
- imspy_predictors/koina_models/input_filters.py +488 -0
- imspy_predictors/lazy_imports.py +126 -0
- imspy_predictors/losses.py +419 -0
- imspy_predictors/mixture.py +350 -0
- imspy_predictors/models/__init__.py +57 -0
- imspy_predictors/models/heads.py +561 -0
- imspy_predictors/models/transformer.py +317 -0
- imspy_predictors/models/unified.py +608 -0
- imspy_predictors/pretrained/__init__.py +0 -0
- imspy_predictors/pretrained/ccs/test_metrics.json +7 -0
- imspy_predictors/pretrained/charge/test_metrics.json +5 -0
- imspy_predictors/pretrained/hub.py +161 -0
- imspy_predictors/pretrained/rt/test_metrics.json +7 -0
- imspy_predictors/pretrained/tokenizer-ptm.json +1 -0
- imspy_predictors/pretrained/unimod-vocab.json +1055 -0
- imspy_predictors/rt/__init__.py +21 -0
- imspy_predictors/rt/predictors.py +540 -0
- imspy_predictors/training.py +1271 -0
- imspy_predictors/utilities/__init__.py +29 -0
- imspy_predictors/utilities/hf_tokenizers.py +87 -0
- imspy_predictors/utilities/simple_tokenizer.py +312 -0
- imspy_predictors/utilities/tokenizers.py +232 -0
- imspy_predictors/utility.py +328 -0
- imspy_predictors-0.5.0.dist-info/METADATA +110 -0
- imspy_predictors-0.5.0.dist-info/RECORD +39 -0
- imspy_predictors-0.5.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Download and cache pretrained model files from GitHub Releases.
|
|
3
|
+
|
|
4
|
+
Models are downloaded on first use and cached locally at:
|
|
5
|
+
$IMSPY_CACHE_DIR or ~/.cache/imspy/models/v{MODEL_VERSION}/
|
|
6
|
+
|
|
7
|
+
Uses torch.hub.download_url_to_file() for downloads (progress bar included),
|
|
8
|
+
with a fallback to urllib if torch.hub is unavailable.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import shutil
|
|
15
|
+
import tempfile
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
# Version & release metadata
|
|
22
|
+
# ---------------------------------------------------------------------------
|
|
23
|
+
|
|
24
|
+
MODEL_VERSION = "0.5.0"
|
|
25
|
+
GITHUB_RELEASE_TAG = "models-v0.5.0"
|
|
26
|
+
|
|
27
|
+
_RELEASE_BASE = (
|
|
28
|
+
f"https://github.com/theGreatHerrLebert/rustims/releases/download/{GITHUB_RELEASE_TAG}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Maps the *package-relative* path (as used by callers of get_model_path)
|
|
32
|
+
# to its download URL and expected SHA-256 hash.
|
|
33
|
+
MODELS = {
|
|
34
|
+
"ccs/best_model.pt": {
|
|
35
|
+
"filename": "ccs-best_model.pt",
|
|
36
|
+
"sha256": "f06282ffcc071bd046bfd054ac0db78c18a0ce405918046de494b51eb26f75e4",
|
|
37
|
+
},
|
|
38
|
+
"rt/best_model.pt": {
|
|
39
|
+
"filename": "rt-best_model.pt",
|
|
40
|
+
"sha256": "6f318fcfe2d37c1a24b1bea32db6c9fa8eb87d2c545a20e93ae8e422f901ad60",
|
|
41
|
+
},
|
|
42
|
+
"charge/best_model.pt": {
|
|
43
|
+
"filename": "charge-best_model.pt",
|
|
44
|
+
"sha256": "e0e2ab6d43f028718d7e09aa5c2946059d8f67cf6cd165785a6fc2038cf64ecf",
|
|
45
|
+
},
|
|
46
|
+
"intensity/best_model.pt": {
|
|
47
|
+
"filename": "intensity-best_model.pt",
|
|
48
|
+
"sha256": "6849d857bf03d1205716940206ad49f8c3ecbf36e3668429883b468274b9d3fc",
|
|
49
|
+
},
|
|
50
|
+
"pretrained_encoder.pt": {
|
|
51
|
+
"filename": "pretrained_encoder.pt",
|
|
52
|
+
"sha256": "43ccc2f836bf3d81943ddce353ade9628e7d036421ba5b5c182bf163e496385e",
|
|
53
|
+
},
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_cache_dir() -> Path:
|
|
58
|
+
"""Return the local cache directory for pretrained models.
|
|
59
|
+
|
|
60
|
+
Respects the ``IMSPY_CACHE_DIR`` environment variable. Falls back to
|
|
61
|
+
``~/.cache/imspy/models/v{MODEL_VERSION}``.
|
|
62
|
+
"""
|
|
63
|
+
env = os.environ.get("IMSPY_CACHE_DIR")
|
|
64
|
+
if env:
|
|
65
|
+
return Path(env)
|
|
66
|
+
return Path.home() / ".cache" / "imspy" / "models" / f"v{MODEL_VERSION}"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _sha256(path: Path) -> str:
|
|
70
|
+
"""Compute the SHA-256 hex digest of a file."""
|
|
71
|
+
h = hashlib.sha256()
|
|
72
|
+
with open(path, "rb") as f:
|
|
73
|
+
for chunk in iter(lambda: f.read(1 << 20), b""):
|
|
74
|
+
h.update(chunk)
|
|
75
|
+
return h.hexdigest()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _download(url: str, dest: Path) -> None:
|
|
79
|
+
"""Download *url* to *dest*, preferring torch.hub for its progress bar."""
|
|
80
|
+
try:
|
|
81
|
+
from torch.hub import download_url_to_file
|
|
82
|
+
|
|
83
|
+
download_url_to_file(url, str(dest), progress=True)
|
|
84
|
+
except Exception:
|
|
85
|
+
# Fallback: plain urllib (no progress bar, but no extra deps)
|
|
86
|
+
import urllib.request
|
|
87
|
+
|
|
88
|
+
logger.info("Downloading %s (urllib fallback) ...", url)
|
|
89
|
+
urllib.request.urlretrieve(url, str(dest))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def ensure_model(model_name: str) -> Path:
|
|
93
|
+
"""Return the path to a cached model, downloading it first if necessary.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
model_name : str
|
|
98
|
+
Package-relative name, e.g. ``"ccs/best_model.pt"``.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
Path
|
|
103
|
+
Absolute path to the cached ``.pt`` file.
|
|
104
|
+
|
|
105
|
+
Raises
|
|
106
|
+
------
|
|
107
|
+
ValueError
|
|
108
|
+
If *model_name* is not a known model.
|
|
109
|
+
RuntimeError
|
|
110
|
+
If the download succeeds but the SHA-256 check fails.
|
|
111
|
+
"""
|
|
112
|
+
if model_name not in MODELS:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Unknown model '{model_name}'. "
|
|
115
|
+
f"Known models: {sorted(MODELS)}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
meta = MODELS[model_name]
|
|
119
|
+
cache_dir = get_cache_dir()
|
|
120
|
+
cached_path = cache_dir / model_name
|
|
121
|
+
|
|
122
|
+
# Fast path: already cached and hash matches.
|
|
123
|
+
if cached_path.exists():
|
|
124
|
+
digest = _sha256(cached_path)
|
|
125
|
+
if digest == meta["sha256"]:
|
|
126
|
+
return cached_path
|
|
127
|
+
logger.warning(
|
|
128
|
+
"Cached model %s has SHA-256 %s (expected %s); re-downloading.",
|
|
129
|
+
cached_path,
|
|
130
|
+
digest,
|
|
131
|
+
meta["sha256"],
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Download into a temp file in the same filesystem, then atomic-rename.
|
|
135
|
+
cached_path.parent.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
url = f"{_RELEASE_BASE}/{meta['filename']}"
|
|
137
|
+
logger.info("Downloading model '%s' from %s ...", model_name, url)
|
|
138
|
+
|
|
139
|
+
tmp_fd, tmp_path = tempfile.mkstemp(
|
|
140
|
+
dir=str(cached_path.parent), suffix=".download"
|
|
141
|
+
)
|
|
142
|
+
os.close(tmp_fd)
|
|
143
|
+
tmp_path = Path(tmp_path)
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
_download(url, tmp_path)
|
|
147
|
+
|
|
148
|
+
digest = _sha256(tmp_path)
|
|
149
|
+
if digest != meta["sha256"]:
|
|
150
|
+
raise RuntimeError(
|
|
151
|
+
f"SHA-256 mismatch for {model_name}: "
|
|
152
|
+
f"got {digest}, expected {meta['sha256']}"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
shutil.move(str(tmp_path), str(cached_path))
|
|
156
|
+
except BaseException:
|
|
157
|
+
tmp_path.unlink(missing_ok=True)
|
|
158
|
+
raise
|
|
159
|
+
|
|
160
|
+
logger.info("Cached model at %s", cached_path)
|
|
161
|
+
return cached_path
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"{\"class_name\": \"Tokenizer\", \"config\": {\"num_words\": null, \"filters\": \"!\\\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\", \"lower\": false, \"split\": \" \", \"char_level\": false, \"oov_token\": null, \"document_count\": 800315, \"word_counts\": \"{\\\"<START>\\\": 792568, \\\"E\\\": 1150285, \\\"T\\\": 708378, \\\"I\\\": 647675, \\\"D\\\": 814486, \\\"G\\\": 801249, \\\"L\\\": 1243541, \\\"W\\\": 68294, \\\"S\\\": 986226, \\\"A\\\": 970370, \\\"R\\\": 487728, \\\"<END>\\\": 800315, \\\"Q\\\": 624996, \\\"V\\\": 822187, \\\"P\\\": 777177, \\\"H\\\": 356941, \\\"K\\\": 635852, \\\"Y\\\": 323049, \\\"M\\\": 112508, \\\"N\\\": 537145, \\\"F\\\": 437148, \\\"C[UNIMOD:4]\\\": 144034, \\\"M[UNIMOD:35]\\\": 150500, \\\"K[UNIMOD:1363]\\\": 363, \\\"<START>[UNIMOD:1]\\\": 7747, \\\"S[UNIMOD:21]\\\": 23097, \\\"K[UNIMOD:747]\\\": 325, \\\"K[UNIMOD:1849]\\\": 429, \\\"Y[UNIMOD:21]\\\": 233, \\\"K[UNIMOD:64]\\\": 471, \\\"T[UNIMOD:43]\\\": 74, \\\"K[UNIMOD:121]\\\": 589, \\\"R[UNIMOD:34]\\\": 270, \\\"R[UNIMOD:7]\\\": 439, \\\"K[UNIMOD:34]\\\": 524, \\\"T[UNIMOD:21]\\\": 1855, \\\"R[UNIMOD:36]\\\": 344, \\\"K[UNIMOD:1289]\\\": 440, \\\"K[UNIMOD:1]\\\": 297, \\\"K[UNIMOD:1848]\\\": 482, \\\"K[UNIMOD:122]\\\": 348, \\\"C[UNIMOD:312]\\\": 136, \\\"K[UNIMOD:37]\\\": 457, \\\"K[UNIMOD:58]\\\": 180, \\\"S[UNIMOD:43]\\\": 103, \\\"K[UNIMOD:3]\\\": 236, \\\"K[UNIMOD:36]\\\": 441, \\\"Y[UNIMOD:354]\\\": 222}\", \"word_docs\": \"{\\\"E\\\": 571986, \\\"T\\\": 446549, \\\"I\\\": 433228, \\\"A\\\": 506744, \\\"D\\\": 494531, \\\"G\\\": 451662, \\\"S\\\": 516302, \\\"L\\\": 628195, \\\"R\\\": 427701, \\\"<END>\\\": 800315, \\\"W\\\": 64054, \\\"<START>\\\": 792568, \\\"P\\\": 438886, \\\"V\\\": 495048, \\\"Q\\\": 405053, \\\"K\\\": 532020, \\\"H\\\": 285269, \\\"Y\\\": 254978, \\\"N\\\": 371606, \\\"M\\\": 98241, \\\"F\\\": 331117, \\\"C[UNIMOD:4]\\\": 118333, \\\"M[UNIMOD:35]\\\": 130818, \\\"K[UNIMOD:1363]\\\": 362, \\\"<START>[UNIMOD:1]\\\": 7747, \\\"S[UNIMOD:21]\\\": 20402, \\\"K[UNIMOD:747]\\\": 322, \\\"K[UNIMOD:1849]\\\": 427, \\\"Y[UNIMOD:21]\\\": 232, \\\"K[UNIMOD:64]\\\": 461, \\\"T[UNIMOD:43]\\\": 74, \\\"K[UNIMOD:121]\\\": 587, \\\"R[UNIMOD:34]\\\": 270, \\\"R[UNIMOD:7]\\\": 383, \\\"K[UNIMOD:34]\\\": 512, \\\"T[UNIMOD:21]\\\": 1781, \\\"R[UNIMOD:36]\\\": 343, \\\"K[UNIMOD:1289]\\\": 437, \\\"K[UNIMOD:1]\\\": 297, \\\"K[UNIMOD:1848]\\\": 477, \\\"K[UNIMOD:122]\\\": 341, \\\"C[UNIMOD:312]\\\": 134, \\\"K[UNIMOD:37]\\\": 456, \\\"K[UNIMOD:58]\\\": 179, \\\"S[UNIMOD:43]\\\": 103, \\\"K[UNIMOD:3]\\\": 236, \\\"K[UNIMOD:36]\\\": 441, \\\"Y[UNIMOD:354]\\\": 214}\", \"index_docs\": \"{\\\"2\\\": 571986, \\\"11\\\": 446549, \\\"12\\\": 433228, \\\"4\\\": 506744, \\\"6\\\": 494531, \\\"7\\\": 451662, \\\"3\\\": 516302, \\\"1\\\": 628195, \\\"16\\\": 427701, \\\"8\\\": 800315, \\\"23\\\": 64054, \\\"9\\\": 792568, \\\"10\\\": 438886, \\\"5\\\": 495048, \\\"14\\\": 405053, \\\"13\\\": 532020, \\\"18\\\": 285269, \\\"19\\\": 254978, \\\"15\\\": 371606, \\\"22\\\": 98241, \\\"17\\\": 331117, \\\"21\\\": 118333, \\\"20\\\": 130818, \\\"36\\\": 362, \\\"25\\\": 7747, \\\"24\\\": 20402, \\\"39\\\": 322, \\\"35\\\": 427, \\\"43\\\": 232, \\\"30\\\": 461, \\\"48\\\": 74, \\\"27\\\": 587, \\\"41\\\": 270, \\\"34\\\": 383, \\\"28\\\": 512, \\\"26\\\": 1781, \\\"38\\\": 343, \\\"33\\\": 437, \\\"40\\\": 297, \\\"29\\\": 477, \\\"37\\\": 341, \\\"46\\\": 134, \\\"31\\\": 456, \\\"45\\\": 179, \\\"47\\\": 103, \\\"42\\\": 236, \\\"32\\\": 441, \\\"44\\\": 214}\", \"index_word\": \"{\\\"1\\\": \\\"L\\\", \\\"2\\\": \\\"E\\\", \\\"3\\\": \\\"S\\\", \\\"4\\\": \\\"A\\\", \\\"5\\\": \\\"V\\\", \\\"6\\\": \\\"D\\\", \\\"7\\\": \\\"G\\\", \\\"8\\\": \\\"<END>\\\", \\\"9\\\": \\\"<START>\\\", \\\"10\\\": \\\"P\\\", \\\"11\\\": \\\"T\\\", \\\"12\\\": \\\"I\\\", \\\"13\\\": \\\"K\\\", \\\"14\\\": \\\"Q\\\", \\\"15\\\": \\\"N\\\", \\\"16\\\": \\\"R\\\", \\\"17\\\": \\\"F\\\", \\\"18\\\": \\\"H\\\", \\\"19\\\": \\\"Y\\\", \\\"20\\\": \\\"M[UNIMOD:35]\\\", \\\"21\\\": \\\"C[UNIMOD:4]\\\", \\\"22\\\": \\\"M\\\", \\\"23\\\": \\\"W\\\", \\\"24\\\": \\\"S[UNIMOD:21]\\\", \\\"25\\\": \\\"<START>[UNIMOD:1]\\\", \\\"26\\\": \\\"T[UNIMOD:21]\\\", \\\"27\\\": \\\"K[UNIMOD:121]\\\", \\\"28\\\": \\\"K[UNIMOD:34]\\\", \\\"29\\\": \\\"K[UNIMOD:1848]\\\", \\\"30\\\": \\\"K[UNIMOD:64]\\\", \\\"31\\\": \\\"K[UNIMOD:37]\\\", \\\"32\\\": \\\"K[UNIMOD:36]\\\", \\\"33\\\": \\\"K[UNIMOD:1289]\\\", \\\"34\\\": \\\"R[UNIMOD:7]\\\", \\\"35\\\": \\\"K[UNIMOD:1849]\\\", \\\"36\\\": \\\"K[UNIMOD:1363]\\\", \\\"37\\\": \\\"K[UNIMOD:122]\\\", \\\"38\\\": \\\"R[UNIMOD:36]\\\", \\\"39\\\": \\\"K[UNIMOD:747]\\\", \\\"40\\\": \\\"K[UNIMOD:1]\\\", \\\"41\\\": \\\"R[UNIMOD:34]\\\", \\\"42\\\": \\\"K[UNIMOD:3]\\\", \\\"43\\\": \\\"Y[UNIMOD:21]\\\", \\\"44\\\": \\\"Y[UNIMOD:354]\\\", \\\"45\\\": \\\"K[UNIMOD:58]\\\", \\\"46\\\": \\\"C[UNIMOD:312]\\\", \\\"47\\\": \\\"S[UNIMOD:43]\\\", \\\"48\\\": \\\"T[UNIMOD:43]\\\"}\", \"word_index\": \"{\\\"L\\\": 1, \\\"E\\\": 2, \\\"S\\\": 3, \\\"A\\\": 4, \\\"V\\\": 5, \\\"D\\\": 6, \\\"G\\\": 7, \\\"<END>\\\": 8, \\\"<START>\\\": 9, \\\"P\\\": 10, \\\"T\\\": 11, \\\"I\\\": 12, \\\"K\\\": 13, \\\"Q\\\": 14, \\\"N\\\": 15, \\\"R\\\": 16, \\\"F\\\": 17, \\\"H\\\": 18, \\\"Y\\\": 19, \\\"M[UNIMOD:35]\\\": 20, \\\"C[UNIMOD:4]\\\": 21, \\\"M\\\": 22, \\\"W\\\": 23, \\\"S[UNIMOD:21]\\\": 24, \\\"<START>[UNIMOD:1]\\\": 25, \\\"T[UNIMOD:21]\\\": 26, \\\"K[UNIMOD:121]\\\": 27, \\\"K[UNIMOD:34]\\\": 28, \\\"K[UNIMOD:1848]\\\": 29, \\\"K[UNIMOD:64]\\\": 30, \\\"K[UNIMOD:37]\\\": 31, \\\"K[UNIMOD:36]\\\": 32, \\\"K[UNIMOD:1289]\\\": 33, \\\"R[UNIMOD:7]\\\": 34, \\\"K[UNIMOD:1849]\\\": 35, \\\"K[UNIMOD:1363]\\\": 36, \\\"K[UNIMOD:122]\\\": 37, \\\"R[UNIMOD:36]\\\": 38, \\\"K[UNIMOD:747]\\\": 39, \\\"K[UNIMOD:1]\\\": 40, \\\"R[UNIMOD:34]\\\": 41, \\\"K[UNIMOD:3]\\\": 42, \\\"Y[UNIMOD:21]\\\": 43, \\\"Y[UNIMOD:354]\\\": 44, \\\"K[UNIMOD:58]\\\": 45, \\\"C[UNIMOD:312]\\\": 46, \\\"S[UNIMOD:43]\\\": 47, \\\"T[UNIMOD:43]\\\": 48}\"}}"
|