lemonade-sdk 7.0.4__py3-none-any.whl → 8.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/api.py +3 -3
- lemonade/cli.py +11 -17
- lemonade/common/build.py +0 -47
- lemonade/common/network.py +50 -0
- lemonade/common/status.py +2 -21
- lemonade/common/system_info.py +19 -4
- lemonade/profilers/memory_tracker.py +3 -1
- lemonade/tools/accuracy.py +3 -4
- lemonade/tools/adapter.py +1 -2
- lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
- lemonade/tools/humaneval.py +9 -3
- lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
- lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
- lemonade/tools/mmlu.py +7 -15
- lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
- lemonade/tools/oga/utils.py +423 -0
- lemonade/tools/perplexity.py +4 -3
- lemonade/tools/prompt.py +2 -1
- lemonade/tools/quark/quark_load.py +2 -1
- lemonade/tools/quark/quark_quantize.py +5 -5
- lemonade/tools/report/table.py +3 -3
- lemonade/tools/server/llamacpp.py +188 -45
- lemonade/tools/server/serve.py +184 -146
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/styles.css +568 -0
- lemonade/tools/server/static/webapp.html +439 -0
- lemonade/tools/server/tray.py +458 -0
- lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
- lemonade/tools/server/utils/system_tray.py +395 -0
- lemonade/tools/server/{instructions.py → webapp.py} +4 -10
- lemonade/version.py +1 -1
- lemonade_install/install.py +46 -28
- lemonade_sdk-8.0.1.dist-info/METADATA +179 -0
- lemonade_sdk-8.0.1.dist-info/RECORD +70 -0
- lemonade_server/cli.py +182 -27
- lemonade_server/model_manager.py +192 -20
- lemonade_server/pydantic_models.py +9 -4
- lemonade_server/server_models.json +5 -3
- lemonade/common/analyze_model.py +0 -26
- lemonade/common/labels.py +0 -61
- lemonade/common/onnx_helpers.py +0 -176
- lemonade/common/plugins.py +0 -10
- lemonade/common/tensor_helpers.py +0 -83
- lemonade/tools/server/static/instructions.html +0 -262
- lemonade_sdk-7.0.4.dist-info/METADATA +0 -113
- lemonade_sdk-7.0.4.dist-info/RECORD +0 -69
- /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
- /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
- /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/WHEEL +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/top_level.txt +0 -0
lemonade_server/model_manager.py
CHANGED
|
@@ -1,24 +1,60 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import shutil
|
|
3
5
|
import huggingface_hub
|
|
4
6
|
from importlib.metadata import distributions
|
|
5
|
-
from lemonade_server.pydantic_models import
|
|
7
|
+
from lemonade_server.pydantic_models import PullConfig
|
|
8
|
+
from lemonade.cache import DEFAULT_CACHE_DIR
|
|
9
|
+
|
|
10
|
+
USER_MODELS_FILE = os.path.join(DEFAULT_CACHE_DIR, "user_models.json")
|
|
6
11
|
|
|
7
12
|
|
|
8
13
|
class ModelManager:
|
|
9
14
|
|
|
15
|
+
@staticmethod
|
|
16
|
+
def parse_checkpoint(checkpoint: str) -> tuple[str, str | None]:
|
|
17
|
+
"""
|
|
18
|
+
Parse a checkpoint string that may contain a variant separated by a colon.
|
|
19
|
+
|
|
20
|
+
For GGUF models, the format is "repository:variant" (e.g., "unsloth/Qwen3-0.6B-GGUF:Q4_0").
|
|
21
|
+
For other models, there is no variant.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
checkpoint: The checkpoint string, potentially with variant
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
tuple: (base_checkpoint, variant) where variant is None if no colon is present
|
|
28
|
+
"""
|
|
29
|
+
if ":" in checkpoint:
|
|
30
|
+
base_checkpoint, variant = checkpoint.split(":", 1)
|
|
31
|
+
return base_checkpoint, variant
|
|
32
|
+
return checkpoint, None
|
|
33
|
+
|
|
10
34
|
@property
|
|
11
35
|
def supported_models(self) -> dict:
|
|
12
36
|
"""
|
|
13
37
|
Returns a dictionary of supported models.
|
|
14
38
|
Note: Models must be downloaded before they are locally available.
|
|
15
39
|
"""
|
|
16
|
-
# Load the models dictionary from the JSON file
|
|
40
|
+
# Load the models dictionary from the built-in JSON file
|
|
17
41
|
server_models_file = os.path.join(
|
|
18
42
|
os.path.dirname(__file__), "server_models.json"
|
|
19
43
|
)
|
|
20
44
|
with open(server_models_file, "r", encoding="utf-8") as file:
|
|
21
|
-
models = json.load(file)
|
|
45
|
+
models: dict = json.load(file)
|
|
46
|
+
|
|
47
|
+
# Load the user's JSON file, if it exists, and merge into the models dict
|
|
48
|
+
if os.path.exists(USER_MODELS_FILE):
|
|
49
|
+
with open(USER_MODELS_FILE, "r", encoding="utf-8") as file:
|
|
50
|
+
user_models: dict = json.load(file)
|
|
51
|
+
# Prepend the user namespace to the model names
|
|
52
|
+
user_models = {
|
|
53
|
+
f"user.{model_name}": model_info
|
|
54
|
+
for model_name, model_info in user_models.items()
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
models.update(user_models)
|
|
22
58
|
|
|
23
59
|
# Add the model name as a key in each entry, to make it easier
|
|
24
60
|
# to access later
|
|
@@ -49,11 +85,12 @@ class ModelManager:
|
|
|
49
85
|
Returns a dictionary of locally available models.
|
|
50
86
|
"""
|
|
51
87
|
downloaded_models = {}
|
|
88
|
+
downloaded_checkpoints = self.downloaded_hf_checkpoints
|
|
52
89
|
for model in self.supported_models:
|
|
53
|
-
|
|
54
|
-
self.supported_models[model]["checkpoint"]
|
|
55
|
-
|
|
56
|
-
|
|
90
|
+
base_checkpoint = self.parse_checkpoint(
|
|
91
|
+
self.supported_models[model]["checkpoint"]
|
|
92
|
+
)[0]
|
|
93
|
+
if base_checkpoint in downloaded_checkpoints:
|
|
57
94
|
downloaded_models[model] = self.supported_models[model]
|
|
58
95
|
return downloaded_models
|
|
59
96
|
|
|
@@ -65,7 +102,7 @@ class ModelManager:
|
|
|
65
102
|
"""
|
|
66
103
|
return self.filter_models_by_backend(self.downloaded_models)
|
|
67
104
|
|
|
68
|
-
def download_gguf(self, model_config:
|
|
105
|
+
def download_gguf(self, model_config: PullConfig) -> dict:
|
|
69
106
|
"""
|
|
70
107
|
Downloads the GGUF file for the given model configuration.
|
|
71
108
|
"""
|
|
@@ -74,7 +111,7 @@ class ModelManager:
|
|
|
74
111
|
# 1. A full GGUF filename (e.g. "model-Q4_0.gguf")
|
|
75
112
|
# 2. A quantization variant (e.g. "Q4_0")
|
|
76
113
|
# This code handles both cases by constructing the appropriate filename
|
|
77
|
-
checkpoint, variant = model_config.checkpoint
|
|
114
|
+
checkpoint, variant = self.parse_checkpoint(model_config.checkpoint)
|
|
78
115
|
hf_base_name = checkpoint.split("/")[-1].replace("-GGUF", "")
|
|
79
116
|
variant_name = (
|
|
80
117
|
variant if variant.endswith(".gguf") else f"{hf_base_name}-{variant}.gguf"
|
|
@@ -91,11 +128,24 @@ class ModelManager:
|
|
|
91
128
|
allow_patterns=list(expected_files.values()),
|
|
92
129
|
)
|
|
93
130
|
|
|
131
|
+
# Make sure we downloaded something
|
|
132
|
+
# If we didn't that can indicate that no patterns from allow_patterns match
|
|
133
|
+
# any files in the HF repo
|
|
134
|
+
if not os.path.exists(snapshot_folder):
|
|
135
|
+
raise ValueError(
|
|
136
|
+
"No patterns matched the variant parameter (CHECKPOINT:VARIANT). "
|
|
137
|
+
"Try again, providing the full filename of your target .gguf file as the variant."
|
|
138
|
+
" For example: Qwen/Qwen2.5-Coder-3B-Instruct-GGUF:"
|
|
139
|
+
"qwen2.5-coder-3b-instruct-q4_0.gguf"
|
|
140
|
+
)
|
|
141
|
+
|
|
94
142
|
# Ensure we downloaded all expected files while creating a dict of the downloaded files
|
|
95
143
|
snapshot_files = {}
|
|
96
144
|
for file in expected_files:
|
|
97
145
|
snapshot_files[file] = os.path.join(snapshot_folder, expected_files[file])
|
|
98
|
-
if expected_files[file] not in
|
|
146
|
+
if expected_files[file].lower() not in [
|
|
147
|
+
name.lower() for name in os.listdir(snapshot_folder)
|
|
148
|
+
]:
|
|
99
149
|
raise ValueError(
|
|
100
150
|
f"Hugging Face snapshot download for {model_config.checkpoint} "
|
|
101
151
|
f"expected file {expected_files[file]} not found in {snapshot_folder}"
|
|
@@ -104,24 +154,103 @@ class ModelManager:
|
|
|
104
154
|
# Return a dict that points to the snapshot path of the downloaded GGUF files
|
|
105
155
|
return snapshot_files
|
|
106
156
|
|
|
107
|
-
def download_models(
|
|
157
|
+
def download_models(
|
|
158
|
+
self,
|
|
159
|
+
models: list[str],
|
|
160
|
+
checkpoint: Optional[str] = None,
|
|
161
|
+
recipe: Optional[str] = None,
|
|
162
|
+
reasoning: bool = False,
|
|
163
|
+
mmproj: str = "",
|
|
164
|
+
):
|
|
108
165
|
"""
|
|
109
166
|
Downloads the specified models from Hugging Face.
|
|
110
167
|
"""
|
|
111
168
|
for model in models:
|
|
112
169
|
if model not in self.supported_models:
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
170
|
+
# Register the model as a user model if the model name
|
|
171
|
+
# is not already registered
|
|
172
|
+
|
|
173
|
+
# Ensure the model name includes the `user` namespace
|
|
174
|
+
model_parsed = model.split(".", 1)
|
|
175
|
+
if len(model_parsed) != 2 or model_parsed[0] != "user":
|
|
176
|
+
raise ValueError(
|
|
177
|
+
f"When registering a new model, the model name must "
|
|
178
|
+
"include the `user` namespace, for example "
|
|
179
|
+
f"`user.Phi-4-Mini-GGUF`. Received: {model}"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
model_name = model_parsed[1]
|
|
183
|
+
|
|
184
|
+
# Check that required arguments are provided
|
|
185
|
+
if not recipe or not checkpoint:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Model {model} is not registered with Lemonade Server. "
|
|
188
|
+
"To register and install it, provide the `checkpoint` "
|
|
189
|
+
"and `recipe` arguments, as well as the optional "
|
|
190
|
+
"`reasoning` and `mmproj` arguments as appropriate. "
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# JSON content that will be used for registration if the download succeeds
|
|
194
|
+
new_user_model = {
|
|
195
|
+
"checkpoint": checkpoint,
|
|
196
|
+
"recipe": recipe,
|
|
197
|
+
"reasoning": reasoning,
|
|
198
|
+
"suggested": True,
|
|
199
|
+
"labels": ["custom"],
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if mmproj:
|
|
203
|
+
new_user_model["mmproj"] = mmproj
|
|
204
|
+
|
|
205
|
+
# Make sure that a variant is provided for GGUF models before registering the model
|
|
206
|
+
if "gguf" in checkpoint.lower() and ":" not in checkpoint.lower():
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"You are required to provide a 'variant' in the checkpoint field when "
|
|
209
|
+
"registering a GGUF model. The variant is provided "
|
|
210
|
+
"as CHECKPOINT:VARIANT. For example: "
|
|
211
|
+
"Qwen/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_0 or "
|
|
212
|
+
"Qwen/Qwen2.5-Coder-3B-Instruct-GGUF:"
|
|
213
|
+
"qwen2.5-coder-3b-instruct-q4_0.gguf"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Create a PullConfig we will use to download the model
|
|
217
|
+
new_registration_model_config = PullConfig(
|
|
218
|
+
model_name=model_name,
|
|
219
|
+
checkpoint=checkpoint,
|
|
220
|
+
recipe=recipe,
|
|
221
|
+
reasoning=reasoning,
|
|
116
222
|
)
|
|
117
|
-
|
|
118
|
-
|
|
223
|
+
else:
|
|
224
|
+
new_registration_model_config = None
|
|
119
225
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
226
|
+
# Download the model
|
|
227
|
+
if new_registration_model_config:
|
|
228
|
+
checkpoint_to_download = checkpoint
|
|
229
|
+
gguf_model_config = new_registration_model_config
|
|
123
230
|
else:
|
|
124
|
-
|
|
231
|
+
checkpoint_to_download = self.supported_models[model]["checkpoint"]
|
|
232
|
+
gguf_model_config = PullConfig(**self.supported_models[model])
|
|
233
|
+
print(f"Downloading {model} ({checkpoint_to_download})")
|
|
234
|
+
|
|
235
|
+
if "gguf" in checkpoint_to_download.lower():
|
|
236
|
+
self.download_gguf(gguf_model_config)
|
|
237
|
+
else:
|
|
238
|
+
huggingface_hub.snapshot_download(repo_id=checkpoint_to_download)
|
|
239
|
+
|
|
240
|
+
# Register the model in user_models.json, creating that file if needed
|
|
241
|
+
# We do this registration after the download so that we don't register
|
|
242
|
+
# any incorrectly configured models where the download would fail
|
|
243
|
+
if new_registration_model_config:
|
|
244
|
+
if os.path.exists(USER_MODELS_FILE):
|
|
245
|
+
with open(USER_MODELS_FILE, "r", encoding="utf-8") as file:
|
|
246
|
+
user_models: dict = json.load(file)
|
|
247
|
+
else:
|
|
248
|
+
user_models = {}
|
|
249
|
+
|
|
250
|
+
user_models[model_name] = new_user_model
|
|
251
|
+
|
|
252
|
+
with open(USER_MODELS_FILE, mode="w", encoding="utf-8") as file:
|
|
253
|
+
json.dump(user_models, fp=file)
|
|
125
254
|
|
|
126
255
|
def filter_models_by_backend(self, models: dict) -> dict:
|
|
127
256
|
"""
|
|
@@ -143,6 +272,49 @@ class ModelManager:
|
|
|
143
272
|
filtered[model] = value
|
|
144
273
|
return filtered
|
|
145
274
|
|
|
275
|
+
def delete_model(self, model_name: str):
|
|
276
|
+
"""
|
|
277
|
+
Deletes the specified model from local storage.
|
|
278
|
+
"""
|
|
279
|
+
if model_name not in self.supported_models:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"Model {model_name} is not supported. Please choose from the following: "
|
|
282
|
+
f"{list(self.supported_models.keys())}"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
checkpoint = self.supported_models[model_name]["checkpoint"]
|
|
286
|
+
print(f"Deleting {model_name} ({checkpoint})")
|
|
287
|
+
|
|
288
|
+
# Handle GGUF models that have the format "checkpoint:variant"
|
|
289
|
+
base_checkpoint = self.parse_checkpoint(checkpoint)[0]
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
# Get the local path using snapshot_download with local_files_only=True
|
|
293
|
+
snapshot_path = huggingface_hub.snapshot_download(
|
|
294
|
+
repo_id=base_checkpoint, local_files_only=True
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Navigate up to the model directory (parent of snapshots directory)
|
|
298
|
+
model_path = os.path.dirname(os.path.dirname(snapshot_path))
|
|
299
|
+
|
|
300
|
+
# Delete the entire model directory (including all snapshots)
|
|
301
|
+
if os.path.exists(model_path):
|
|
302
|
+
shutil.rmtree(model_path)
|
|
303
|
+
print(f"Successfully deleted model {model_name} from {model_path}")
|
|
304
|
+
else:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Model {model_name} not found locally at {model_path}"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
except Exception as e:
|
|
310
|
+
if (
|
|
311
|
+
"not found in cache" in str(e).lower()
|
|
312
|
+
or "no such file" in str(e).lower()
|
|
313
|
+
):
|
|
314
|
+
raise ValueError(f"Model {model_name} is not installed locally")
|
|
315
|
+
else:
|
|
316
|
+
raise ValueError(f"Failed to delete model {model_name}: {str(e)}")
|
|
317
|
+
|
|
146
318
|
|
|
147
319
|
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
148
320
|
# Modifications Copyright (c) 2025 AMD
|
|
@@ -15,9 +15,8 @@ class LoadConfig(BaseModel):
|
|
|
15
15
|
and hardware/framework configuration (recipe) for model loading.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
model_name:
|
|
18
|
+
model_name: str
|
|
19
19
|
checkpoint: Optional[str] = None
|
|
20
|
-
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
|
|
21
20
|
recipe: Optional[str] = None
|
|
22
21
|
# Indicates the maximum prompt length allowed for that specific
|
|
23
22
|
# checkpoint + recipe combination
|
|
@@ -77,9 +76,15 @@ class ResponsesRequest(BaseModel):
|
|
|
77
76
|
stream: bool = False
|
|
78
77
|
|
|
79
78
|
|
|
80
|
-
class PullConfig(
|
|
79
|
+
class PullConfig(LoadConfig):
|
|
80
|
+
"""
|
|
81
|
+
Pull and load have the same fields.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class DeleteConfig(BaseModel):
|
|
81
86
|
"""
|
|
82
|
-
|
|
87
|
+
Configuration for deleting a supported LLM.
|
|
83
88
|
"""
|
|
84
89
|
|
|
85
90
|
model_name: str
|
|
@@ -193,13 +193,15 @@
|
|
|
193
193
|
"mmproj": "mmproj-model-f16.gguf",
|
|
194
194
|
"recipe": "llamacpp",
|
|
195
195
|
"reasoning": false,
|
|
196
|
-
"suggested": true
|
|
196
|
+
"suggested": true,
|
|
197
|
+
"labels": ["vision"]
|
|
197
198
|
},
|
|
198
|
-
"Qwen2.5-VL-7B-Instruct": {
|
|
199
|
+
"Qwen2.5-VL-7B-Instruct-GGUF": {
|
|
199
200
|
"checkpoint": "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M",
|
|
200
201
|
"mmproj": "mmproj-Qwen2.5-VL-7B-Instruct-f16.gguf",
|
|
201
202
|
"recipe": "llamacpp",
|
|
202
203
|
"reasoning": false,
|
|
203
|
-
"suggested": true
|
|
204
|
+
"suggested": true,
|
|
205
|
+
"labels": ["vision"]
|
|
204
206
|
}
|
|
205
207
|
}
|
lemonade/common/analyze_model.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import torch
|
|
3
|
-
import onnx
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def count_parameters(model: torch.nn.Module) -> int:
|
|
7
|
-
"""
|
|
8
|
-
Returns the number of parameters of a given model
|
|
9
|
-
"""
|
|
10
|
-
if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
|
|
11
|
-
return sum([parameter.numel() for _, parameter in model.named_parameters()])
|
|
12
|
-
elif isinstance(model, str) and model.endswith(".onnx"):
|
|
13
|
-
onnx_model = onnx.load(model)
|
|
14
|
-
return int(
|
|
15
|
-
sum(
|
|
16
|
-
np.prod(tensor.dims, dtype=np.int64)
|
|
17
|
-
for tensor in onnx_model.graph.initializer
|
|
18
|
-
if tensor.name not in onnx_model.graph.input
|
|
19
|
-
)
|
|
20
|
-
)
|
|
21
|
-
else:
|
|
22
|
-
return None
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
26
|
-
# Modifications Copyright (c) 2025 AMD
|
lemonade/common/labels.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
from typing import Dict, List
|
|
2
|
-
import turnkeyml.common.printing as printing
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def to_dict(label_list: List[str]) -> Dict[str, List[str]]:
|
|
6
|
-
"""
|
|
7
|
-
Convert label list into a dictionary of labels
|
|
8
|
-
"""
|
|
9
|
-
label_dict = {}
|
|
10
|
-
for item in label_list:
|
|
11
|
-
try:
|
|
12
|
-
label_key, label_value = item.split("::")
|
|
13
|
-
label_value = label_value.split(",")
|
|
14
|
-
label_dict[label_key] = label_value
|
|
15
|
-
except ValueError:
|
|
16
|
-
printing.log_warning(
|
|
17
|
-
(
|
|
18
|
-
f"Malformed label {item} found. "
|
|
19
|
-
"Each label must have the format key::value1,value2,... "
|
|
20
|
-
)
|
|
21
|
-
)
|
|
22
|
-
return label_dict
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def load_from_file(file_path: str) -> Dict[str, List[str]]:
|
|
26
|
-
"""
|
|
27
|
-
This function extracts labels from a Python file.
|
|
28
|
-
Labels must be in the first line of a Python file and start with "# labels: "
|
|
29
|
-
Each label must have the format "key::value1,value2,..."
|
|
30
|
-
|
|
31
|
-
Example:
|
|
32
|
-
"# labels: author::google test_group::daily,monthly"
|
|
33
|
-
"""
|
|
34
|
-
# Open file
|
|
35
|
-
with open(file_path, encoding="utf-8") as f:
|
|
36
|
-
first_line = f.readline()
|
|
37
|
-
|
|
38
|
-
# Return label dict
|
|
39
|
-
if "# labels:" in first_line:
|
|
40
|
-
label_list = first_line.replace("\n", "").split(" ")[2:]
|
|
41
|
-
return to_dict(label_list)
|
|
42
|
-
else:
|
|
43
|
-
return {}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def is_subset(label_dict_a: Dict[str, List[str]], label_dict_b: Dict[str, List[str]]):
|
|
47
|
-
"""
|
|
48
|
-
This function returns True if label_dict_a is a subset of label_dict_b.
|
|
49
|
-
More specifically, we return True if:
|
|
50
|
-
* All keys of label_dict_a are also keys of label_dict_b AND,
|
|
51
|
-
* All values of label_dict_a[key] are values of label_dict_b[key]
|
|
52
|
-
"""
|
|
53
|
-
for key in label_dict_a:
|
|
54
|
-
# Skip benchmarking if the label_dict_a key is not a key of label_dict_b
|
|
55
|
-
if key not in label_dict_b:
|
|
56
|
-
return False
|
|
57
|
-
# A label key may point to multiple label values
|
|
58
|
-
# Skip if not all values of label_dict_a[key] are in label_dict_b[key]
|
|
59
|
-
elif not all(elem in label_dict_a[key] for elem in label_dict_b[key]):
|
|
60
|
-
return False
|
|
61
|
-
return True
|
lemonade/common/onnx_helpers.py
DELETED
|
@@ -1,176 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Helper functions for dealing with ONNX files and ONNX models
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import os
|
|
6
|
-
from typing import Tuple, Union
|
|
7
|
-
import re
|
|
8
|
-
import math
|
|
9
|
-
import numpy as np
|
|
10
|
-
import onnx
|
|
11
|
-
import onnxruntime as ort
|
|
12
|
-
import lemonade.common.exceptions as exp
|
|
13
|
-
from lemonade.state import State
|
|
14
|
-
import lemonade.common.build as build
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def check_model(onnx_file, success_message, fail_message) -> bool:
|
|
18
|
-
if os.path.isfile(onnx_file):
|
|
19
|
-
print(success_message)
|
|
20
|
-
else:
|
|
21
|
-
print(fail_message)
|
|
22
|
-
return False
|
|
23
|
-
try:
|
|
24
|
-
onnx.checker.check_model(onnx_file)
|
|
25
|
-
print("\tSuccessfully checked onnx file")
|
|
26
|
-
return True
|
|
27
|
-
except onnx.checker.ValidationError as e:
|
|
28
|
-
print("\tError while checking generated ONNX file")
|
|
29
|
-
print(e)
|
|
30
|
-
return False
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def original_inputs_file(cache_dir: str, build_name: str):
|
|
34
|
-
return os.path.join(build.output_dir(cache_dir, build_name), "inputs.npy")
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def onnx_dir(state: State):
|
|
38
|
-
return os.path.join(build.output_dir(state.cache_dir, state.build_name), "onnx")
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def get_output_names(
|
|
42
|
-
onnx_model: Union[str, onnx.ModelProto],
|
|
43
|
-
): # pylint: disable=no-member
|
|
44
|
-
# Get output names of ONNX file/model
|
|
45
|
-
if not isinstance(onnx_model, onnx.ModelProto): # pylint: disable=no-member
|
|
46
|
-
onnx_model = onnx.load(onnx_model)
|
|
47
|
-
return [node.name for node in onnx_model.graph.output] # pylint: disable=no-member
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def parameter_count(model):
|
|
51
|
-
weights = model.graph.initializer
|
|
52
|
-
parameter_count = 0
|
|
53
|
-
|
|
54
|
-
for w in weights:
|
|
55
|
-
weight = onnx.numpy_helper.to_array(w)
|
|
56
|
-
parameter_count += np.prod(weight.shape)
|
|
57
|
-
return parameter_count
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def io_bytes(onnx_path: str) -> Tuple[int, int]:
|
|
61
|
-
"""Return the number of bytes of each of the inputs and outputs"""
|
|
62
|
-
# pylint: disable = no-member
|
|
63
|
-
|
|
64
|
-
def elem_type_to_bytes(elem_type) -> int:
|
|
65
|
-
"""
|
|
66
|
-
Convert ONNX's elem_type to the number of bytes used by
|
|
67
|
-
hardware to send that specific datatype through PCIe
|
|
68
|
-
"""
|
|
69
|
-
if (
|
|
70
|
-
elem_type == onnx.TensorProto.DataType.UINT8
|
|
71
|
-
or elem_type == onnx.TensorProto.DataType.INT8
|
|
72
|
-
or elem_type == onnx.TensorProto.DataType.BOOL
|
|
73
|
-
):
|
|
74
|
-
# Each bool requires an entire byte
|
|
75
|
-
return 1
|
|
76
|
-
elif (
|
|
77
|
-
elem_type == onnx.TensorProto.DataType.UINT16
|
|
78
|
-
or elem_type == onnx.TensorProto.DataType.INT16
|
|
79
|
-
or elem_type == onnx.TensorProto.DataType.FLOAT16
|
|
80
|
-
):
|
|
81
|
-
return 2
|
|
82
|
-
if (
|
|
83
|
-
elem_type == onnx.TensorProto.DataType.FLOAT
|
|
84
|
-
or elem_type == onnx.TensorProto.DataType.INT32
|
|
85
|
-
or elem_type == onnx.TensorProto.DataType.INT64
|
|
86
|
-
or elem_type == onnx.TensorProto.DataType.DOUBLE
|
|
87
|
-
or elem_type == onnx.TensorProto.DataType.UINT64
|
|
88
|
-
):
|
|
89
|
-
# 64 bit ints are treated as 32 bits everywhere
|
|
90
|
-
# Doubles are treated as floats
|
|
91
|
-
return 4
|
|
92
|
-
elif (
|
|
93
|
-
elem_type == onnx.TensorProto.DataType.COMPLEX64
|
|
94
|
-
or elem_type == onnx.TensorProto.DataType.COMPLEX128
|
|
95
|
-
or elem_type == onnx.TensorProto.DataType.STRING
|
|
96
|
-
or elem_type == onnx.TensorProto.DataType.UNDEFINED
|
|
97
|
-
):
|
|
98
|
-
raise exp.Error("Unsupported data type")
|
|
99
|
-
else:
|
|
100
|
-
raise exp.Error("Unsupported data type (unknown to ONNX)")
|
|
101
|
-
|
|
102
|
-
def get_nodes_bytes(nodes):
|
|
103
|
-
nodes_bytes = {}
|
|
104
|
-
for node in nodes:
|
|
105
|
-
|
|
106
|
-
# Get the number of the data type
|
|
107
|
-
dtype_bytes = elem_type_to_bytes(node.type.tensor_type.elem_type)
|
|
108
|
-
|
|
109
|
-
# Calculate the total number of elements based on the shape
|
|
110
|
-
shape = str(node.type.tensor_type.shape.dim)
|
|
111
|
-
num_elements = np.prod([int(s) for s in shape.split() if s.isdigit()])
|
|
112
|
-
|
|
113
|
-
# Assign a total number of bytes to each node
|
|
114
|
-
nodes_bytes[node.name] = num_elements * dtype_bytes
|
|
115
|
-
|
|
116
|
-
return nodes_bytes
|
|
117
|
-
|
|
118
|
-
# Get the number of bytes of each of the inputs and outputs
|
|
119
|
-
model = onnx.load(onnx_path)
|
|
120
|
-
onnx_input_bytes = get_nodes_bytes(model.graph.input)
|
|
121
|
-
onnx_output_bytes = get_nodes_bytes(model.graph.output)
|
|
122
|
-
|
|
123
|
-
return int(sum(onnx_input_bytes.values())), int(sum(onnx_output_bytes.values()))
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def dtype_ort2str(dtype_str: str):
|
|
127
|
-
if dtype_str == "float16":
|
|
128
|
-
datatype = "float16"
|
|
129
|
-
elif dtype_str == "float":
|
|
130
|
-
datatype = "float32"
|
|
131
|
-
elif dtype_str == "double":
|
|
132
|
-
datatype = "float64"
|
|
133
|
-
elif dtype_str == "long":
|
|
134
|
-
datatype = "int64"
|
|
135
|
-
else:
|
|
136
|
-
datatype = dtype_str
|
|
137
|
-
return datatype
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def dummy_inputs(onnx_file: str) -> dict:
|
|
141
|
-
# Generate dummy inputs of the expected shape and type for the input model
|
|
142
|
-
sess_options = ort.SessionOptions()
|
|
143
|
-
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
144
|
-
onnx_session = ort.InferenceSession(onnx_file, sess_options)
|
|
145
|
-
sess_input = onnx_session.get_inputs()
|
|
146
|
-
|
|
147
|
-
input_stats = []
|
|
148
|
-
for _idx, input_ in enumerate(range(len(sess_input))):
|
|
149
|
-
input_name = sess_input[input_].name
|
|
150
|
-
input_shape = sess_input[input_].shape
|
|
151
|
-
|
|
152
|
-
# TODO: Use onnx update_inputs_outputs_dims to automatically freeze models
|
|
153
|
-
for dim in input_shape:
|
|
154
|
-
if isinstance(dim, str) is True or math.isnan(dim) is True:
|
|
155
|
-
raise AssertionError(
|
|
156
|
-
"Error: Model has dynamic inputs. Freeze the graph and try again"
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
input_type = sess_input[input_].type
|
|
160
|
-
input_stats.append([input_name, input_shape, input_type])
|
|
161
|
-
|
|
162
|
-
input_feed = {}
|
|
163
|
-
for stat in input_stats:
|
|
164
|
-
dtype_str = re.search(r"\((.*)\)", stat[2])
|
|
165
|
-
assert dtype_str is not None
|
|
166
|
-
datatype = dtype_ort2str(dtype_str.group(1))
|
|
167
|
-
input_feed[stat[0]] = np.random.rand(*stat[1]).astype(datatype)
|
|
168
|
-
return input_feed
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
def get_opset(model: onnx.ModelProto) -> int:
|
|
172
|
-
return getattr(model.opset_import[0], "version", None)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
176
|
-
# Modifications Copyright (c) 2025 AMD
|