lemonade-sdk 7.0.4__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (56) hide show
  1. lemonade/api.py +3 -3
  2. lemonade/cli.py +11 -17
  3. lemonade/common/build.py +0 -47
  4. lemonade/common/network.py +50 -0
  5. lemonade/common/status.py +2 -21
  6. lemonade/common/system_info.py +19 -4
  7. lemonade/profilers/memory_tracker.py +3 -1
  8. lemonade/tools/accuracy.py +3 -4
  9. lemonade/tools/adapter.py +1 -2
  10. lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
  11. lemonade/tools/huggingface/load.py +235 -0
  12. lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
  13. lemonade/tools/humaneval.py +9 -3
  14. lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
  15. lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
  16. lemonade/tools/mmlu.py +7 -15
  17. lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
  18. lemonade/tools/oga/utils.py +423 -0
  19. lemonade/tools/perplexity.py +4 -3
  20. lemonade/tools/prompt.py +2 -1
  21. lemonade/tools/quark/quark_load.py +2 -1
  22. lemonade/tools/quark/quark_quantize.py +5 -5
  23. lemonade/tools/report/table.py +3 -3
  24. lemonade/tools/server/llamacpp.py +188 -45
  25. lemonade/tools/server/serve.py +184 -146
  26. lemonade/tools/server/static/favicon.ico +0 -0
  27. lemonade/tools/server/static/styles.css +568 -0
  28. lemonade/tools/server/static/webapp.html +439 -0
  29. lemonade/tools/server/tray.py +458 -0
  30. lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
  31. lemonade/tools/server/utils/system_tray.py +395 -0
  32. lemonade/tools/server/{instructions.py → webapp.py} +4 -10
  33. lemonade/version.py +1 -1
  34. lemonade_install/install.py +46 -28
  35. lemonade_sdk-8.0.1.dist-info/METADATA +179 -0
  36. lemonade_sdk-8.0.1.dist-info/RECORD +70 -0
  37. lemonade_server/cli.py +182 -27
  38. lemonade_server/model_manager.py +192 -20
  39. lemonade_server/pydantic_models.py +9 -4
  40. lemonade_server/server_models.json +5 -3
  41. lemonade/common/analyze_model.py +0 -26
  42. lemonade/common/labels.py +0 -61
  43. lemonade/common/onnx_helpers.py +0 -176
  44. lemonade/common/plugins.py +0 -10
  45. lemonade/common/tensor_helpers.py +0 -83
  46. lemonade/tools/server/static/instructions.html +0 -262
  47. lemonade_sdk-7.0.4.dist-info/METADATA +0 -113
  48. lemonade_sdk-7.0.4.dist-info/RECORD +0 -69
  49. /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
  50. /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
  51. /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
  52. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/WHEEL +0 -0
  53. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/entry_points.txt +0 -0
  54. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/LICENSE +0 -0
  55. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/NOTICE.md +0 -0
  56. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,60 @@
1
1
  import json
2
2
  import os
3
+ from typing import Optional
4
+ import shutil
3
5
  import huggingface_hub
4
6
  from importlib.metadata import distributions
5
- from lemonade_server.pydantic_models import LoadConfig
7
+ from lemonade_server.pydantic_models import PullConfig
8
+ from lemonade.cache import DEFAULT_CACHE_DIR
9
+
10
+ USER_MODELS_FILE = os.path.join(DEFAULT_CACHE_DIR, "user_models.json")
6
11
 
7
12
 
8
13
  class ModelManager:
9
14
 
15
+ @staticmethod
16
+ def parse_checkpoint(checkpoint: str) -> tuple[str, str | None]:
17
+ """
18
+ Parse a checkpoint string that may contain a variant separated by a colon.
19
+
20
+ For GGUF models, the format is "repository:variant" (e.g., "unsloth/Qwen3-0.6B-GGUF:Q4_0").
21
+ For other models, there is no variant.
22
+
23
+ Args:
24
+ checkpoint: The checkpoint string, potentially with variant
25
+
26
+ Returns:
27
+ tuple: (base_checkpoint, variant) where variant is None if no colon is present
28
+ """
29
+ if ":" in checkpoint:
30
+ base_checkpoint, variant = checkpoint.split(":", 1)
31
+ return base_checkpoint, variant
32
+ return checkpoint, None
33
+
10
34
  @property
11
35
  def supported_models(self) -> dict:
12
36
  """
13
37
  Returns a dictionary of supported models.
14
38
  Note: Models must be downloaded before they are locally available.
15
39
  """
16
- # Load the models dictionary from the JSON file
40
+ # Load the models dictionary from the built-in JSON file
17
41
  server_models_file = os.path.join(
18
42
  os.path.dirname(__file__), "server_models.json"
19
43
  )
20
44
  with open(server_models_file, "r", encoding="utf-8") as file:
21
- models = json.load(file)
45
+ models: dict = json.load(file)
46
+
47
+ # Load the user's JSON file, if it exists, and merge into the models dict
48
+ if os.path.exists(USER_MODELS_FILE):
49
+ with open(USER_MODELS_FILE, "r", encoding="utf-8") as file:
50
+ user_models: dict = json.load(file)
51
+ # Prepend the user namespace to the model names
52
+ user_models = {
53
+ f"user.{model_name}": model_info
54
+ for model_name, model_info in user_models.items()
55
+ }
56
+
57
+ models.update(user_models)
22
58
 
23
59
  # Add the model name as a key in each entry, to make it easier
24
60
  # to access later
@@ -49,11 +85,12 @@ class ModelManager:
49
85
  Returns a dictionary of locally available models.
50
86
  """
51
87
  downloaded_models = {}
88
+ downloaded_checkpoints = self.downloaded_hf_checkpoints
52
89
  for model in self.supported_models:
53
- if (
54
- self.supported_models[model]["checkpoint"].split(":")[0]
55
- in self.downloaded_hf_checkpoints
56
- ):
90
+ base_checkpoint = self.parse_checkpoint(
91
+ self.supported_models[model]["checkpoint"]
92
+ )[0]
93
+ if base_checkpoint in downloaded_checkpoints:
57
94
  downloaded_models[model] = self.supported_models[model]
58
95
  return downloaded_models
59
96
 
@@ -65,7 +102,7 @@ class ModelManager:
65
102
  """
66
103
  return self.filter_models_by_backend(self.downloaded_models)
67
104
 
68
- def download_gguf(self, model_config: LoadConfig) -> dict:
105
+ def download_gguf(self, model_config: PullConfig) -> dict:
69
106
  """
70
107
  Downloads the GGUF file for the given model configuration.
71
108
  """
@@ -74,7 +111,7 @@ class ModelManager:
74
111
  # 1. A full GGUF filename (e.g. "model-Q4_0.gguf")
75
112
  # 2. A quantization variant (e.g. "Q4_0")
76
113
  # This code handles both cases by constructing the appropriate filename
77
- checkpoint, variant = model_config.checkpoint.split(":")
114
+ checkpoint, variant = self.parse_checkpoint(model_config.checkpoint)
78
115
  hf_base_name = checkpoint.split("/")[-1].replace("-GGUF", "")
79
116
  variant_name = (
80
117
  variant if variant.endswith(".gguf") else f"{hf_base_name}-{variant}.gguf"
@@ -91,11 +128,24 @@ class ModelManager:
91
128
  allow_patterns=list(expected_files.values()),
92
129
  )
93
130
 
131
+ # Make sure we downloaded something
132
+ # If we didn't that can indicate that no patterns from allow_patterns match
133
+ # any files in the HF repo
134
+ if not os.path.exists(snapshot_folder):
135
+ raise ValueError(
136
+ "No patterns matched the variant parameter (CHECKPOINT:VARIANT). "
137
+ "Try again, providing the full filename of your target .gguf file as the variant."
138
+ " For example: Qwen/Qwen2.5-Coder-3B-Instruct-GGUF:"
139
+ "qwen2.5-coder-3b-instruct-q4_0.gguf"
140
+ )
141
+
94
142
  # Ensure we downloaded all expected files while creating a dict of the downloaded files
95
143
  snapshot_files = {}
96
144
  for file in expected_files:
97
145
  snapshot_files[file] = os.path.join(snapshot_folder, expected_files[file])
98
- if expected_files[file] not in os.listdir(snapshot_folder):
146
+ if expected_files[file].lower() not in [
147
+ name.lower() for name in os.listdir(snapshot_folder)
148
+ ]:
99
149
  raise ValueError(
100
150
  f"Hugging Face snapshot download for {model_config.checkpoint} "
101
151
  f"expected file {expected_files[file]} not found in {snapshot_folder}"
@@ -104,24 +154,103 @@ class ModelManager:
104
154
  # Return a dict that points to the snapshot path of the downloaded GGUF files
105
155
  return snapshot_files
106
156
 
107
- def download_models(self, models: list[str]):
157
+ def download_models(
158
+ self,
159
+ models: list[str],
160
+ checkpoint: Optional[str] = None,
161
+ recipe: Optional[str] = None,
162
+ reasoning: bool = False,
163
+ mmproj: str = "",
164
+ ):
108
165
  """
109
166
  Downloads the specified models from Hugging Face.
110
167
  """
111
168
  for model in models:
112
169
  if model not in self.supported_models:
113
- raise ValueError(
114
- f"Model {model} is not supported. Please choose from the following: "
115
- f"{list(self.supported_models.keys())}"
170
+ # Register the model as a user model if the model name
171
+ # is not already registered
172
+
173
+ # Ensure the model name includes the `user` namespace
174
+ model_parsed = model.split(".", 1)
175
+ if len(model_parsed) != 2 or model_parsed[0] != "user":
176
+ raise ValueError(
177
+ f"When registering a new model, the model name must "
178
+ "include the `user` namespace, for example "
179
+ f"`user.Phi-4-Mini-GGUF`. Received: {model}"
180
+ )
181
+
182
+ model_name = model_parsed[1]
183
+
184
+ # Check that required arguments are provided
185
+ if not recipe or not checkpoint:
186
+ raise ValueError(
187
+ f"Model {model} is not registered with Lemonade Server. "
188
+ "To register and install it, provide the `checkpoint` "
189
+ "and `recipe` arguments, as well as the optional "
190
+ "`reasoning` and `mmproj` arguments as appropriate. "
191
+ )
192
+
193
+ # JSON content that will be used for registration if the download succeeds
194
+ new_user_model = {
195
+ "checkpoint": checkpoint,
196
+ "recipe": recipe,
197
+ "reasoning": reasoning,
198
+ "suggested": True,
199
+ "labels": ["custom"],
200
+ }
201
+
202
+ if mmproj:
203
+ new_user_model["mmproj"] = mmproj
204
+
205
+ # Make sure that a variant is provided for GGUF models before registering the model
206
+ if "gguf" in checkpoint.lower() and ":" not in checkpoint.lower():
207
+ raise ValueError(
208
+ "You are required to provide a 'variant' in the checkpoint field when "
209
+ "registering a GGUF model. The variant is provided "
210
+ "as CHECKPOINT:VARIANT. For example: "
211
+ "Qwen/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_0 or "
212
+ "Qwen/Qwen2.5-Coder-3B-Instruct-GGUF:"
213
+ "qwen2.5-coder-3b-instruct-q4_0.gguf"
214
+ )
215
+
216
+ # Create a PullConfig we will use to download the model
217
+ new_registration_model_config = PullConfig(
218
+ model_name=model_name,
219
+ checkpoint=checkpoint,
220
+ recipe=recipe,
221
+ reasoning=reasoning,
116
222
  )
117
- checkpoint = self.supported_models[model]["checkpoint"]
118
- print(f"Downloading {model} ({checkpoint})")
223
+ else:
224
+ new_registration_model_config = None
119
225
 
120
- if "gguf" in checkpoint.lower():
121
- model_config = LoadConfig(**self.supported_models[model])
122
- self.download_gguf(model_config)
226
+ # Download the model
227
+ if new_registration_model_config:
228
+ checkpoint_to_download = checkpoint
229
+ gguf_model_config = new_registration_model_config
123
230
  else:
124
- huggingface_hub.snapshot_download(repo_id=checkpoint)
231
+ checkpoint_to_download = self.supported_models[model]["checkpoint"]
232
+ gguf_model_config = PullConfig(**self.supported_models[model])
233
+ print(f"Downloading {model} ({checkpoint_to_download})")
234
+
235
+ if "gguf" in checkpoint_to_download.lower():
236
+ self.download_gguf(gguf_model_config)
237
+ else:
238
+ huggingface_hub.snapshot_download(repo_id=checkpoint_to_download)
239
+
240
+ # Register the model in user_models.json, creating that file if needed
241
+ # We do this registration after the download so that we don't register
242
+ # any incorrectly configured models where the download would fail
243
+ if new_registration_model_config:
244
+ if os.path.exists(USER_MODELS_FILE):
245
+ with open(USER_MODELS_FILE, "r", encoding="utf-8") as file:
246
+ user_models: dict = json.load(file)
247
+ else:
248
+ user_models = {}
249
+
250
+ user_models[model_name] = new_user_model
251
+
252
+ with open(USER_MODELS_FILE, mode="w", encoding="utf-8") as file:
253
+ json.dump(user_models, fp=file)
125
254
 
126
255
  def filter_models_by_backend(self, models: dict) -> dict:
127
256
  """
@@ -143,6 +272,49 @@ class ModelManager:
143
272
  filtered[model] = value
144
273
  return filtered
145
274
 
275
+ def delete_model(self, model_name: str):
276
+ """
277
+ Deletes the specified model from local storage.
278
+ """
279
+ if model_name not in self.supported_models:
280
+ raise ValueError(
281
+ f"Model {model_name} is not supported. Please choose from the following: "
282
+ f"{list(self.supported_models.keys())}"
283
+ )
284
+
285
+ checkpoint = self.supported_models[model_name]["checkpoint"]
286
+ print(f"Deleting {model_name} ({checkpoint})")
287
+
288
+ # Handle GGUF models that have the format "checkpoint:variant"
289
+ base_checkpoint = self.parse_checkpoint(checkpoint)[0]
290
+
291
+ try:
292
+ # Get the local path using snapshot_download with local_files_only=True
293
+ snapshot_path = huggingface_hub.snapshot_download(
294
+ repo_id=base_checkpoint, local_files_only=True
295
+ )
296
+
297
+ # Navigate up to the model directory (parent of snapshots directory)
298
+ model_path = os.path.dirname(os.path.dirname(snapshot_path))
299
+
300
+ # Delete the entire model directory (including all snapshots)
301
+ if os.path.exists(model_path):
302
+ shutil.rmtree(model_path)
303
+ print(f"Successfully deleted model {model_name} from {model_path}")
304
+ else:
305
+ raise ValueError(
306
+ f"Model {model_name} not found locally at {model_path}"
307
+ )
308
+
309
+ except Exception as e:
310
+ if (
311
+ "not found in cache" in str(e).lower()
312
+ or "no such file" in str(e).lower()
313
+ ):
314
+ raise ValueError(f"Model {model_name} is not installed locally")
315
+ else:
316
+ raise ValueError(f"Failed to delete model {model_name}: {str(e)}")
317
+
146
318
 
147
319
  # This file was originally licensed under Apache 2.0. It has been modified.
148
320
  # Modifications Copyright (c) 2025 AMD
@@ -15,9 +15,8 @@ class LoadConfig(BaseModel):
15
15
  and hardware/framework configuration (recipe) for model loading.
16
16
  """
17
17
 
18
- model_name: Optional[str] = None
18
+ model_name: str
19
19
  checkpoint: Optional[str] = None
20
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
21
20
  recipe: Optional[str] = None
22
21
  # Indicates the maximum prompt length allowed for that specific
23
22
  # checkpoint + recipe combination
@@ -77,9 +76,15 @@ class ResponsesRequest(BaseModel):
77
76
  stream: bool = False
78
77
 
79
78
 
80
- class PullConfig(BaseModel):
79
+ class PullConfig(LoadConfig):
80
+ """
81
+ Pull and load have the same fields.
82
+ """
83
+
84
+
85
+ class DeleteConfig(BaseModel):
81
86
  """
82
- Configurating for installing a supported LLM.
87
+ Configuration for deleting a supported LLM.
83
88
  """
84
89
 
85
90
  model_name: str
@@ -193,13 +193,15 @@
193
193
  "mmproj": "mmproj-model-f16.gguf",
194
194
  "recipe": "llamacpp",
195
195
  "reasoning": false,
196
- "suggested": true
196
+ "suggested": true,
197
+ "labels": ["vision"]
197
198
  },
198
- "Qwen2.5-VL-7B-Instruct": {
199
+ "Qwen2.5-VL-7B-Instruct-GGUF": {
199
200
  "checkpoint": "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M",
200
201
  "mmproj": "mmproj-Qwen2.5-VL-7B-Instruct-f16.gguf",
201
202
  "recipe": "llamacpp",
202
203
  "reasoning": false,
203
- "suggested": true
204
+ "suggested": true,
205
+ "labels": ["vision"]
204
206
  }
205
207
  }
@@ -1,26 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import onnx
4
-
5
-
6
- def count_parameters(model: torch.nn.Module) -> int:
7
- """
8
- Returns the number of parameters of a given model
9
- """
10
- if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
11
- return sum([parameter.numel() for _, parameter in model.named_parameters()])
12
- elif isinstance(model, str) and model.endswith(".onnx"):
13
- onnx_model = onnx.load(model)
14
- return int(
15
- sum(
16
- np.prod(tensor.dims, dtype=np.int64)
17
- for tensor in onnx_model.graph.initializer
18
- if tensor.name not in onnx_model.graph.input
19
- )
20
- )
21
- else:
22
- return None
23
-
24
-
25
- # This file was originally licensed under Apache 2.0. It has been modified.
26
- # Modifications Copyright (c) 2025 AMD
lemonade/common/labels.py DELETED
@@ -1,61 +0,0 @@
1
- from typing import Dict, List
2
- import turnkeyml.common.printing as printing
3
-
4
-
5
- def to_dict(label_list: List[str]) -> Dict[str, List[str]]:
6
- """
7
- Convert label list into a dictionary of labels
8
- """
9
- label_dict = {}
10
- for item in label_list:
11
- try:
12
- label_key, label_value = item.split("::")
13
- label_value = label_value.split(",")
14
- label_dict[label_key] = label_value
15
- except ValueError:
16
- printing.log_warning(
17
- (
18
- f"Malformed label {item} found. "
19
- "Each label must have the format key::value1,value2,... "
20
- )
21
- )
22
- return label_dict
23
-
24
-
25
- def load_from_file(file_path: str) -> Dict[str, List[str]]:
26
- """
27
- This function extracts labels from a Python file.
28
- Labels must be in the first line of a Python file and start with "# labels: "
29
- Each label must have the format "key::value1,value2,..."
30
-
31
- Example:
32
- "# labels: author::google test_group::daily,monthly"
33
- """
34
- # Open file
35
- with open(file_path, encoding="utf-8") as f:
36
- first_line = f.readline()
37
-
38
- # Return label dict
39
- if "# labels:" in first_line:
40
- label_list = first_line.replace("\n", "").split(" ")[2:]
41
- return to_dict(label_list)
42
- else:
43
- return {}
44
-
45
-
46
- def is_subset(label_dict_a: Dict[str, List[str]], label_dict_b: Dict[str, List[str]]):
47
- """
48
- This function returns True if label_dict_a is a subset of label_dict_b.
49
- More specifically, we return True if:
50
- * All keys of label_dict_a are also keys of label_dict_b AND,
51
- * All values of label_dict_a[key] are values of label_dict_b[key]
52
- """
53
- for key in label_dict_a:
54
- # Skip benchmarking if the label_dict_a key is not a key of label_dict_b
55
- if key not in label_dict_b:
56
- return False
57
- # A label key may point to multiple label values
58
- # Skip if not all values of label_dict_a[key] are in label_dict_b[key]
59
- elif not all(elem in label_dict_a[key] for elem in label_dict_b[key]):
60
- return False
61
- return True
@@ -1,176 +0,0 @@
1
- """
2
- Helper functions for dealing with ONNX files and ONNX models
3
- """
4
-
5
- import os
6
- from typing import Tuple, Union
7
- import re
8
- import math
9
- import numpy as np
10
- import onnx
11
- import onnxruntime as ort
12
- import lemonade.common.exceptions as exp
13
- from lemonade.state import State
14
- import lemonade.common.build as build
15
-
16
-
17
- def check_model(onnx_file, success_message, fail_message) -> bool:
18
- if os.path.isfile(onnx_file):
19
- print(success_message)
20
- else:
21
- print(fail_message)
22
- return False
23
- try:
24
- onnx.checker.check_model(onnx_file)
25
- print("\tSuccessfully checked onnx file")
26
- return True
27
- except onnx.checker.ValidationError as e:
28
- print("\tError while checking generated ONNX file")
29
- print(e)
30
- return False
31
-
32
-
33
- def original_inputs_file(cache_dir: str, build_name: str):
34
- return os.path.join(build.output_dir(cache_dir, build_name), "inputs.npy")
35
-
36
-
37
- def onnx_dir(state: State):
38
- return os.path.join(build.output_dir(state.cache_dir, state.build_name), "onnx")
39
-
40
-
41
- def get_output_names(
42
- onnx_model: Union[str, onnx.ModelProto],
43
- ): # pylint: disable=no-member
44
- # Get output names of ONNX file/model
45
- if not isinstance(onnx_model, onnx.ModelProto): # pylint: disable=no-member
46
- onnx_model = onnx.load(onnx_model)
47
- return [node.name for node in onnx_model.graph.output] # pylint: disable=no-member
48
-
49
-
50
- def parameter_count(model):
51
- weights = model.graph.initializer
52
- parameter_count = 0
53
-
54
- for w in weights:
55
- weight = onnx.numpy_helper.to_array(w)
56
- parameter_count += np.prod(weight.shape)
57
- return parameter_count
58
-
59
-
60
- def io_bytes(onnx_path: str) -> Tuple[int, int]:
61
- """Return the number of bytes of each of the inputs and outputs"""
62
- # pylint: disable = no-member
63
-
64
- def elem_type_to_bytes(elem_type) -> int:
65
- """
66
- Convert ONNX's elem_type to the number of bytes used by
67
- hardware to send that specific datatype through PCIe
68
- """
69
- if (
70
- elem_type == onnx.TensorProto.DataType.UINT8
71
- or elem_type == onnx.TensorProto.DataType.INT8
72
- or elem_type == onnx.TensorProto.DataType.BOOL
73
- ):
74
- # Each bool requires an entire byte
75
- return 1
76
- elif (
77
- elem_type == onnx.TensorProto.DataType.UINT16
78
- or elem_type == onnx.TensorProto.DataType.INT16
79
- or elem_type == onnx.TensorProto.DataType.FLOAT16
80
- ):
81
- return 2
82
- if (
83
- elem_type == onnx.TensorProto.DataType.FLOAT
84
- or elem_type == onnx.TensorProto.DataType.INT32
85
- or elem_type == onnx.TensorProto.DataType.INT64
86
- or elem_type == onnx.TensorProto.DataType.DOUBLE
87
- or elem_type == onnx.TensorProto.DataType.UINT64
88
- ):
89
- # 64 bit ints are treated as 32 bits everywhere
90
- # Doubles are treated as floats
91
- return 4
92
- elif (
93
- elem_type == onnx.TensorProto.DataType.COMPLEX64
94
- or elem_type == onnx.TensorProto.DataType.COMPLEX128
95
- or elem_type == onnx.TensorProto.DataType.STRING
96
- or elem_type == onnx.TensorProto.DataType.UNDEFINED
97
- ):
98
- raise exp.Error("Unsupported data type")
99
- else:
100
- raise exp.Error("Unsupported data type (unknown to ONNX)")
101
-
102
- def get_nodes_bytes(nodes):
103
- nodes_bytes = {}
104
- for node in nodes:
105
-
106
- # Get the number of the data type
107
- dtype_bytes = elem_type_to_bytes(node.type.tensor_type.elem_type)
108
-
109
- # Calculate the total number of elements based on the shape
110
- shape = str(node.type.tensor_type.shape.dim)
111
- num_elements = np.prod([int(s) for s in shape.split() if s.isdigit()])
112
-
113
- # Assign a total number of bytes to each node
114
- nodes_bytes[node.name] = num_elements * dtype_bytes
115
-
116
- return nodes_bytes
117
-
118
- # Get the number of bytes of each of the inputs and outputs
119
- model = onnx.load(onnx_path)
120
- onnx_input_bytes = get_nodes_bytes(model.graph.input)
121
- onnx_output_bytes = get_nodes_bytes(model.graph.output)
122
-
123
- return int(sum(onnx_input_bytes.values())), int(sum(onnx_output_bytes.values()))
124
-
125
-
126
- def dtype_ort2str(dtype_str: str):
127
- if dtype_str == "float16":
128
- datatype = "float16"
129
- elif dtype_str == "float":
130
- datatype = "float32"
131
- elif dtype_str == "double":
132
- datatype = "float64"
133
- elif dtype_str == "long":
134
- datatype = "int64"
135
- else:
136
- datatype = dtype_str
137
- return datatype
138
-
139
-
140
- def dummy_inputs(onnx_file: str) -> dict:
141
- # Generate dummy inputs of the expected shape and type for the input model
142
- sess_options = ort.SessionOptions()
143
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
144
- onnx_session = ort.InferenceSession(onnx_file, sess_options)
145
- sess_input = onnx_session.get_inputs()
146
-
147
- input_stats = []
148
- for _idx, input_ in enumerate(range(len(sess_input))):
149
- input_name = sess_input[input_].name
150
- input_shape = sess_input[input_].shape
151
-
152
- # TODO: Use onnx update_inputs_outputs_dims to automatically freeze models
153
- for dim in input_shape:
154
- if isinstance(dim, str) is True or math.isnan(dim) is True:
155
- raise AssertionError(
156
- "Error: Model has dynamic inputs. Freeze the graph and try again"
157
- )
158
-
159
- input_type = sess_input[input_].type
160
- input_stats.append([input_name, input_shape, input_type])
161
-
162
- input_feed = {}
163
- for stat in input_stats:
164
- dtype_str = re.search(r"\((.*)\)", stat[2])
165
- assert dtype_str is not None
166
- datatype = dtype_ort2str(dtype_str.group(1))
167
- input_feed[stat[0]] = np.random.rand(*stat[1]).astype(datatype)
168
- return input_feed
169
-
170
-
171
- def get_opset(model: onnx.ModelProto) -> int:
172
- return getattr(model.opset_import[0], "version", None)
173
-
174
-
175
- # This file was originally licensed under Apache 2.0. It has been modified.
176
- # Modifications Copyright (c) 2025 AMD
@@ -1,10 +0,0 @@
1
- import pkgutil
2
- import importlib
3
-
4
-
5
- def discover():
6
- return {
7
- name: importlib.import_module(name)
8
- for _, name, _ in pkgutil.iter_modules()
9
- if name.startswith("turnkeyml_plugin_")
10
- }