lemonade-sdk 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/__init__.py +5 -0
- lemonade/api.py +125 -0
- lemonade/cache.py +85 -0
- lemonade/cli.py +135 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/analyze_model.py +26 -0
- lemonade/common/build.py +223 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/labels.py +61 -0
- lemonade/common/onnx_helpers.py +176 -0
- lemonade/common/plugins.py +10 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +490 -0
- lemonade/common/system_info.py +390 -0
- lemonade/common/tensor_helpers.py +83 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/memory_tracker.py +257 -0
- lemonade/profilers/profiler.py +55 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/adapter.py +104 -0
- lemonade/tools/bench.py +284 -0
- lemonade/tools/huggingface_bench.py +267 -0
- lemonade/tools/huggingface_load.py +520 -0
- lemonade/tools/humaneval.py +258 -0
- lemonade/tools/llamacpp.py +261 -0
- lemonade/tools/llamacpp_bench.py +154 -0
- lemonade/tools/management_tools.py +273 -0
- lemonade/tools/mmlu.py +327 -0
- lemonade/tools/ort_genai/__init__.py +0 -0
- lemonade/tools/ort_genai/oga.py +1129 -0
- lemonade/tools/ort_genai/oga_bench.py +142 -0
- lemonade/tools/perplexity.py +146 -0
- lemonade/tools/prompt.py +228 -0
- lemonade/tools/quark/__init__.py +0 -0
- lemonade/tools/quark/quark_load.py +172 -0
- lemonade/tools/quark/quark_quantize.py +439 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +739 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/serve.py +1354 -0
- lemonade/tools/server/tool_calls.py +146 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +774 -0
- lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
- lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
- lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
- lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
- lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
- lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +260 -0
- lemonade_server/model_manager.py +98 -0
- lemonade_server/server_models.json +142 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper functions for dealing with ONNX files and ONNX models
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Tuple, Union
|
|
7
|
+
import re
|
|
8
|
+
import math
|
|
9
|
+
import numpy as np
|
|
10
|
+
import onnx
|
|
11
|
+
import onnxruntime as ort
|
|
12
|
+
import lemonade.common.exceptions as exp
|
|
13
|
+
from lemonade.state import State
|
|
14
|
+
import lemonade.common.build as build
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def check_model(onnx_file, success_message, fail_message) -> bool:
|
|
18
|
+
if os.path.isfile(onnx_file):
|
|
19
|
+
print(success_message)
|
|
20
|
+
else:
|
|
21
|
+
print(fail_message)
|
|
22
|
+
return False
|
|
23
|
+
try:
|
|
24
|
+
onnx.checker.check_model(onnx_file)
|
|
25
|
+
print("\tSuccessfully checked onnx file")
|
|
26
|
+
return True
|
|
27
|
+
except onnx.checker.ValidationError as e:
|
|
28
|
+
print("\tError while checking generated ONNX file")
|
|
29
|
+
print(e)
|
|
30
|
+
return False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def original_inputs_file(cache_dir: str, build_name: str):
|
|
34
|
+
return os.path.join(build.output_dir(cache_dir, build_name), "inputs.npy")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def onnx_dir(state: State):
|
|
38
|
+
return os.path.join(build.output_dir(state.cache_dir, state.build_name), "onnx")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_output_names(
|
|
42
|
+
onnx_model: Union[str, onnx.ModelProto],
|
|
43
|
+
): # pylint: disable=no-member
|
|
44
|
+
# Get output names of ONNX file/model
|
|
45
|
+
if not isinstance(onnx_model, onnx.ModelProto): # pylint: disable=no-member
|
|
46
|
+
onnx_model = onnx.load(onnx_model)
|
|
47
|
+
return [node.name for node in onnx_model.graph.output] # pylint: disable=no-member
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def parameter_count(model):
|
|
51
|
+
weights = model.graph.initializer
|
|
52
|
+
parameter_count = 0
|
|
53
|
+
|
|
54
|
+
for w in weights:
|
|
55
|
+
weight = onnx.numpy_helper.to_array(w)
|
|
56
|
+
parameter_count += np.prod(weight.shape)
|
|
57
|
+
return parameter_count
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def io_bytes(onnx_path: str) -> Tuple[int, int]:
|
|
61
|
+
"""Return the number of bytes of each of the inputs and outputs"""
|
|
62
|
+
# pylint: disable = no-member
|
|
63
|
+
|
|
64
|
+
def elem_type_to_bytes(elem_type) -> int:
|
|
65
|
+
"""
|
|
66
|
+
Convert ONNX's elem_type to the number of bytes used by
|
|
67
|
+
hardware to send that specific datatype through PCIe
|
|
68
|
+
"""
|
|
69
|
+
if (
|
|
70
|
+
elem_type == onnx.TensorProto.DataType.UINT8
|
|
71
|
+
or elem_type == onnx.TensorProto.DataType.INT8
|
|
72
|
+
or elem_type == onnx.TensorProto.DataType.BOOL
|
|
73
|
+
):
|
|
74
|
+
# Each bool requires an entire byte
|
|
75
|
+
return 1
|
|
76
|
+
elif (
|
|
77
|
+
elem_type == onnx.TensorProto.DataType.UINT16
|
|
78
|
+
or elem_type == onnx.TensorProto.DataType.INT16
|
|
79
|
+
or elem_type == onnx.TensorProto.DataType.FLOAT16
|
|
80
|
+
):
|
|
81
|
+
return 2
|
|
82
|
+
if (
|
|
83
|
+
elem_type == onnx.TensorProto.DataType.FLOAT
|
|
84
|
+
or elem_type == onnx.TensorProto.DataType.INT32
|
|
85
|
+
or elem_type == onnx.TensorProto.DataType.INT64
|
|
86
|
+
or elem_type == onnx.TensorProto.DataType.DOUBLE
|
|
87
|
+
or elem_type == onnx.TensorProto.DataType.UINT64
|
|
88
|
+
):
|
|
89
|
+
# 64 bit ints are treated as 32 bits everywhere
|
|
90
|
+
# Doubles are treated as floats
|
|
91
|
+
return 4
|
|
92
|
+
elif (
|
|
93
|
+
elem_type == onnx.TensorProto.DataType.COMPLEX64
|
|
94
|
+
or elem_type == onnx.TensorProto.DataType.COMPLEX128
|
|
95
|
+
or elem_type == onnx.TensorProto.DataType.STRING
|
|
96
|
+
or elem_type == onnx.TensorProto.DataType.UNDEFINED
|
|
97
|
+
):
|
|
98
|
+
raise exp.Error("Unsupported data type")
|
|
99
|
+
else:
|
|
100
|
+
raise exp.Error("Unsupported data type (unknown to ONNX)")
|
|
101
|
+
|
|
102
|
+
def get_nodes_bytes(nodes):
|
|
103
|
+
nodes_bytes = {}
|
|
104
|
+
for node in nodes:
|
|
105
|
+
|
|
106
|
+
# Get the number of the data type
|
|
107
|
+
dtype_bytes = elem_type_to_bytes(node.type.tensor_type.elem_type)
|
|
108
|
+
|
|
109
|
+
# Calculate the total number of elements based on the shape
|
|
110
|
+
shape = str(node.type.tensor_type.shape.dim)
|
|
111
|
+
num_elements = np.prod([int(s) for s in shape.split() if s.isdigit()])
|
|
112
|
+
|
|
113
|
+
# Assign a total number of bytes to each node
|
|
114
|
+
nodes_bytes[node.name] = num_elements * dtype_bytes
|
|
115
|
+
|
|
116
|
+
return nodes_bytes
|
|
117
|
+
|
|
118
|
+
# Get the number of bytes of each of the inputs and outputs
|
|
119
|
+
model = onnx.load(onnx_path)
|
|
120
|
+
onnx_input_bytes = get_nodes_bytes(model.graph.input)
|
|
121
|
+
onnx_output_bytes = get_nodes_bytes(model.graph.output)
|
|
122
|
+
|
|
123
|
+
return int(sum(onnx_input_bytes.values())), int(sum(onnx_output_bytes.values()))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def dtype_ort2str(dtype_str: str):
|
|
127
|
+
if dtype_str == "float16":
|
|
128
|
+
datatype = "float16"
|
|
129
|
+
elif dtype_str == "float":
|
|
130
|
+
datatype = "float32"
|
|
131
|
+
elif dtype_str == "double":
|
|
132
|
+
datatype = "float64"
|
|
133
|
+
elif dtype_str == "long":
|
|
134
|
+
datatype = "int64"
|
|
135
|
+
else:
|
|
136
|
+
datatype = dtype_str
|
|
137
|
+
return datatype
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def dummy_inputs(onnx_file: str) -> dict:
|
|
141
|
+
# Generate dummy inputs of the expected shape and type for the input model
|
|
142
|
+
sess_options = ort.SessionOptions()
|
|
143
|
+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
144
|
+
onnx_session = ort.InferenceSession(onnx_file, sess_options)
|
|
145
|
+
sess_input = onnx_session.get_inputs()
|
|
146
|
+
|
|
147
|
+
input_stats = []
|
|
148
|
+
for _idx, input_ in enumerate(range(len(sess_input))):
|
|
149
|
+
input_name = sess_input[input_].name
|
|
150
|
+
input_shape = sess_input[input_].shape
|
|
151
|
+
|
|
152
|
+
# TODO: Use onnx update_inputs_outputs_dims to automatically freeze models
|
|
153
|
+
for dim in input_shape:
|
|
154
|
+
if isinstance(dim, str) is True or math.isnan(dim) is True:
|
|
155
|
+
raise AssertionError(
|
|
156
|
+
"Error: Model has dynamic inputs. Freeze the graph and try again"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
input_type = sess_input[input_].type
|
|
160
|
+
input_stats.append([input_name, input_shape, input_type])
|
|
161
|
+
|
|
162
|
+
input_feed = {}
|
|
163
|
+
for stat in input_stats:
|
|
164
|
+
dtype_str = re.search(r"\((.*)\)", stat[2])
|
|
165
|
+
assert dtype_str is not None
|
|
166
|
+
datatype = dtype_ort2str(dtype_str.group(1))
|
|
167
|
+
input_feed[stat[0]] = np.random.rand(*stat[1]).astype(datatype)
|
|
168
|
+
return input_feed
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def get_opset(model: onnx.ModelProto) -> int:
|
|
172
|
+
return getattr(model.opset_import[0], "version", None)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
176
|
+
# Modifications Copyright (c) 2025 AMD
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import enum
|
|
4
|
+
import sys
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Colors:
|
|
9
|
+
HEADER = "\033[95m"
|
|
10
|
+
OKBLUE = "\033[94m"
|
|
11
|
+
OKCYAN = "\033[96m"
|
|
12
|
+
OKGREEN = "\033[92m"
|
|
13
|
+
WARNING = "\033[93m"
|
|
14
|
+
FAIL = "\033[91m"
|
|
15
|
+
ENDC = "\033[0m"
|
|
16
|
+
BOLD = "\033[1m"
|
|
17
|
+
UNDERLINE = "\033[4m"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def log(txt, c=Colors.ENDC, end="", is_error=False):
|
|
21
|
+
logn(txt, c=c, end=end, is_error=is_error)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def logn(txt, c=Colors.ENDC, end="\n", is_error=False):
|
|
25
|
+
file = sys.stderr if is_error else sys.stdout
|
|
26
|
+
print(c + txt + Colors.ENDC, end=end, flush=True, file=file)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LogType(enum.Enum):
|
|
30
|
+
ERROR = "Error:"
|
|
31
|
+
SUCCESS = "Woohoo!"
|
|
32
|
+
WARNING = "Warning:"
|
|
33
|
+
INFO = "Info:"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def clean_print(type: LogType, msg):
|
|
37
|
+
# Replace path to user’s home directory by a tilde symbol (~)
|
|
38
|
+
home_directory = os.path.expanduser("~")
|
|
39
|
+
home_directory_escaped = re.escape(home_directory)
|
|
40
|
+
msg = re.sub(home_directory_escaped, "~", msg)
|
|
41
|
+
|
|
42
|
+
# Split message into list, remove leading spaces and line breaks
|
|
43
|
+
msg = msg.split("\n")
|
|
44
|
+
msg = [line.lstrip() for line in msg]
|
|
45
|
+
while msg[0] == "" and len(msg) > 1:
|
|
46
|
+
msg.pop(0)
|
|
47
|
+
|
|
48
|
+
# Print message
|
|
49
|
+
indentation = len(type.value) + 1
|
|
50
|
+
if type == LogType.ERROR:
|
|
51
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.FAIL, is_error=True)
|
|
52
|
+
elif type == LogType.SUCCESS:
|
|
53
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.OKGREEN)
|
|
54
|
+
elif type == LogType.WARNING:
|
|
55
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.WARNING)
|
|
56
|
+
elif type == LogType.INFO:
|
|
57
|
+
log(f"\n{type.value} ".rjust(indentation), c=Colors.OKCYAN)
|
|
58
|
+
|
|
59
|
+
is_error = type == LogType.ERROR
|
|
60
|
+
for line_idx, line in enumerate(msg):
|
|
61
|
+
if line_idx != 0:
|
|
62
|
+
log(" " * indentation)
|
|
63
|
+
s_line = line.split("**")
|
|
64
|
+
for idx, l in enumerate(s_line):
|
|
65
|
+
c = Colors.ENDC if idx % 2 == 0 else Colors.BOLD
|
|
66
|
+
if idx != len(s_line) - 1:
|
|
67
|
+
log(l, c=c, is_error=is_error)
|
|
68
|
+
else:
|
|
69
|
+
logn(l, c=c, is_error=is_error)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def log_error(msg):
|
|
73
|
+
clean_print(LogType.ERROR, str(msg))
|
|
74
|
+
# ASCII art credit:
|
|
75
|
+
# https://textart4u.blogspot.com/2014/05/the-fail-whale-ascii-art-code.html
|
|
76
|
+
logn(
|
|
77
|
+
"""\n▄██████████████▄▐█▄▄▄▄█▌
|
|
78
|
+
██████▌▄▌▄▐▐▌███▌▀▀██▀▀
|
|
79
|
+
████▄█▌▄▌▄▐▐▌▀███▄▄█▌
|
|
80
|
+
▄▄▄▄▄██████████████\n\n""",
|
|
81
|
+
is_error=True,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def log_success(msg):
|
|
86
|
+
clean_print(LogType.SUCCESS, msg)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def log_warning(msg):
|
|
90
|
+
clean_print(LogType.WARNING, msg)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def log_info(msg):
|
|
94
|
+
clean_print(LogType.INFO, msg)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def list_table(list, padding=25, num_cols=4):
|
|
98
|
+
lines_per_column = int(math.ceil(len(list) / num_cols))
|
|
99
|
+
for i in range(lines_per_column):
|
|
100
|
+
for col in range(num_cols):
|
|
101
|
+
if i + col * lines_per_column < len(list):
|
|
102
|
+
print(
|
|
103
|
+
list[i + col * lines_per_column].ljust(padding),
|
|
104
|
+
end="",
|
|
105
|
+
)
|
|
106
|
+
print("\n\t", end="")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
110
|
+
# Modifications Copyright (c) 2025 AMD
|