lemonade-sdk 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (61) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +125 -0
  3. lemonade/cache.py +85 -0
  4. lemonade/cli.py +135 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/analyze_model.py +26 -0
  7. lemonade/common/build.py +223 -0
  8. lemonade/common/cli_helpers.py +139 -0
  9. lemonade/common/exceptions.py +98 -0
  10. lemonade/common/filesystem.py +368 -0
  11. lemonade/common/labels.py +61 -0
  12. lemonade/common/onnx_helpers.py +176 -0
  13. lemonade/common/plugins.py +10 -0
  14. lemonade/common/printing.py +110 -0
  15. lemonade/common/status.py +490 -0
  16. lemonade/common/system_info.py +390 -0
  17. lemonade/common/tensor_helpers.py +83 -0
  18. lemonade/common/test_helpers.py +28 -0
  19. lemonade/profilers/__init__.py +1 -0
  20. lemonade/profilers/memory_tracker.py +257 -0
  21. lemonade/profilers/profiler.py +55 -0
  22. lemonade/sequence.py +363 -0
  23. lemonade/state.py +159 -0
  24. lemonade/tools/__init__.py +1 -0
  25. lemonade/tools/adapter.py +104 -0
  26. lemonade/tools/bench.py +284 -0
  27. lemonade/tools/huggingface_bench.py +267 -0
  28. lemonade/tools/huggingface_load.py +520 -0
  29. lemonade/tools/humaneval.py +258 -0
  30. lemonade/tools/llamacpp.py +261 -0
  31. lemonade/tools/llamacpp_bench.py +154 -0
  32. lemonade/tools/management_tools.py +273 -0
  33. lemonade/tools/mmlu.py +327 -0
  34. lemonade/tools/ort_genai/__init__.py +0 -0
  35. lemonade/tools/ort_genai/oga.py +1129 -0
  36. lemonade/tools/ort_genai/oga_bench.py +142 -0
  37. lemonade/tools/perplexity.py +146 -0
  38. lemonade/tools/prompt.py +228 -0
  39. lemonade/tools/quark/__init__.py +0 -0
  40. lemonade/tools/quark/quark_load.py +172 -0
  41. lemonade/tools/quark/quark_quantize.py +439 -0
  42. lemonade/tools/report/__init__.py +0 -0
  43. lemonade/tools/report/llm_report.py +203 -0
  44. lemonade/tools/report/table.py +739 -0
  45. lemonade/tools/server/__init__.py +0 -0
  46. lemonade/tools/server/serve.py +1354 -0
  47. lemonade/tools/server/tool_calls.py +146 -0
  48. lemonade/tools/tool.py +374 -0
  49. lemonade/version.py +1 -0
  50. lemonade_install/__init__.py +1 -0
  51. lemonade_install/install.py +774 -0
  52. lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
  53. lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
  54. lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
  55. lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
  56. lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
  57. lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
  58. lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
  59. lemonade_server/cli.py +260 -0
  60. lemonade_server/model_manager.py +98 -0
  61. lemonade_server/server_models.json +142 -0
@@ -0,0 +1,176 @@
1
+ """
2
+ Helper functions for dealing with ONNX files and ONNX models
3
+ """
4
+
5
+ import os
6
+ from typing import Tuple, Union
7
+ import re
8
+ import math
9
+ import numpy as np
10
+ import onnx
11
+ import onnxruntime as ort
12
+ import lemonade.common.exceptions as exp
13
+ from lemonade.state import State
14
+ import lemonade.common.build as build
15
+
16
+
17
+ def check_model(onnx_file, success_message, fail_message) -> bool:
18
+ if os.path.isfile(onnx_file):
19
+ print(success_message)
20
+ else:
21
+ print(fail_message)
22
+ return False
23
+ try:
24
+ onnx.checker.check_model(onnx_file)
25
+ print("\tSuccessfully checked onnx file")
26
+ return True
27
+ except onnx.checker.ValidationError as e:
28
+ print("\tError while checking generated ONNX file")
29
+ print(e)
30
+ return False
31
+
32
+
33
+ def original_inputs_file(cache_dir: str, build_name: str):
34
+ return os.path.join(build.output_dir(cache_dir, build_name), "inputs.npy")
35
+
36
+
37
+ def onnx_dir(state: State):
38
+ return os.path.join(build.output_dir(state.cache_dir, state.build_name), "onnx")
39
+
40
+
41
+ def get_output_names(
42
+ onnx_model: Union[str, onnx.ModelProto],
43
+ ): # pylint: disable=no-member
44
+ # Get output names of ONNX file/model
45
+ if not isinstance(onnx_model, onnx.ModelProto): # pylint: disable=no-member
46
+ onnx_model = onnx.load(onnx_model)
47
+ return [node.name for node in onnx_model.graph.output] # pylint: disable=no-member
48
+
49
+
50
+ def parameter_count(model):
51
+ weights = model.graph.initializer
52
+ parameter_count = 0
53
+
54
+ for w in weights:
55
+ weight = onnx.numpy_helper.to_array(w)
56
+ parameter_count += np.prod(weight.shape)
57
+ return parameter_count
58
+
59
+
60
+ def io_bytes(onnx_path: str) -> Tuple[int, int]:
61
+ """Return the number of bytes of each of the inputs and outputs"""
62
+ # pylint: disable = no-member
63
+
64
+ def elem_type_to_bytes(elem_type) -> int:
65
+ """
66
+ Convert ONNX's elem_type to the number of bytes used by
67
+ hardware to send that specific datatype through PCIe
68
+ """
69
+ if (
70
+ elem_type == onnx.TensorProto.DataType.UINT8
71
+ or elem_type == onnx.TensorProto.DataType.INT8
72
+ or elem_type == onnx.TensorProto.DataType.BOOL
73
+ ):
74
+ # Each bool requires an entire byte
75
+ return 1
76
+ elif (
77
+ elem_type == onnx.TensorProto.DataType.UINT16
78
+ or elem_type == onnx.TensorProto.DataType.INT16
79
+ or elem_type == onnx.TensorProto.DataType.FLOAT16
80
+ ):
81
+ return 2
82
+ if (
83
+ elem_type == onnx.TensorProto.DataType.FLOAT
84
+ or elem_type == onnx.TensorProto.DataType.INT32
85
+ or elem_type == onnx.TensorProto.DataType.INT64
86
+ or elem_type == onnx.TensorProto.DataType.DOUBLE
87
+ or elem_type == onnx.TensorProto.DataType.UINT64
88
+ ):
89
+ # 64 bit ints are treated as 32 bits everywhere
90
+ # Doubles are treated as floats
91
+ return 4
92
+ elif (
93
+ elem_type == onnx.TensorProto.DataType.COMPLEX64
94
+ or elem_type == onnx.TensorProto.DataType.COMPLEX128
95
+ or elem_type == onnx.TensorProto.DataType.STRING
96
+ or elem_type == onnx.TensorProto.DataType.UNDEFINED
97
+ ):
98
+ raise exp.Error("Unsupported data type")
99
+ else:
100
+ raise exp.Error("Unsupported data type (unknown to ONNX)")
101
+
102
+ def get_nodes_bytes(nodes):
103
+ nodes_bytes = {}
104
+ for node in nodes:
105
+
106
+ # Get the number of the data type
107
+ dtype_bytes = elem_type_to_bytes(node.type.tensor_type.elem_type)
108
+
109
+ # Calculate the total number of elements based on the shape
110
+ shape = str(node.type.tensor_type.shape.dim)
111
+ num_elements = np.prod([int(s) for s in shape.split() if s.isdigit()])
112
+
113
+ # Assign a total number of bytes to each node
114
+ nodes_bytes[node.name] = num_elements * dtype_bytes
115
+
116
+ return nodes_bytes
117
+
118
+ # Get the number of bytes of each of the inputs and outputs
119
+ model = onnx.load(onnx_path)
120
+ onnx_input_bytes = get_nodes_bytes(model.graph.input)
121
+ onnx_output_bytes = get_nodes_bytes(model.graph.output)
122
+
123
+ return int(sum(onnx_input_bytes.values())), int(sum(onnx_output_bytes.values()))
124
+
125
+
126
+ def dtype_ort2str(dtype_str: str):
127
+ if dtype_str == "float16":
128
+ datatype = "float16"
129
+ elif dtype_str == "float":
130
+ datatype = "float32"
131
+ elif dtype_str == "double":
132
+ datatype = "float64"
133
+ elif dtype_str == "long":
134
+ datatype = "int64"
135
+ else:
136
+ datatype = dtype_str
137
+ return datatype
138
+
139
+
140
+ def dummy_inputs(onnx_file: str) -> dict:
141
+ # Generate dummy inputs of the expected shape and type for the input model
142
+ sess_options = ort.SessionOptions()
143
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
144
+ onnx_session = ort.InferenceSession(onnx_file, sess_options)
145
+ sess_input = onnx_session.get_inputs()
146
+
147
+ input_stats = []
148
+ for _idx, input_ in enumerate(range(len(sess_input))):
149
+ input_name = sess_input[input_].name
150
+ input_shape = sess_input[input_].shape
151
+
152
+ # TODO: Use onnx update_inputs_outputs_dims to automatically freeze models
153
+ for dim in input_shape:
154
+ if isinstance(dim, str) is True or math.isnan(dim) is True:
155
+ raise AssertionError(
156
+ "Error: Model has dynamic inputs. Freeze the graph and try again"
157
+ )
158
+
159
+ input_type = sess_input[input_].type
160
+ input_stats.append([input_name, input_shape, input_type])
161
+
162
+ input_feed = {}
163
+ for stat in input_stats:
164
+ dtype_str = re.search(r"\((.*)\)", stat[2])
165
+ assert dtype_str is not None
166
+ datatype = dtype_ort2str(dtype_str.group(1))
167
+ input_feed[stat[0]] = np.random.rand(*stat[1]).astype(datatype)
168
+ return input_feed
169
+
170
+
171
+ def get_opset(model: onnx.ModelProto) -> int:
172
+ return getattr(model.opset_import[0], "version", None)
173
+
174
+
175
+ # This file was originally licensed under Apache 2.0. It has been modified.
176
+ # Modifications Copyright (c) 2025 AMD
@@ -0,0 +1,10 @@
1
+ import pkgutil
2
+ import importlib
3
+
4
+
5
+ def discover():
6
+ return {
7
+ name: importlib.import_module(name)
8
+ for _, name, _ in pkgutil.iter_modules()
9
+ if name.startswith("turnkeyml_plugin_")
10
+ }
@@ -0,0 +1,110 @@
1
+ import os
2
+ import re
3
+ import enum
4
+ import sys
5
+ import math
6
+
7
+
8
+ class Colors:
9
+ HEADER = "\033[95m"
10
+ OKBLUE = "\033[94m"
11
+ OKCYAN = "\033[96m"
12
+ OKGREEN = "\033[92m"
13
+ WARNING = "\033[93m"
14
+ FAIL = "\033[91m"
15
+ ENDC = "\033[0m"
16
+ BOLD = "\033[1m"
17
+ UNDERLINE = "\033[4m"
18
+
19
+
20
+ def log(txt, c=Colors.ENDC, end="", is_error=False):
21
+ logn(txt, c=c, end=end, is_error=is_error)
22
+
23
+
24
+ def logn(txt, c=Colors.ENDC, end="\n", is_error=False):
25
+ file = sys.stderr if is_error else sys.stdout
26
+ print(c + txt + Colors.ENDC, end=end, flush=True, file=file)
27
+
28
+
29
+ class LogType(enum.Enum):
30
+ ERROR = "Error:"
31
+ SUCCESS = "Woohoo!"
32
+ WARNING = "Warning:"
33
+ INFO = "Info:"
34
+
35
+
36
+ def clean_print(type: LogType, msg):
37
+ # Replace path to user’s home directory by a tilde symbol (~)
38
+ home_directory = os.path.expanduser("~")
39
+ home_directory_escaped = re.escape(home_directory)
40
+ msg = re.sub(home_directory_escaped, "~", msg)
41
+
42
+ # Split message into list, remove leading spaces and line breaks
43
+ msg = msg.split("\n")
44
+ msg = [line.lstrip() for line in msg]
45
+ while msg[0] == "" and len(msg) > 1:
46
+ msg.pop(0)
47
+
48
+ # Print message
49
+ indentation = len(type.value) + 1
50
+ if type == LogType.ERROR:
51
+ log(f"\n{type.value} ".rjust(indentation), c=Colors.FAIL, is_error=True)
52
+ elif type == LogType.SUCCESS:
53
+ log(f"\n{type.value} ".rjust(indentation), c=Colors.OKGREEN)
54
+ elif type == LogType.WARNING:
55
+ log(f"\n{type.value} ".rjust(indentation), c=Colors.WARNING)
56
+ elif type == LogType.INFO:
57
+ log(f"\n{type.value} ".rjust(indentation), c=Colors.OKCYAN)
58
+
59
+ is_error = type == LogType.ERROR
60
+ for line_idx, line in enumerate(msg):
61
+ if line_idx != 0:
62
+ log(" " * indentation)
63
+ s_line = line.split("**")
64
+ for idx, l in enumerate(s_line):
65
+ c = Colors.ENDC if idx % 2 == 0 else Colors.BOLD
66
+ if idx != len(s_line) - 1:
67
+ log(l, c=c, is_error=is_error)
68
+ else:
69
+ logn(l, c=c, is_error=is_error)
70
+
71
+
72
+ def log_error(msg):
73
+ clean_print(LogType.ERROR, str(msg))
74
+ # ASCII art credit:
75
+ # https://textart4u.blogspot.com/2014/05/the-fail-whale-ascii-art-code.html
76
+ logn(
77
+ """\n▄██████████████▄▐█▄▄▄▄█▌
78
+ ██████▌▄▌▄▐▐▌███▌▀▀██▀▀
79
+ ████▄█▌▄▌▄▐▐▌▀███▄▄█▌
80
+ ▄▄▄▄▄██████████████\n\n""",
81
+ is_error=True,
82
+ )
83
+
84
+
85
+ def log_success(msg):
86
+ clean_print(LogType.SUCCESS, msg)
87
+
88
+
89
+ def log_warning(msg):
90
+ clean_print(LogType.WARNING, msg)
91
+
92
+
93
+ def log_info(msg):
94
+ clean_print(LogType.INFO, msg)
95
+
96
+
97
+ def list_table(list, padding=25, num_cols=4):
98
+ lines_per_column = int(math.ceil(len(list) / num_cols))
99
+ for i in range(lines_per_column):
100
+ for col in range(num_cols):
101
+ if i + col * lines_per_column < len(list):
102
+ print(
103
+ list[i + col * lines_per_column].ljust(padding),
104
+ end="",
105
+ )
106
+ print("\n\t", end="")
107
+
108
+
109
+ # This file was originally licensed under Apache 2.0. It has been modified.
110
+ # Modifications Copyright (c) 2025 AMD