lemonade-sdk 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lemonade/__init__.py +5 -0
- lemonade/api.py +180 -0
- lemonade/cache.py +92 -0
- lemonade/cli.py +173 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/build.py +176 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/inference_engines.py +408 -0
- lemonade/common/network.py +93 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +471 -0
- lemonade/common/system_info.py +1411 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/agt_power.py +437 -0
- lemonade/profilers/hwinfo_power.py +429 -0
- lemonade/profilers/memory_tracker.py +259 -0
- lemonade/profilers/profiler.py +58 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/accuracy.py +432 -0
- lemonade/tools/adapter.py +114 -0
- lemonade/tools/bench.py +302 -0
- lemonade/tools/flm/__init__.py +1 -0
- lemonade/tools/flm/utils.py +305 -0
- lemonade/tools/huggingface/bench.py +187 -0
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/huggingface/utils.py +359 -0
- lemonade/tools/humaneval.py +264 -0
- lemonade/tools/llamacpp/bench.py +255 -0
- lemonade/tools/llamacpp/load.py +222 -0
- lemonade/tools/llamacpp/utils.py +1260 -0
- lemonade/tools/management_tools.py +319 -0
- lemonade/tools/mmlu.py +319 -0
- lemonade/tools/oga/__init__.py +0 -0
- lemonade/tools/oga/bench.py +120 -0
- lemonade/tools/oga/load.py +804 -0
- lemonade/tools/oga/migration.py +403 -0
- lemonade/tools/oga/utils.py +462 -0
- lemonade/tools/perplexity.py +147 -0
- lemonade/tools/prompt.py +263 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +899 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/flm.py +133 -0
- lemonade/tools/server/llamacpp.py +320 -0
- lemonade/tools/server/serve.py +2123 -0
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/index.html +279 -0
- lemonade/tools/server/static/js/chat.js +1059 -0
- lemonade/tools/server/static/js/model-settings.js +183 -0
- lemonade/tools/server/static/js/models.js +1395 -0
- lemonade/tools/server/static/js/shared.js +556 -0
- lemonade/tools/server/static/logs.html +191 -0
- lemonade/tools/server/static/styles.css +2654 -0
- lemonade/tools/server/static/webapp.html +321 -0
- lemonade/tools/server/tool_calls.py +153 -0
- lemonade/tools/server/tray.py +664 -0
- lemonade/tools/server/utils/macos_tray.py +226 -0
- lemonade/tools/server/utils/port.py +77 -0
- lemonade/tools/server/utils/thread.py +85 -0
- lemonade/tools/server/utils/windows_tray.py +408 -0
- lemonade/tools/server/webapp.py +34 -0
- lemonade/tools/server/wrapped_server.py +559 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +239 -0
- lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
- lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
- lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
- lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
- lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
- lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +805 -0
- lemonade_server/model_manager.py +758 -0
- lemonade_server/pydantic_models.py +159 -0
- lemonade_server/server_models.json +643 -0
- lemonade_server/settings.py +39 -0
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import importlib.util
|
|
4
|
+
import importlib.metadata
|
|
5
|
+
import subprocess
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Dict, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InferenceEngineDetector:
|
|
11
|
+
"""
|
|
12
|
+
Main class for detecting inference engine availability.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self.oga_detector = OGADetector()
|
|
17
|
+
self.llamacpp_detector = LlamaCppDetector()
|
|
18
|
+
self.transformers_detector = TransformersDetector()
|
|
19
|
+
|
|
20
|
+
def detect_engines_for_device(
|
|
21
|
+
self, device_type: str, device_name: str
|
|
22
|
+
) -> Dict[str, Dict]:
|
|
23
|
+
"""
|
|
24
|
+
Detect all available inference engines for a specific device type.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
device_type: "cpu", "amd_igpu", "amd_dgpu", "nvidia_dgpu", or "npu"
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
dict: Engine availability information
|
|
31
|
+
"""
|
|
32
|
+
engines = {}
|
|
33
|
+
|
|
34
|
+
# Detect OGA availability
|
|
35
|
+
oga_info = self.oga_detector.detect_for_device(device_type)
|
|
36
|
+
if oga_info:
|
|
37
|
+
engines["oga"] = oga_info
|
|
38
|
+
|
|
39
|
+
# Detect llama.cpp vulkan availability
|
|
40
|
+
llamacpp_info = self.llamacpp_detector.detect_for_device(
|
|
41
|
+
device_type, device_name, "vulkan"
|
|
42
|
+
)
|
|
43
|
+
if llamacpp_info:
|
|
44
|
+
engines["llamacpp-vulkan"] = llamacpp_info
|
|
45
|
+
|
|
46
|
+
# Detect llama.cpp rocm availability
|
|
47
|
+
llamacpp_info = self.llamacpp_detector.detect_for_device(
|
|
48
|
+
device_type, device_name, "rocm"
|
|
49
|
+
)
|
|
50
|
+
if llamacpp_info:
|
|
51
|
+
engines["llamacpp-rocm"] = llamacpp_info
|
|
52
|
+
|
|
53
|
+
# Detect Transformers availability
|
|
54
|
+
transformers_info = self.transformers_detector.detect_for_device(device_type)
|
|
55
|
+
if transformers_info:
|
|
56
|
+
engines["transformers"] = transformers_info
|
|
57
|
+
|
|
58
|
+
return engines
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class BaseEngineDetector(ABC):
|
|
62
|
+
"""
|
|
63
|
+
Base class for engine-specific detectors.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def detect_for_device(self, device_type: str) -> Optional[Dict]:
|
|
68
|
+
"""
|
|
69
|
+
Detect engine availability for specific device type.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def is_installed(self) -> bool:
|
|
74
|
+
"""
|
|
75
|
+
Check if the engine package/binary is installed.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class OGADetector(BaseEngineDetector):
|
|
80
|
+
"""
|
|
81
|
+
Detector for ONNX Runtime GenAI (OGA).
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def detect_for_device(self, device_type: str) -> Optional[Dict]:
|
|
85
|
+
"""
|
|
86
|
+
Detect OGA availability for specific device.
|
|
87
|
+
"""
|
|
88
|
+
# Check package installation based on device type
|
|
89
|
+
if device_type == "npu":
|
|
90
|
+
if not self.is_npu_package_installed():
|
|
91
|
+
return {
|
|
92
|
+
"available": False,
|
|
93
|
+
"error": "NPU packages not installed (need "
|
|
94
|
+
"onnxruntime-genai-directml-ryzenai or onnxruntime-vitisai)",
|
|
95
|
+
}
|
|
96
|
+
else:
|
|
97
|
+
# For other devices, check general OGA installation
|
|
98
|
+
if not self.is_installed():
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
import onnxruntime as ort
|
|
103
|
+
|
|
104
|
+
# Map device types to ORT providers
|
|
105
|
+
device_provider_map = {
|
|
106
|
+
"cpu": "cpu",
|
|
107
|
+
"amd_igpu": "dml",
|
|
108
|
+
"amd_dgpu": "dml",
|
|
109
|
+
"npu": "vitisai",
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if device_type not in device_provider_map:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
backend = device_provider_map[device_type]
|
|
116
|
+
|
|
117
|
+
# Map backends to ORT provider names
|
|
118
|
+
provider_map = {
|
|
119
|
+
"cpu": "CPUExecutionProvider",
|
|
120
|
+
"dml": "DmlExecutionProvider",
|
|
121
|
+
"vitisai": "VitisAIExecutionProvider",
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
required_provider = provider_map[backend]
|
|
125
|
+
available_providers = ort.get_available_providers()
|
|
126
|
+
|
|
127
|
+
if required_provider in available_providers:
|
|
128
|
+
result = {
|
|
129
|
+
"available": True,
|
|
130
|
+
"version": self._get_oga_version(device_type),
|
|
131
|
+
"backend": backend,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
# Add dependency versions in details
|
|
135
|
+
result["details"] = {
|
|
136
|
+
"dependency_versions": {"onnxruntime": ort.__version__}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return result
|
|
140
|
+
else:
|
|
141
|
+
if device_type == "npu":
|
|
142
|
+
error_msg = (
|
|
143
|
+
"VitisAI provider not available - "
|
|
144
|
+
"check AMD NPU driver installation"
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
error_msg = f"{backend.upper()} provider not available"
|
|
148
|
+
|
|
149
|
+
return {
|
|
150
|
+
"available": False,
|
|
151
|
+
"error": error_msg,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
except (ImportError, AttributeError) as e:
|
|
155
|
+
return {"available": False, "error": f"OGA detection failed: {str(e)}"}
|
|
156
|
+
|
|
157
|
+
def is_installed(self) -> bool:
|
|
158
|
+
"""
|
|
159
|
+
Check if OGA is installed.
|
|
160
|
+
"""
|
|
161
|
+
return importlib.util.find_spec("onnxruntime_genai") is not None
|
|
162
|
+
|
|
163
|
+
def is_npu_package_installed(self) -> bool:
|
|
164
|
+
"""
|
|
165
|
+
Check if NPU-specific OGA packages are installed.
|
|
166
|
+
"""
|
|
167
|
+
try:
|
|
168
|
+
|
|
169
|
+
installed_packages = [
|
|
170
|
+
dist.metadata["name"].lower()
|
|
171
|
+
for dist in importlib.metadata.distributions()
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
# Check for NPU-specific packages
|
|
175
|
+
npu_packages = ["onnxruntime-genai-directml-ryzenai", "onnxruntime-vitisai"]
|
|
176
|
+
|
|
177
|
+
for package in npu_packages:
|
|
178
|
+
if package.lower() in installed_packages:
|
|
179
|
+
return True
|
|
180
|
+
return False
|
|
181
|
+
except (ImportError, AttributeError):
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
def _get_oga_version(self, device_type: str) -> str:
|
|
185
|
+
"""
|
|
186
|
+
Get OGA version.
|
|
187
|
+
"""
|
|
188
|
+
try:
|
|
189
|
+
# For NPU, try NPU-specific packages first
|
|
190
|
+
if device_type == "npu":
|
|
191
|
+
try:
|
|
192
|
+
import onnxruntime_genai_directml_ryzenai as og
|
|
193
|
+
|
|
194
|
+
return og.__version__
|
|
195
|
+
except ImportError:
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
import onnxruntime_vitisai as og
|
|
200
|
+
|
|
201
|
+
return og.__version__
|
|
202
|
+
except ImportError:
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
# Fall back to general onnxruntime_genai
|
|
206
|
+
import onnxruntime_genai as og
|
|
207
|
+
|
|
208
|
+
return og.__version__
|
|
209
|
+
except (ImportError, AttributeError):
|
|
210
|
+
return "unknown"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class LlamaCppDetector(BaseEngineDetector):
|
|
214
|
+
"""
|
|
215
|
+
Detector for llama.cpp.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def detect_for_device(
|
|
219
|
+
self, device_type: str, device_name: str, backend: str
|
|
220
|
+
) -> Optional[Dict]:
|
|
221
|
+
"""
|
|
222
|
+
Detect llama.cpp availability for specific device.
|
|
223
|
+
"""
|
|
224
|
+
try:
|
|
225
|
+
|
|
226
|
+
if device_type not in ["cpu", "amd_igpu", "amd_dgpu", "nvidia_dgpu"]:
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
# Check if the device is supported by the backend
|
|
230
|
+
if device_type == "cpu":
|
|
231
|
+
device_supported = True
|
|
232
|
+
elif device_type in ["amd_igpu", "amd_dgpu"]:
|
|
233
|
+
if backend == "vulkan":
|
|
234
|
+
device_supported = self._check_vulkan_support()
|
|
235
|
+
elif backend == "rocm":
|
|
236
|
+
device_supported = self._check_rocm_support(device_name.lower())
|
|
237
|
+
else:
|
|
238
|
+
device_supported = False
|
|
239
|
+
elif device_type == "nvidia_dgpu":
|
|
240
|
+
if backend == "vulkan":
|
|
241
|
+
device_supported = self._check_vulkan_support()
|
|
242
|
+
else:
|
|
243
|
+
device_supported = False
|
|
244
|
+
else:
|
|
245
|
+
device_supported = False
|
|
246
|
+
if not device_supported:
|
|
247
|
+
return {"available": False, "error": f"{backend} not available"}
|
|
248
|
+
|
|
249
|
+
is_installed = self.is_installed(backend)
|
|
250
|
+
if not is_installed:
|
|
251
|
+
return {
|
|
252
|
+
"available": False,
|
|
253
|
+
"error": f"{backend} binaries not installed",
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
return {
|
|
257
|
+
"available": True,
|
|
258
|
+
"version": self._get_llamacpp_version(backend),
|
|
259
|
+
"backend": backend,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
except (ImportError, OSError, subprocess.SubprocessError) as e:
|
|
263
|
+
return {
|
|
264
|
+
"available": False,
|
|
265
|
+
"error": f"llama.cpp detection failed: {str(e)}",
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
def is_installed(self, backend: str) -> bool:
|
|
269
|
+
"""
|
|
270
|
+
Check if llama.cpp binaries are available for any backend.
|
|
271
|
+
"""
|
|
272
|
+
from lemonade.tools.llamacpp.utils import get_llama_server_exe_path
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
server_exe_path = get_llama_server_exe_path(backend)
|
|
276
|
+
if os.path.exists(server_exe_path):
|
|
277
|
+
return True
|
|
278
|
+
except (ImportError, OSError, ValueError):
|
|
279
|
+
pass
|
|
280
|
+
|
|
281
|
+
return False
|
|
282
|
+
|
|
283
|
+
def _check_vulkan_support(self) -> bool:
|
|
284
|
+
"""
|
|
285
|
+
Check if Vulkan is available for GPU acceleration.
|
|
286
|
+
"""
|
|
287
|
+
try:
|
|
288
|
+
# Run vulkaninfo to check Vulkan availability
|
|
289
|
+
result = subprocess.run(
|
|
290
|
+
["vulkaninfo", "--summary"],
|
|
291
|
+
capture_output=True,
|
|
292
|
+
text=True,
|
|
293
|
+
timeout=10,
|
|
294
|
+
check=False,
|
|
295
|
+
)
|
|
296
|
+
return result.returncode == 0
|
|
297
|
+
except (
|
|
298
|
+
subprocess.TimeoutExpired,
|
|
299
|
+
FileNotFoundError,
|
|
300
|
+
subprocess.SubprocessError,
|
|
301
|
+
):
|
|
302
|
+
try:
|
|
303
|
+
# Check for Vulkan DLL on Windows
|
|
304
|
+
vulkan_dll_paths = [
|
|
305
|
+
"C:\\Windows\\System32\\vulkan-1.dll",
|
|
306
|
+
"C:\\Windows\\SysWOW64\\vulkan-1.dll",
|
|
307
|
+
]
|
|
308
|
+
# Check for Vulkan libraries on Linux
|
|
309
|
+
vulkan_lib_paths = [
|
|
310
|
+
"/usr/lib/x86_64-linux-gnu/libvulkan.so.1",
|
|
311
|
+
"/usr/lib/libvulkan.so.1",
|
|
312
|
+
"/lib/x86_64-linux-gnu/libvulkan.so.1",
|
|
313
|
+
]
|
|
314
|
+
return any(os.path.exists(path) for path in vulkan_dll_paths) or any(
|
|
315
|
+
os.path.exists(path) for path in vulkan_lib_paths
|
|
316
|
+
)
|
|
317
|
+
except OSError:
|
|
318
|
+
return False
|
|
319
|
+
|
|
320
|
+
def _check_rocm_support(self, device_name: str) -> bool:
|
|
321
|
+
"""
|
|
322
|
+
Check if ROCM is available for GPU acceleration.
|
|
323
|
+
"""
|
|
324
|
+
from lemonade.tools.llamacpp.utils import identify_rocm_arch_from_name
|
|
325
|
+
|
|
326
|
+
return identify_rocm_arch_from_name(device_name) is not None
|
|
327
|
+
|
|
328
|
+
def _get_llamacpp_version(self, backend: str) -> str:
|
|
329
|
+
"""
|
|
330
|
+
Get llama.cpp version from lemonade's managed installation for specific backend.
|
|
331
|
+
"""
|
|
332
|
+
try:
|
|
333
|
+
# Use backend-specific path - same logic as get_llama_folder_path in utils.py
|
|
334
|
+
# Uses sys.prefix to get the environment root (works for both venv and conda)
|
|
335
|
+
server_base_dir = os.path.join(sys.prefix, backend, "llama_server")
|
|
336
|
+
version_file = os.path.join(server_base_dir, "version.txt")
|
|
337
|
+
|
|
338
|
+
if os.path.exists(version_file):
|
|
339
|
+
with open(version_file, "r", encoding="utf-8") as f:
|
|
340
|
+
version = f.read().strip()
|
|
341
|
+
return version
|
|
342
|
+
except (ImportError, OSError):
|
|
343
|
+
pass
|
|
344
|
+
|
|
345
|
+
return "unknown"
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class TransformersDetector(BaseEngineDetector):
|
|
349
|
+
"""
|
|
350
|
+
Detector for Transformers/PyTorch.
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
def detect_for_device(self, device_type: str) -> Optional[Dict]:
|
|
354
|
+
"""
|
|
355
|
+
Detect Transformers availability for specific device.
|
|
356
|
+
"""
|
|
357
|
+
if not self.is_installed():
|
|
358
|
+
return None
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
import torch
|
|
362
|
+
import transformers
|
|
363
|
+
|
|
364
|
+
if device_type == "cpu":
|
|
365
|
+
result = {
|
|
366
|
+
"available": True,
|
|
367
|
+
"version": transformers.__version__,
|
|
368
|
+
"backend": "cpu",
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
# Add dependency versions in details
|
|
372
|
+
result["details"] = {
|
|
373
|
+
"dependency_versions": {"torch": torch.__version__}
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
return result
|
|
377
|
+
else:
|
|
378
|
+
return None
|
|
379
|
+
|
|
380
|
+
except (ImportError, AttributeError) as e:
|
|
381
|
+
return {
|
|
382
|
+
"available": False,
|
|
383
|
+
"error": f"Transformers detection failed: {str(e)}",
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
def is_installed(self) -> bool:
|
|
387
|
+
"""
|
|
388
|
+
Check if Transformers and PyTorch are installed.
|
|
389
|
+
"""
|
|
390
|
+
return (
|
|
391
|
+
importlib.util.find_spec("transformers") is not None
|
|
392
|
+
and importlib.util.find_spec("torch") is not None
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def detect_inference_engines(device_type: str, device_name: str) -> Dict[str, Dict]:
|
|
397
|
+
"""
|
|
398
|
+
Helper function to detect inference engines for a device type.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
device_type: "cpu", "amd_igpu", "amd_dgpu", "nvidia_dgpu", or "npu"
|
|
402
|
+
device_name: device name
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
dict: Engine availability information.
|
|
406
|
+
"""
|
|
407
|
+
detector = InferenceEngineDetector()
|
|
408
|
+
return detector.detect_engines_for_device(device_type, device_name)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import socket
|
|
4
|
+
from huggingface_hub import model_info, snapshot_download
|
|
5
|
+
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def is_offline():
|
|
9
|
+
"""
|
|
10
|
+
Check if the system is offline by attempting to connect to huggingface.co.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
bool: True if the system is offline (cannot connect to huggingface.co),
|
|
14
|
+
False otherwise.
|
|
15
|
+
"""
|
|
16
|
+
if os.environ.get("LEMONADE_OFFLINE"):
|
|
17
|
+
return True
|
|
18
|
+
try:
|
|
19
|
+
socket.gethostbyname("huggingface.co")
|
|
20
|
+
return False
|
|
21
|
+
except socket.gaierror:
|
|
22
|
+
return True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_base_model(checkpoint: str) -> Optional[str]:
|
|
26
|
+
"""
|
|
27
|
+
Get the base model information for a given checkpoint from the Hugging Face Hub.
|
|
28
|
+
Will auto-detect if we're offline and skip the network call in that case.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
checkpoint: The model checkpoint to query
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The base model name if found, or None if not found or error occurs
|
|
35
|
+
"""
|
|
36
|
+
# Skip network call in offline mode
|
|
37
|
+
if is_offline():
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
info = model_info(checkpoint)
|
|
42
|
+
if info.cardData and "base_model" in info.cardData:
|
|
43
|
+
if info.cardData["base_model"] is not None:
|
|
44
|
+
# This is a derived model
|
|
45
|
+
return info.cardData["base_model"]
|
|
46
|
+
else:
|
|
47
|
+
# This is itself a base model
|
|
48
|
+
return [checkpoint]
|
|
49
|
+
except Exception: # pylint: disable=broad-except
|
|
50
|
+
pass
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _symlink_safe_snapshot_download(repo_id, **kwargs):
|
|
55
|
+
"""
|
|
56
|
+
Custom snapshot download with retry logic for Windows symlink privilege errors.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
for attempt in range(2):
|
|
60
|
+
try:
|
|
61
|
+
return snapshot_download(repo_id=repo_id, **kwargs)
|
|
62
|
+
except OSError as e:
|
|
63
|
+
if (
|
|
64
|
+
hasattr(e, "winerror")
|
|
65
|
+
and e.winerror == 1314 # pylint: disable=no-member
|
|
66
|
+
and attempt < 1
|
|
67
|
+
):
|
|
68
|
+
continue
|
|
69
|
+
raise
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def custom_snapshot_download(repo_id, do_not_upgrade=False, **kwargs):
|
|
73
|
+
"""
|
|
74
|
+
Custom snapshot download with:
|
|
75
|
+
1) retry logic for Windows symlink privilege errors.
|
|
76
|
+
2) do_not_upgrade allows the caller to prioritize a local copy
|
|
77
|
+
of the model over an upgraded remote copy.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
if do_not_upgrade:
|
|
81
|
+
try:
|
|
82
|
+
# Prioritize the local model, if available
|
|
83
|
+
return _symlink_safe_snapshot_download(
|
|
84
|
+
repo_id, local_files_only=True, **kwargs
|
|
85
|
+
)
|
|
86
|
+
except LocalEntryNotFoundError:
|
|
87
|
+
# LocalEntryNotFoundError means there was no local model, at this point
|
|
88
|
+
# we'll accept a remote model
|
|
89
|
+
return _symlink_safe_snapshot_download(
|
|
90
|
+
repo_id, local_files_only=False, **kwargs
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
return _symlink_safe_snapshot_download(repo_id, **kwargs)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import enum
|
|
4
|
+
import sys
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Colors:
|
|
9
|
+
HEADER = "\033[95m"
|
|
10
|
+
OKBLUE = "\033[94m"
|
|
11
|
+
OKCYAN = "\033[96m"
|
|
12
|
+
OKGREEN = "\033[92m"
|
|
13
|
+
WARNING = "\033[93m"
|
|
14
|
+
FAIL = "\033[91m"
|
|
15
|
+
ENDC = "\033[0m"
|
|
16
|
+
BOLD = "\033[1m"
|
|
17
|
+
UNDERLINE = "\033[4m"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def log(txt, c=Colors.ENDC, end="", is_error=False):
|
|
21
|
+
logn(txt, c=c, end=end, is_error=is_error)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def logn(txt, c=Colors.ENDC, end="\n", is_error=False):
|
|
25
|
+
file = sys.stderr if is_error else sys.stdout
|
|
26
|
+
print(c + txt + Colors.ENDC, end=end, flush=True, file=file)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LogType(enum.Enum):
|
|
30
|
+
ERROR = "Error:"
|
|
31
|
+
SUCCESS = "Woohoo!"
|
|
32
|
+
WARNING = "Warning:"
|
|
33
|
+
INFO = "Info:"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def clean_print(type: LogType, msg):
|
|
37
|
+
# Replace path to user’s home directory by a tilde symbol (~)
|
|
38
|
+
home_directory = os.path.expanduser("~")
|
|
39
|
+
home_directory_escaped = re.escape(home_directory)
|
|
40
|
+
msg = re.sub(home_directory_escaped, "~", msg)
|
|
41
|
+
|
|
42
|
+
# Split message into list, remove leading spaces and line breaks
|
|
43
|
+
msg = msg.split("\n")
|
|
44
|
+
msg = [line.lstrip() for line in msg]
|
|
45
|
+
while msg[0] == "" and len(msg) > 1:
|
|
46
|
+
msg.pop(0)
|
|
47
|
+
|
|
48
|
+
# Print message
|
|
49
|
+
indentation = len(type.value) + 1
|
|
50
|
+
if type == LogType.ERROR:
|
|
51
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.FAIL, is_error=True)
|
|
52
|
+
elif type == LogType.SUCCESS:
|
|
53
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.OKGREEN)
|
|
54
|
+
elif type == LogType.WARNING:
|
|
55
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.WARNING)
|
|
56
|
+
elif type == LogType.INFO:
|
|
57
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.OKCYAN)
|
|
58
|
+
|
|
59
|
+
is_error = type == LogType.ERROR
|
|
60
|
+
for line_idx, line in enumerate(msg):
|
|
61
|
+
if line_idx != 0:
|
|
62
|
+
log(" " * indentation)
|
|
63
|
+
s_line = line.split("**")
|
|
64
|
+
for idx, l in enumerate(s_line):
|
|
65
|
+
c = Colors.ENDC if idx % 2 == 0 else Colors.BOLD
|
|
66
|
+
if idx != len(s_line) - 1:
|
|
67
|
+
log(l, c=c, is_error=is_error)
|
|
68
|
+
else:
|
|
69
|
+
logn(l, c=c, is_error=is_error)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def log_error(msg):
|
|
73
|
+
clean_print(LogType.ERROR, str(msg))
|
|
74
|
+
# ASCII art credit:
|
|
75
|
+
# https://textart4u.blogspot.com/2014/05/the-fail-whale-ascii-art-code.html
|
|
76
|
+
logn(
|
|
77
|
+
"""\n▄██████████████▄▐█▄▄▄▄█▌
|
|
78
|
+
██████▌▄▌▄▐▐▌███▌▀▀██▀▀
|
|
79
|
+
████▄█▌▄▌▄▐▐▌▀███▄▄█▌
|
|
80
|
+
▄▄▄▄▄██████████████\n\n""",
|
|
81
|
+
is_error=True,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def log_success(msg):
|
|
86
|
+
clean_print(LogType.SUCCESS, msg)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def log_warning(msg):
|
|
90
|
+
clean_print(LogType.WARNING, msg)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def log_info(msg):
|
|
94
|
+
clean_print(LogType.INFO, msg)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def list_table(list, padding=25, num_cols=4):
|
|
98
|
+
lines_per_column = int(math.ceil(len(list) / num_cols))
|
|
99
|
+
for i in range(lines_per_column):
|
|
100
|
+
for col in range(num_cols):
|
|
101
|
+
if i + col * lines_per_column < len(list):
|
|
102
|
+
print(
|
|
103
|
+
list[i + col * lines_per_column].ljust(padding),
|
|
104
|
+
end="",
|
|
105
|
+
)
|
|
106
|
+
print("\n\t", end="")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
110
|
+
# Modifications Copyright (c) 2025 AMD
|