lattifai 0.1.5__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +12 -47
- lattifai/bin/align.py +26 -2
- lattifai/bin/cli_base.py +5 -0
- lattifai/client.py +26 -13
- lattifai/io/reader.py +1 -2
- lattifai/tokenizer/tokenizer.py +284 -0
- lattifai/workers/lattice1_alpha.py +33 -11
- lattifai-0.2.2.dist-info/METADATA +333 -0
- lattifai-0.2.2.dist-info/RECORD +22 -0
- lattifai/tokenizers/tokenizer.py +0 -147
- lattifai-0.1.5.dist-info/METADATA +0 -444
- lattifai-0.1.5.dist-info/RECORD +0 -24
- scripts/__init__.py +0 -1
- scripts/install_k2.py +0 -520
- /lattifai/{tokenizers → tokenizer}/__init__.py +0 -0
- /lattifai/{tokenizers → tokenizer}/phonemizer.py +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/WHEEL +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/entry_points.txt +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/top_level.txt +0 -0
scripts/install_k2.py
DELETED
|
@@ -1,520 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
|
|
4
|
-
"""
|
|
5
|
-
Auto-install the latest k2 wheel that matches the current machine.
|
|
6
|
-
- Prints in English.
|
|
7
|
-
- Sources:
|
|
8
|
-
Linux CUDA wheels: https://k2-fsa.github.io/k2/installation/pre-compiled-cuda-wheels-linux/index.html
|
|
9
|
-
macOS CPU wheels: https://k2-fsa.github.io/k2/installation/pre-compiled-cpu-wheels-macos/index.html
|
|
10
|
-
Windows CPU wheels: https://k2-fsa.github.io/k2/installation/pre-compiled-cpu-wheels-windows/index.html
|
|
11
|
-
|
|
12
|
-
Usage:
|
|
13
|
-
python install_k2_auto.py # install immediately
|
|
14
|
-
python install_k2_auto.py --dry-run # only show what would be installed
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
import argparse
|
|
18
|
-
import os
|
|
19
|
-
import platform
|
|
20
|
-
import re
|
|
21
|
-
import subprocess
|
|
22
|
-
import sys
|
|
23
|
-
import urllib.request
|
|
24
|
-
from html.parser import HTMLParser
|
|
25
|
-
from typing import List, Optional, Tuple
|
|
26
|
-
|
|
27
|
-
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
|
28
|
-
os.environ['OMP_NUM_THREADS'] = '4'
|
|
29
|
-
|
|
30
|
-
CUDA_LINUX_URL = 'https://k2-fsa.github.io/k2/installation/pre-compiled-cuda-wheels-linux/index.html'
|
|
31
|
-
MAC_CPU_URL = 'https://k2-fsa.github.io/k2/installation/pre-compiled-cpu-wheels-macos/index.html'
|
|
32
|
-
WIN_CPU_URL = 'https://k2-fsa.github.io/k2/installation/pre-compiled-cpu-wheels-windows/index.html'
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class WheelLinkParser(HTMLParser):
|
|
36
|
-
def __init__(self, parse_mode='wheels'):
|
|
37
|
-
super().__init__()
|
|
38
|
-
self.links: List[str] = []
|
|
39
|
-
self.parse_mode = parse_mode # 'wheels' or 'versions'
|
|
40
|
-
|
|
41
|
-
def handle_starttag(self, tag, attrs):
|
|
42
|
-
if tag.lower() == 'a':
|
|
43
|
-
href = dict(attrs).get('href')
|
|
44
|
-
if href:
|
|
45
|
-
if self.parse_mode == 'wheels' and href.endswith('.whl'):
|
|
46
|
-
self.links.append(href)
|
|
47
|
-
elif self.parse_mode == 'versions' and re.match(r'^\d+\.\d+\.\d+\.html$', href):
|
|
48
|
-
# Match version links like "2.8.0.html"
|
|
49
|
-
self.links.append(href)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def fetch_wheel_links(
|
|
53
|
-
page_url: str, target_torch_version: Optional[str] = None, cuda_version: Optional[str] = None
|
|
54
|
-
) -> List[str]:
|
|
55
|
-
"""
|
|
56
|
-
Fetch wheel links from k2 pages. The structure is:
|
|
57
|
-
- Index page contains links to version-specific pages (e.g., 2.8.0.html)
|
|
58
|
-
- Version pages contain actual .whl file links
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
page_url: The base URL to fetch wheels from
|
|
62
|
-
target_torch_version: If specified, only fetch wheels for this torch version (e.g., "2.8.0")
|
|
63
|
-
cuda_version: If specified, prefer wheels with this CUDA version (e.g., "12.1")
|
|
64
|
-
"""
|
|
65
|
-
with urllib.request.urlopen(page_url) as resp:
|
|
66
|
-
html = resp.read().decode('utf-8', errors='ignore')
|
|
67
|
-
|
|
68
|
-
# First, try to find version page links
|
|
69
|
-
version_parser = WheelLinkParser(parse_mode='versions')
|
|
70
|
-
version_parser.feed(html)
|
|
71
|
-
|
|
72
|
-
if version_parser.links:
|
|
73
|
-
# If we found version links, this is an index page
|
|
74
|
-
# Filter version links if target_torch_version is specified
|
|
75
|
-
version_links_to_process = version_parser.links
|
|
76
|
-
if target_torch_version:
|
|
77
|
-
target_filename = f'{target_torch_version}.html'
|
|
78
|
-
version_links_to_process = [link for link in version_parser.links if link == target_filename]
|
|
79
|
-
if not version_links_to_process:
|
|
80
|
-
print(f'[WARN] No page found for torch version {target_torch_version}')
|
|
81
|
-
return []
|
|
82
|
-
print(f'[INFO] Found torch version {target_torch_version}, fetching wheels from {target_filename}')
|
|
83
|
-
else:
|
|
84
|
-
# If no target version specified, choose the highest version (latest torch version)
|
|
85
|
-
def parse_version_from_link(link: str) -> Tuple[int, int, int]:
|
|
86
|
-
# Extract version from "2.8.0.html" -> (2, 8, 0)
|
|
87
|
-
match = re.match(r'^(\d+)\.(\d+)\.(\d+)\.html$', link)
|
|
88
|
-
if match:
|
|
89
|
-
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
|
|
90
|
-
return (0, 0, 0)
|
|
91
|
-
|
|
92
|
-
# Sort by version and take the highest
|
|
93
|
-
sorted_versions = sorted(version_parser.links, key=parse_version_from_link, reverse=True)
|
|
94
|
-
if sorted_versions:
|
|
95
|
-
latest_version = sorted_versions[0]
|
|
96
|
-
version_links_to_process = [latest_version]
|
|
97
|
-
version_str = latest_version.replace('.html', '')
|
|
98
|
-
print(f'[INFO] No target torch version specified, using latest version: {version_str}')
|
|
99
|
-
|
|
100
|
-
# Fetch wheel links from version pages
|
|
101
|
-
all_wheel_links = []
|
|
102
|
-
base_url = page_url.rsplit('/', 1)[0]
|
|
103
|
-
py_tag, abi_tag = py_tags()
|
|
104
|
-
|
|
105
|
-
for version_link in version_links_to_process:
|
|
106
|
-
version_url = f'{base_url}/{version_link}'
|
|
107
|
-
try:
|
|
108
|
-
with urllib.request.urlopen(version_url) as resp:
|
|
109
|
-
version_html = resp.read().decode('utf-8', errors='ignore')
|
|
110
|
-
wheel_parser = WheelLinkParser(parse_mode='wheels')
|
|
111
|
-
wheel_parser.feed(version_html)
|
|
112
|
-
|
|
113
|
-
# If target version specified or using latest version, find matching wheels
|
|
114
|
-
if target_torch_version or len(version_links_to_process) == 1:
|
|
115
|
-
matching_wheels = []
|
|
116
|
-
for wheel_link in wheel_parser.links:
|
|
117
|
-
if py_tag in wheel_link and abi_tag in wheel_link:
|
|
118
|
-
matching_wheels.append(wheel_link)
|
|
119
|
-
|
|
120
|
-
if cuda_version and matching_wheels:
|
|
121
|
-
# First try to find wheels with the specified CUDA version
|
|
122
|
-
cuda_specific_wheels = []
|
|
123
|
-
for wheel_link in matching_wheels:
|
|
124
|
-
wheel_cuda = parse_cuda_from_filename(wheel_link)
|
|
125
|
-
if wheel_cuda and wheel_cuda == cuda_version:
|
|
126
|
-
cuda_specific_wheels.append(wheel_link)
|
|
127
|
-
|
|
128
|
-
if cuda_specific_wheels:
|
|
129
|
-
# Found wheels with specified CUDA version, pick the latest one by dev date
|
|
130
|
-
def sort_by_devdate(wheel: str) -> int:
|
|
131
|
-
return parse_devdate(wheel) or 0
|
|
132
|
-
|
|
133
|
-
best_wheel = max(cuda_specific_wheels, key=sort_by_devdate)
|
|
134
|
-
print(
|
|
135
|
-
f'[INFO] Found matching wheel for Python {py_tag} and CUDA {cuda_version}: {best_wheel}'
|
|
136
|
-
)
|
|
137
|
-
return [best_wheel] if best_wheel.startswith('http') else [best_wheel]
|
|
138
|
-
else:
|
|
139
|
-
print(f'[WARN] No wheel found for CUDA {cuda_version}, falling back to latest version')
|
|
140
|
-
|
|
141
|
-
# If no CUDA version specified or no matching CUDA wheels found, use the latest wheel
|
|
142
|
-
if matching_wheels:
|
|
143
|
-
|
|
144
|
-
def sort_by_devdate(wheel: str) -> int:
|
|
145
|
-
return parse_devdate(wheel) or 0
|
|
146
|
-
|
|
147
|
-
best_wheel = max(matching_wheels, key=sort_by_devdate)
|
|
148
|
-
cuda_info = (
|
|
149
|
-
f' (CUDA {parse_cuda_from_filename(best_wheel)})'
|
|
150
|
-
if parse_cuda_from_filename(best_wheel)
|
|
151
|
-
else ''
|
|
152
|
-
)
|
|
153
|
-
print(f'[INFO] Found matching wheel for Python {py_tag}{cuda_info}: {best_wheel}')
|
|
154
|
-
return [best_wheel]
|
|
155
|
-
|
|
156
|
-
version_str = version_link.replace('.html', '')
|
|
157
|
-
print(f'[WARN] No wheel found for Python {py_tag} in torch {version_str}')
|
|
158
|
-
else:
|
|
159
|
-
all_wheel_links.extend(wheel_parser.links)
|
|
160
|
-
except Exception as e:
|
|
161
|
-
print(f'[WARN] Failed to fetch {version_url}: {e}')
|
|
162
|
-
continue
|
|
163
|
-
|
|
164
|
-
# If target version specified or latest version but no matching wheel found
|
|
165
|
-
if target_torch_version or len(version_links_to_process) == 1:
|
|
166
|
-
return []
|
|
167
|
-
|
|
168
|
-
# Normalize to absolute URLs for all wheels case
|
|
169
|
-
abs_links = []
|
|
170
|
-
for href in all_wheel_links:
|
|
171
|
-
if href.startswith('http://') or href.startswith('https://'):
|
|
172
|
-
abs_links.append(href)
|
|
173
|
-
else:
|
|
174
|
-
# For huggingface links, they are already absolute in the href
|
|
175
|
-
abs_links.append(href)
|
|
176
|
-
return abs_links
|
|
177
|
-
|
|
178
|
-
else:
|
|
179
|
-
raise ValueError('No version links found on the page; unexpected page structure.')
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
def py_tags() -> Tuple[str, str]:
|
|
183
|
-
"""Return (py_tag, abi_tag), e.g. ('cp310', 'cp310') for CPython."""
|
|
184
|
-
impl = platform.python_implementation().lower()
|
|
185
|
-
if impl != 'cpython':
|
|
186
|
-
# Wheels are for CPython; still try cpXY
|
|
187
|
-
pass
|
|
188
|
-
major, minor = sys.version_info.major, sys.version_info.minor
|
|
189
|
-
tag = f'cp{major}{minor}'
|
|
190
|
-
return tag, tag
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def detect_torch_version() -> Optional[str]:
|
|
194
|
-
"""
|
|
195
|
-
Detect installed PyTorch version string like '2.8.0'.
|
|
196
|
-
Returns None if PyTorch is not installed.
|
|
197
|
-
"""
|
|
198
|
-
try:
|
|
199
|
-
import importlib
|
|
200
|
-
|
|
201
|
-
torch = importlib.import_module('torch')
|
|
202
|
-
version = getattr(torch, '__version__', None)
|
|
203
|
-
if version:
|
|
204
|
-
# Extract major.minor.patch from version string (remove +cu118 etc suffixes)
|
|
205
|
-
version_match = re.match(r'(\d+\.\d+\.\d+)', str(version))
|
|
206
|
-
if version_match:
|
|
207
|
-
return version_match.group(1)
|
|
208
|
-
except Exception:
|
|
209
|
-
pass
|
|
210
|
-
return None
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
def detect_cuda_version_linux() -> Optional[str]:
|
|
214
|
-
"""
|
|
215
|
-
Detect CUDA version string like '12.1'.
|
|
216
|
-
Priority: torch.version.cuda -> nvidia-smi -> None
|
|
217
|
-
"""
|
|
218
|
-
# Try PyTorch if installed
|
|
219
|
-
try:
|
|
220
|
-
import importlib
|
|
221
|
-
|
|
222
|
-
torch = importlib.import_module('torch')
|
|
223
|
-
v = getattr(getattr(torch, 'version', None), 'cuda', None)
|
|
224
|
-
if v:
|
|
225
|
-
return str(v)
|
|
226
|
-
except Exception:
|
|
227
|
-
pass
|
|
228
|
-
|
|
229
|
-
# # Try nvidia-smi
|
|
230
|
-
# try:
|
|
231
|
-
# out = subprocess.check_output(["nvidia-smi"], stderr=subprocess.STDOUT, text=True)
|
|
232
|
-
# m = re.search(r"CUDA Version:\s*([\d.]+)", out)
|
|
233
|
-
# if m:
|
|
234
|
-
# return m.group(1)
|
|
235
|
-
# except Exception:
|
|
236
|
-
# pass
|
|
237
|
-
|
|
238
|
-
return None
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def parse_cuda_from_filename(name: str) -> Optional[str]:
|
|
242
|
-
# e.g., ...+cuda12.1-..., ...+cuda11.8-...
|
|
243
|
-
m = re.search(r'cuda(\d+(?:\.\d+)?)', name)
|
|
244
|
-
return m.group(1) if m else None
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def parse_devdate(name: str) -> Optional[int]:
|
|
248
|
-
# e.g., dev20240606
|
|
249
|
-
m = re.search(r'dev(\d{8})', name)
|
|
250
|
-
return int(m.group(1)) if m else None
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
def parse_version_tuple(name: str) -> Tuple[int, ...]:
|
|
254
|
-
# k2-<version>... take first contiguous version-like sequence
|
|
255
|
-
m = re.search(r'k2-([\d]+(?:\.[\d]+)*)', name)
|
|
256
|
-
if not m:
|
|
257
|
-
return tuple()
|
|
258
|
-
return tuple(int(p) for p in m.group(1).split('.'))
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
def best_match_cuda(candidates: List[str], installed_cuda: Optional[str]) -> List[str]:
|
|
262
|
-
"""
|
|
263
|
-
Keep only CUDA wheels; if installed_cuda present, prefer same major.minor,
|
|
264
|
-
else fallback to highest CUDA version available.
|
|
265
|
-
"""
|
|
266
|
-
cuda_wheels = [w for w in candidates if 'cuda' in w.lower()]
|
|
267
|
-
if not cuda_wheels:
|
|
268
|
-
return []
|
|
269
|
-
|
|
270
|
-
if installed_cuda:
|
|
271
|
-
# Normalize like '12.1' -> (12,1)
|
|
272
|
-
def to_tuple(v: str) -> Tuple[int, int]:
|
|
273
|
-
parts = v.split('.')
|
|
274
|
-
major = int(parts[0])
|
|
275
|
-
minor = int(parts[1]) if len(parts) > 1 else 0
|
|
276
|
-
return (major, minor)
|
|
277
|
-
|
|
278
|
-
target = to_tuple(installed_cuda)
|
|
279
|
-
|
|
280
|
-
# Score by distance in (major, minor); prefer exact or closest lower/higher
|
|
281
|
-
def score(w: str) -> Tuple[int, int, int]:
|
|
282
|
-
wc = parse_cuda_from_filename(w) or '0'
|
|
283
|
-
wt = to_tuple(wc)
|
|
284
|
-
# absolute distance
|
|
285
|
-
dist = (abs(wt[0] - target[0]) * 100) + abs(wt[1] - target[1])
|
|
286
|
-
# Prefer same major, then higher minor not exceeding target, etc.
|
|
287
|
-
bias = 0 if wt[0] == target[0] else 1
|
|
288
|
-
# Negative if <= target to prefer not exceeding
|
|
289
|
-
not_exceed = 0 if (wt <= target) else 1
|
|
290
|
-
return (dist, bias, not_exceed)
|
|
291
|
-
|
|
292
|
-
cuda_wheels.sort(key=score)
|
|
293
|
-
# Keep top-N that share the best CUDA version string (for later date/version sorting)
|
|
294
|
-
best_cuda = parse_cuda_from_filename(cuda_wheels[0])
|
|
295
|
-
cuda_wheels = [w for w in cuda_wheels if parse_cuda_from_filename(w) == best_cuda]
|
|
296
|
-
return cuda_wheels
|
|
297
|
-
|
|
298
|
-
# No installed CUDA detected: pick the highest CUDA in page (by version tuple)
|
|
299
|
-
def cudatuple(w: str) -> Tuple[int, int]:
|
|
300
|
-
c = parse_cuda_from_filename(w) or '0'
|
|
301
|
-
parts = c.split('.')
|
|
302
|
-
major = int(parts[0])
|
|
303
|
-
minor = int(parts[1]) if len(parts) > 1 else 0
|
|
304
|
-
return (major, minor)
|
|
305
|
-
|
|
306
|
-
cuda_wheels.sort(key=cudatuple, reverse=True)
|
|
307
|
-
top = parse_cuda_from_filename(cuda_wheels[0])
|
|
308
|
-
return [w for w in cuda_wheels if parse_cuda_from_filename(w) == top]
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
def platform_tag_filters() -> List[str]:
|
|
312
|
-
system = platform.system().lower()
|
|
313
|
-
machine = platform.machine().lower()
|
|
314
|
-
|
|
315
|
-
if system == 'linux':
|
|
316
|
-
# Manylinux tags typically include 'linux_x86_64' or 'manylinux...' but
|
|
317
|
-
# the page often lists 'linux_x86_64'. We'll match the common substrings.
|
|
318
|
-
if 'aarch64' in machine or 'arm64' in machine:
|
|
319
|
-
return ['linux_aarch64', 'manylinux_aarch64']
|
|
320
|
-
return ['linux_x86_64', 'manylinux_x86_64']
|
|
321
|
-
|
|
322
|
-
if system == 'darwin':
|
|
323
|
-
if 'arm64' in machine or 'aarch64' in machine:
|
|
324
|
-
return ['macosx_11_0_arm64', 'macosx_12_0_arm64', 'macosx_13_0_arm64', 'macosx_14_0_arm64']
|
|
325
|
-
# Intel macs
|
|
326
|
-
return [
|
|
327
|
-
'macosx_10_9_x86_64',
|
|
328
|
-
'macosx_11_0_x86_64',
|
|
329
|
-
'macosx_12_0_x86_64',
|
|
330
|
-
'macosx_13_0_x86_64',
|
|
331
|
-
'macosx_14_0_x86_64',
|
|
332
|
-
]
|
|
333
|
-
|
|
334
|
-
if system == 'windows':
|
|
335
|
-
if 'arm64' in machine:
|
|
336
|
-
# If k2 provides win_arm64 wheels in future, this will catch them.
|
|
337
|
-
return ['win_arm64']
|
|
338
|
-
return ['win_amd64']
|
|
339
|
-
|
|
340
|
-
return []
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
def choose_best_wheel(links: List[str], require_cuda: bool) -> Optional[str]:
|
|
344
|
-
py_tag, abi_tag = py_tags()
|
|
345
|
-
plat_filters = platform_tag_filters()
|
|
346
|
-
|
|
347
|
-
def match_basic(name: str) -> bool:
|
|
348
|
-
# python tag & abi tag must appear
|
|
349
|
-
if py_tag not in name or abi_tag not in name:
|
|
350
|
-
return False
|
|
351
|
-
# platform tag must match one of known substrings
|
|
352
|
-
if not any(tag in name for tag in plat_filters):
|
|
353
|
-
return False
|
|
354
|
-
return True
|
|
355
|
-
|
|
356
|
-
candidates = [u for u in links if match_basic(u)]
|
|
357
|
-
if not candidates:
|
|
358
|
-
return None
|
|
359
|
-
|
|
360
|
-
if require_cuda:
|
|
361
|
-
candidates = best_match_cuda(candidates, detect_cuda_version_linux())
|
|
362
|
-
if not candidates:
|
|
363
|
-
return None
|
|
364
|
-
else:
|
|
365
|
-
# For CPU, try to exclude CUDA wheels explicitly
|
|
366
|
-
candidates = [u for u in candidates if 'cuda' not in u.lower()]
|
|
367
|
-
|
|
368
|
-
# Now sort by (dev date desc, version desc, URL lex desc as tie-breaker)
|
|
369
|
-
def sort_key(u: str):
|
|
370
|
-
date = parse_devdate(u) or 0
|
|
371
|
-
ver = parse_version_tuple(u)
|
|
372
|
-
return (date, ver, u)
|
|
373
|
-
|
|
374
|
-
candidates.sort(key=sort_key, reverse=True)
|
|
375
|
-
return candidates[0] if candidates else None
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def run_pip_install(wheel_url: str, dry_run: bool):
|
|
379
|
-
cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade', '--no-cache-dir', wheel_url]
|
|
380
|
-
print('[INFO] Pip command:', ' '.join(cmd))
|
|
381
|
-
if dry_run:
|
|
382
|
-
print('[DRY-RUN] Skipping actual installation.')
|
|
383
|
-
return
|
|
384
|
-
try:
|
|
385
|
-
subprocess.check_call(cmd)
|
|
386
|
-
print('[SUCCESS] k2 has been installed successfully.')
|
|
387
|
-
except subprocess.CalledProcessError as e:
|
|
388
|
-
print('[ERROR] pip install failed with exit code:', e.returncode)
|
|
389
|
-
sys.exit(e.returncode)
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
def install_k2_main(dry_run: bool = False):
|
|
393
|
-
"""Main function to install k2 without argparse, suitable for programmatic use."""
|
|
394
|
-
system = platform.system().lower()
|
|
395
|
-
print(f'[INFO] Detected OS: {system}')
|
|
396
|
-
print(f'[INFO] Python: {platform.python_version()} | Impl: {platform.python_implementation()}')
|
|
397
|
-
|
|
398
|
-
# Check if torch is already installed
|
|
399
|
-
torch_version = detect_torch_version()
|
|
400
|
-
if torch_version:
|
|
401
|
-
print(f'[INFO] Detected PyTorch version: {torch_version}')
|
|
402
|
-
else:
|
|
403
|
-
print('[INFO] PyTorch not detected, will search all available versions')
|
|
404
|
-
|
|
405
|
-
if system == 'linux':
|
|
406
|
-
print('[INFO] Target: Linux (CUDA wheels)')
|
|
407
|
-
cuda_version = detect_cuda_version_linux()
|
|
408
|
-
if not cuda_version:
|
|
409
|
-
print('[WARN] No CUDA detected on Linux.')
|
|
410
|
-
# print("[HINT] Install CUDA or build from source if CPU-only is required.")
|
|
411
|
-
# print("")
|
|
412
|
-
# print("To build k2 from source, you can run the following commands:")
|
|
413
|
-
# print(" git clone https://github.com/k2-fsa/k2.git")
|
|
414
|
-
# print(" cd k2")
|
|
415
|
-
# print(' export K2_MAKE_ARGS="-j6"')
|
|
416
|
-
# print(" python3 setup.py install")
|
|
417
|
-
# print("")
|
|
418
|
-
# response = input("Do you want to continue with source installation? (y/N): ").strip().lower()
|
|
419
|
-
# if response in ["y", "yes"]:
|
|
420
|
-
# print("[INFO] Please run the commands above manually to install k2 from source.")
|
|
421
|
-
# sys.exit(2)
|
|
422
|
-
print(f'[INFO] Detected CUDA version: {cuda_version}')
|
|
423
|
-
|
|
424
|
-
wheel = None
|
|
425
|
-
for _torch_version in [torch_version, None] if torch_version else [None]:
|
|
426
|
-
for _cuda_version in [cuda_version, None] if cuda_version else [None]:
|
|
427
|
-
links = fetch_wheel_links(CUDA_LINUX_URL, _torch_version, cuda_version=_cuda_version)
|
|
428
|
-
if _torch_version and links:
|
|
429
|
-
# If we have torch version and found matching wheel, use it directly
|
|
430
|
-
wheel = links[0]
|
|
431
|
-
else:
|
|
432
|
-
# Fallback to traditional selection
|
|
433
|
-
if not links:
|
|
434
|
-
links = fetch_wheel_links(CUDA_LINUX_URL)
|
|
435
|
-
wheel = choose_best_wheel(links, require_cuda=_cuda_version is not None)
|
|
436
|
-
|
|
437
|
-
if not _torch_version and links and not wheel:
|
|
438
|
-
wheel = links[0] # Pick first available as last resort
|
|
439
|
-
|
|
440
|
-
if not wheel:
|
|
441
|
-
if _cuda_version:
|
|
442
|
-
print(
|
|
443
|
-
f'[WARN] No suitable wheel found for CUDA {_cuda_version}, " + \
|
|
444
|
-
"trying without CUDA preference...'
|
|
445
|
-
)
|
|
446
|
-
else:
|
|
447
|
-
break # Found a wheel, exit loop
|
|
448
|
-
|
|
449
|
-
if not wheel and _torch_version:
|
|
450
|
-
print(
|
|
451
|
-
f'[WARN] Tried torch version {_torch_version}, but not found wheel, trying without torch version...'
|
|
452
|
-
)
|
|
453
|
-
|
|
454
|
-
if wheel:
|
|
455
|
-
break
|
|
456
|
-
|
|
457
|
-
print(f'[INFO] Selected wheel:\n {wheel}')
|
|
458
|
-
run_pip_install(wheel, dry_run)
|
|
459
|
-
return
|
|
460
|
-
|
|
461
|
-
elif system == 'darwin':
|
|
462
|
-
print('[INFO] Target: macOS (CPU wheels)')
|
|
463
|
-
for _torch_version in [torch_version, None] if torch_version else [None]:
|
|
464
|
-
links = fetch_wheel_links(MAC_CPU_URL, _torch_version)
|
|
465
|
-
if _torch_version and links:
|
|
466
|
-
# If we have torch version and found matching wheel, use it directly
|
|
467
|
-
wheel = links[0]
|
|
468
|
-
else:
|
|
469
|
-
# Fallback to traditional selection
|
|
470
|
-
if not links:
|
|
471
|
-
links = fetch_wheel_links(MAC_CPU_URL)
|
|
472
|
-
wheel = choose_best_wheel(links, require_cuda=False)
|
|
473
|
-
if links and not wheel:
|
|
474
|
-
wheel = links[0] # Pick first available as last resort
|
|
475
|
-
|
|
476
|
-
if not wheel:
|
|
477
|
-
print('[ERROR] Could not find a suitable macOS CPU wheel for your Python/platform.')
|
|
478
|
-
sys.exit(1)
|
|
479
|
-
|
|
480
|
-
print(f'[INFO] Selected wheel:\n {wheel}')
|
|
481
|
-
run_pip_install(wheel, dry_run)
|
|
482
|
-
return
|
|
483
|
-
|
|
484
|
-
elif system == 'windows':
|
|
485
|
-
print('[INFO] Target: Windows (CPU wheels)')
|
|
486
|
-
for _torch_version in [torch_version, None] if torch_version else [None]:
|
|
487
|
-
links = fetch_wheel_links(WIN_CPU_URL, torch_version)
|
|
488
|
-
if torch_version and links:
|
|
489
|
-
# If we have torch version and found matching wheel, use it directly
|
|
490
|
-
wheel = links[0]
|
|
491
|
-
else:
|
|
492
|
-
# Fallback to traditional selection
|
|
493
|
-
if not links:
|
|
494
|
-
links = fetch_wheel_links(WIN_CPU_URL)
|
|
495
|
-
wheel = choose_best_wheel(links, require_cuda=False)
|
|
496
|
-
if links and not wheel:
|
|
497
|
-
wheel = links[0] # Pick first available as last resort
|
|
498
|
-
|
|
499
|
-
if not wheel:
|
|
500
|
-
print('[ERROR] Could not find a suitable Windows CPU wheel for your Python/platform.')
|
|
501
|
-
sys.exit(1)
|
|
502
|
-
print(f'[INFO] Selected wheel:\n {wheel}')
|
|
503
|
-
run_pip_install(wheel, dry_run)
|
|
504
|
-
return
|
|
505
|
-
|
|
506
|
-
else:
|
|
507
|
-
print(f'[ERROR] Unsupported OS: {system}')
|
|
508
|
-
sys.exit(3)
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
def install_k2():
|
|
512
|
-
"""CLI entry point with argparse support."""
|
|
513
|
-
parser = argparse.ArgumentParser(description='Auto-install the latest k2 wheel for your environment.')
|
|
514
|
-
parser.add_argument('--dry-run', action='store_true', help='Show what would be installed without making changes.')
|
|
515
|
-
args = parser.parse_args()
|
|
516
|
-
install_k2_main(dry_run=args.dry_run)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
if __name__ == '__main__':
|
|
520
|
-
install_k2()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|