lemonade-sdk 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/__init__.py +5 -0
- lemonade/api.py +125 -0
- lemonade/cache.py +85 -0
- lemonade/cli.py +135 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/analyze_model.py +26 -0
- lemonade/common/build.py +223 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/labels.py +61 -0
- lemonade/common/onnx_helpers.py +176 -0
- lemonade/common/plugins.py +10 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +490 -0
- lemonade/common/system_info.py +390 -0
- lemonade/common/tensor_helpers.py +83 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/memory_tracker.py +257 -0
- lemonade/profilers/profiler.py +55 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/adapter.py +104 -0
- lemonade/tools/bench.py +284 -0
- lemonade/tools/huggingface_bench.py +267 -0
- lemonade/tools/huggingface_load.py +520 -0
- lemonade/tools/humaneval.py +258 -0
- lemonade/tools/llamacpp.py +261 -0
- lemonade/tools/llamacpp_bench.py +154 -0
- lemonade/tools/management_tools.py +273 -0
- lemonade/tools/mmlu.py +327 -0
- lemonade/tools/ort_genai/__init__.py +0 -0
- lemonade/tools/ort_genai/oga.py +1129 -0
- lemonade/tools/ort_genai/oga_bench.py +142 -0
- lemonade/tools/perplexity.py +146 -0
- lemonade/tools/prompt.py +228 -0
- lemonade/tools/quark/__init__.py +0 -0
- lemonade/tools/quark/quark_load.py +172 -0
- lemonade/tools/quark/quark_quantize.py +439 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +739 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/serve.py +1354 -0
- lemonade/tools/server/tool_calls.py +146 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +774 -0
- lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
- lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
- lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
- lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
- lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
- lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +260 -0
- lemonade_server/model_manager.py +98 -0
- lemonade_server/server_models.json +142 -0
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
import importlib.metadata
|
|
3
|
+
import platform
|
|
4
|
+
import re
|
|
5
|
+
import subprocess
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SystemInfo(ABC):
|
|
9
|
+
"""Abstract base class for OS-dependent system information classes"""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
def get_dict(self):
|
|
15
|
+
"""
|
|
16
|
+
Retrieves all the system information into a dictionary
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
dict: System information
|
|
20
|
+
"""
|
|
21
|
+
info_dict = {
|
|
22
|
+
"OS Version": self.get_os_version(),
|
|
23
|
+
"Python Packages": self.get_python_packages(),
|
|
24
|
+
}
|
|
25
|
+
return info_dict
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def get_os_version() -> str:
|
|
29
|
+
"""
|
|
30
|
+
Retrieves the OS version.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
str: OS Version
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
return platform.platform()
|
|
37
|
+
except Exception as e: # pylint: disable=broad-except
|
|
38
|
+
return f"ERROR - {e}"
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def get_python_packages() -> list:
|
|
42
|
+
"""
|
|
43
|
+
Retrieves the Python package versions.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
list: List of Python package versions in the form ["package-name==package-version", ...]
|
|
47
|
+
"""
|
|
48
|
+
# Get Python Packages
|
|
49
|
+
distributions = importlib.metadata.distributions()
|
|
50
|
+
return [
|
|
51
|
+
f"{dist.metadata['name']}=={dist.metadata['version']}"
|
|
52
|
+
for dist in distributions
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class WindowsSystemInfo(SystemInfo):
|
|
57
|
+
"""Class used to access system information in Windows"""
|
|
58
|
+
|
|
59
|
+
def __init__(self):
|
|
60
|
+
super().__init__()
|
|
61
|
+
import wmi
|
|
62
|
+
|
|
63
|
+
self.connection = wmi.WMI()
|
|
64
|
+
|
|
65
|
+
def get_processor_name(self) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Retrieves the name of the processor.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
str: Name of the processor.
|
|
71
|
+
"""
|
|
72
|
+
processors = self.connection.Win32_Processor()
|
|
73
|
+
if processors:
|
|
74
|
+
return (
|
|
75
|
+
f"{processors[0].Name.strip()} "
|
|
76
|
+
f"({processors[0].NumberOfCores} cores, "
|
|
77
|
+
f"{processors[0].NumberOfLogicalProcessors} logical processors)"
|
|
78
|
+
)
|
|
79
|
+
return "Processor information not found."
|
|
80
|
+
|
|
81
|
+
def get_system_model(self) -> str:
|
|
82
|
+
"""
|
|
83
|
+
Retrieves the model of the computer system.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
str: Model of the computer system.
|
|
87
|
+
"""
|
|
88
|
+
systems = self.connection.Win32_ComputerSystem()
|
|
89
|
+
if systems:
|
|
90
|
+
return systems[0].Model
|
|
91
|
+
return "System model information not found."
|
|
92
|
+
|
|
93
|
+
def get_physical_memory(self) -> str:
|
|
94
|
+
"""
|
|
95
|
+
Retrieves the physical memory of the computer system.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
str: Physical memory
|
|
99
|
+
"""
|
|
100
|
+
memory = self.connection.Win32_PhysicalMemory()
|
|
101
|
+
if memory:
|
|
102
|
+
total_capacity = sum([int(m.Capacity) for m in memory])
|
|
103
|
+
total_capacity_str = f"{total_capacity/(1024**3)} GB"
|
|
104
|
+
details_str = " + ".join(
|
|
105
|
+
[
|
|
106
|
+
f"{m.Manufacturer} {int(m.Capacity)/(1024**3)} GB {m.Speed} ns"
|
|
107
|
+
for m in memory
|
|
108
|
+
]
|
|
109
|
+
)
|
|
110
|
+
return total_capacity_str + " (" + details_str + ")"
|
|
111
|
+
return "Physical memory information not found."
|
|
112
|
+
|
|
113
|
+
def get_bios_version(self) -> str:
|
|
114
|
+
"""
|
|
115
|
+
Retrieves the BIOS Version of the computer system.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
str: BIOS Version
|
|
119
|
+
"""
|
|
120
|
+
bios = self.connection.Win32_BIOS()
|
|
121
|
+
if bios:
|
|
122
|
+
return bios[0].Name
|
|
123
|
+
return "BIOS Version not found."
|
|
124
|
+
|
|
125
|
+
def get_max_clock_speed(self) -> str:
|
|
126
|
+
"""
|
|
127
|
+
Retrieves the max clock speed of the CPU of the system.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
str: Max CPU clock speed
|
|
131
|
+
"""
|
|
132
|
+
processor = self.connection.Win32_Processor()
|
|
133
|
+
if processor:
|
|
134
|
+
return f"{processor[0].MaxClockSpeed} MHz"
|
|
135
|
+
return "Max CPU clock speed not found."
|
|
136
|
+
|
|
137
|
+
def get_driver_version(self, device_name) -> str:
|
|
138
|
+
"""
|
|
139
|
+
Retrieves the driver version for the specified device name.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
str: Driver version, or None if device driver not found
|
|
143
|
+
"""
|
|
144
|
+
drivers = self.connection.Win32_PnPSignedDriver(DeviceName=device_name)
|
|
145
|
+
if drivers:
|
|
146
|
+
return drivers[0].DriverVersion
|
|
147
|
+
return ""
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def get_npu_power_mode() -> str:
|
|
151
|
+
"""
|
|
152
|
+
Retrieves the NPU power mode.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
str: NPU power mode
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
out = subprocess.check_output(
|
|
159
|
+
[
|
|
160
|
+
r"C:\Windows\System32\AMD\xrt-smi.exe",
|
|
161
|
+
"examine",
|
|
162
|
+
"-r",
|
|
163
|
+
"platform",
|
|
164
|
+
],
|
|
165
|
+
stderr=subprocess.STDOUT,
|
|
166
|
+
).decode()
|
|
167
|
+
lines = out.splitlines()
|
|
168
|
+
modes = [line.split()[-1] for line in lines if "Mode" in line]
|
|
169
|
+
if len(modes) > 0:
|
|
170
|
+
return modes[0]
|
|
171
|
+
except FileNotFoundError:
|
|
172
|
+
# xrt-smi not present
|
|
173
|
+
pass
|
|
174
|
+
except subprocess.CalledProcessError:
|
|
175
|
+
pass
|
|
176
|
+
return "NPU power mode not found."
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def get_windows_power_setting() -> str:
|
|
180
|
+
"""
|
|
181
|
+
Retrieves the Windows power setting.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
str: Windows power setting.
|
|
185
|
+
"""
|
|
186
|
+
try:
|
|
187
|
+
out = subprocess.check_output(["powercfg", "/getactivescheme"]).decode()
|
|
188
|
+
return re.search(r"\((.*?)\)", out).group(1)
|
|
189
|
+
except subprocess.CalledProcessError:
|
|
190
|
+
pass
|
|
191
|
+
return "Windows power setting not found"
|
|
192
|
+
|
|
193
|
+
def get_dict(self) -> dict:
|
|
194
|
+
"""
|
|
195
|
+
Retrieves all the system information into a dictionary
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
dict: System information
|
|
199
|
+
"""
|
|
200
|
+
info_dict = super().get_dict()
|
|
201
|
+
info_dict["Processor"] = self.get_processor_name()
|
|
202
|
+
info_dict["OEM System"] = self.get_system_model()
|
|
203
|
+
info_dict["Physical Memory"] = self.get_physical_memory()
|
|
204
|
+
info_dict["BIOS Version"] = self.get_bios_version()
|
|
205
|
+
info_dict["CPU Max Clock"] = self.get_max_clock_speed()
|
|
206
|
+
info_dict["Windows Power Setting"] = self.get_windows_power_setting()
|
|
207
|
+
if "AMD" in info_dict["Processor"]:
|
|
208
|
+
device_names = [
|
|
209
|
+
"NPU Compute Accelerator Device",
|
|
210
|
+
"AMD-OpenCL User Mode Driver",
|
|
211
|
+
]
|
|
212
|
+
driver_versions = {
|
|
213
|
+
device_name: self.get_driver_version(device_name)
|
|
214
|
+
for device_name in device_names
|
|
215
|
+
}
|
|
216
|
+
info_dict["Driver Versions"] = {
|
|
217
|
+
k: (v if len(v) else "DEVICE NOT FOUND")
|
|
218
|
+
for k, v in driver_versions.items()
|
|
219
|
+
}
|
|
220
|
+
info_dict["NPU Power Mode"] = self.get_npu_power_mode()
|
|
221
|
+
return info_dict
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class WSLSystemInfo(SystemInfo):
|
|
225
|
+
"""Class used to access system information in WSL"""
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def get_system_model() -> str:
|
|
229
|
+
"""
|
|
230
|
+
Retrieves the model of the computer system.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
str: Model of the computer system.
|
|
234
|
+
"""
|
|
235
|
+
try:
|
|
236
|
+
oem_info = (
|
|
237
|
+
subprocess.check_output(
|
|
238
|
+
'powershell.exe -Command "wmic computersystem get model"',
|
|
239
|
+
shell=True,
|
|
240
|
+
)
|
|
241
|
+
.decode()
|
|
242
|
+
.strip()
|
|
243
|
+
)
|
|
244
|
+
oem_info = (
|
|
245
|
+
oem_info.replace("\r", "").replace("\n", "").split("Model")[-1].strip()
|
|
246
|
+
)
|
|
247
|
+
return oem_info
|
|
248
|
+
except Exception as e: # pylint: disable=broad-except
|
|
249
|
+
return f"ERROR - {e}"
|
|
250
|
+
|
|
251
|
+
def get_dict(self) -> dict:
|
|
252
|
+
"""
|
|
253
|
+
Retrieves all the system information into a dictionary
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
dict: System information
|
|
257
|
+
"""
|
|
258
|
+
info_dict = super().get_dict()
|
|
259
|
+
info_dict["OEM System"] = self.get_system_model()
|
|
260
|
+
return info_dict
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class LinuxSystemInfo(SystemInfo):
|
|
264
|
+
"""Class used to access system information in Linux"""
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def get_processor_name() -> str:
|
|
268
|
+
"""
|
|
269
|
+
Retrieves the name of the processor.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
str: Name of the processor.
|
|
273
|
+
"""
|
|
274
|
+
# Get CPU Information
|
|
275
|
+
try:
|
|
276
|
+
cpu_info = subprocess.check_output("lscpu", shell=True).decode()
|
|
277
|
+
for line in cpu_info.split("\n"):
|
|
278
|
+
if "Model name:" in line:
|
|
279
|
+
return line.split(":")[1].strip()
|
|
280
|
+
except Exception as e: # pylint: disable=broad-except
|
|
281
|
+
return f"ERROR - {e}"
|
|
282
|
+
|
|
283
|
+
@staticmethod
|
|
284
|
+
def get_system_model() -> str:
|
|
285
|
+
"""
|
|
286
|
+
Retrieves the model of the computer system.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
str: Model of the computer system.
|
|
290
|
+
"""
|
|
291
|
+
# Get OEM System Information
|
|
292
|
+
try:
|
|
293
|
+
oem_info = (
|
|
294
|
+
subprocess.check_output(
|
|
295
|
+
"sudo -n dmidecode -s system-product-name",
|
|
296
|
+
shell=True,
|
|
297
|
+
stderr=subprocess.DEVNULL,
|
|
298
|
+
)
|
|
299
|
+
.decode()
|
|
300
|
+
.strip()
|
|
301
|
+
.replace("\n", " ")
|
|
302
|
+
)
|
|
303
|
+
return oem_info
|
|
304
|
+
except subprocess.CalledProcessError:
|
|
305
|
+
# This catches the case where sudo requires a password
|
|
306
|
+
return "Unable to get oem info - password required"
|
|
307
|
+
except Exception as e: # pylint: disable=broad-except
|
|
308
|
+
return f"ERROR - {e}"
|
|
309
|
+
|
|
310
|
+
@staticmethod
|
|
311
|
+
def get_physical_memory() -> str:
|
|
312
|
+
"""
|
|
313
|
+
Retrieves the physical memory of the computer system.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
str: Physical memory
|
|
317
|
+
"""
|
|
318
|
+
try:
|
|
319
|
+
mem_info = (
|
|
320
|
+
subprocess.check_output("free -m", shell=True)
|
|
321
|
+
.decode()
|
|
322
|
+
.split("\n")[1]
|
|
323
|
+
.split()[1]
|
|
324
|
+
)
|
|
325
|
+
mem_info_gb = round(int(mem_info) / 1024, 2)
|
|
326
|
+
return f"{mem_info_gb} GB"
|
|
327
|
+
except Exception as e: # pylint: disable=broad-except
|
|
328
|
+
return f"ERROR - {e}"
|
|
329
|
+
|
|
330
|
+
def get_dict(self) -> dict:
|
|
331
|
+
"""
|
|
332
|
+
Retrieves all the system information into a dictionary
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
dict: System information
|
|
336
|
+
"""
|
|
337
|
+
info_dict = super().get_dict()
|
|
338
|
+
info_dict["Processor"] = self.get_processor_name()
|
|
339
|
+
info_dict["OEM System"] = self.get_system_model()
|
|
340
|
+
info_dict["Physical Memory"] = self.get_physical_memory()
|
|
341
|
+
return info_dict
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class UnsupportedOSSystemInfo(SystemInfo):
|
|
345
|
+
"""Class used to access system information in unsupported operating systems"""
|
|
346
|
+
|
|
347
|
+
def get_dict(self):
|
|
348
|
+
"""
|
|
349
|
+
Retrieves all the system information into a dictionary
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
dict: System information
|
|
353
|
+
"""
|
|
354
|
+
info_dict = super().get_dict()
|
|
355
|
+
info_dict["Error"] = "UNSUPPORTED OS"
|
|
356
|
+
return info_dict
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def get_system_info() -> SystemInfo:
|
|
360
|
+
"""
|
|
361
|
+
Creates the appropriate SystemInfo object based on the operating system.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
A subclass of SystemInfo for the current operating system.
|
|
365
|
+
"""
|
|
366
|
+
os_type = platform.system()
|
|
367
|
+
if os_type == "Windows":
|
|
368
|
+
return WindowsSystemInfo()
|
|
369
|
+
elif os_type == "Linux":
|
|
370
|
+
# WSL has to be handled differently compared to native Linux
|
|
371
|
+
if "microsoft" in str(platform.release()):
|
|
372
|
+
return WSLSystemInfo()
|
|
373
|
+
else:
|
|
374
|
+
return LinuxSystemInfo()
|
|
375
|
+
else:
|
|
376
|
+
return UnsupportedOSSystemInfo()
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def get_system_info_dict() -> dict:
|
|
380
|
+
"""
|
|
381
|
+
Puts the system information into a dictionary.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
dict: Dictionary containing the system information.
|
|
385
|
+
"""
|
|
386
|
+
return get_system_info().get_dict()
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
390
|
+
# Modifications Copyright (c) 2025 AMD
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper functions for dealing with tensors
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import copy
|
|
7
|
+
import torch
|
|
8
|
+
import numpy as np
|
|
9
|
+
import lemonade.common.exceptions as exp
|
|
10
|
+
import lemonade.common.build as build
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Checks whether a given input has the expected shape
|
|
14
|
+
def check_shapes_and_dtypes(
|
|
15
|
+
inputs, expected_shapes, expected_dtypes, expect_downcast=False, raise_error=True
|
|
16
|
+
):
|
|
17
|
+
current_shapes, current_dtypes = build.get_shapes_and_dtypes(inputs)
|
|
18
|
+
|
|
19
|
+
# If we are modifying the data type of inputs on a later tool we
|
|
20
|
+
# verify input type based on the future data type conversion
|
|
21
|
+
if expect_downcast:
|
|
22
|
+
for key, value in current_dtypes.items():
|
|
23
|
+
if value == "float32":
|
|
24
|
+
current_dtypes[key] = "float16"
|
|
25
|
+
elif value == "int64":
|
|
26
|
+
current_dtypes[key] = "int32"
|
|
27
|
+
|
|
28
|
+
input_shapes_changed = expected_shapes != current_shapes
|
|
29
|
+
input_dtypes_changed = expected_dtypes != current_dtypes
|
|
30
|
+
|
|
31
|
+
if input_shapes_changed and raise_error:
|
|
32
|
+
msg = f"""
|
|
33
|
+
Model built to always take input of shape
|
|
34
|
+
{expected_shapes} but got {current_shapes}
|
|
35
|
+
"""
|
|
36
|
+
raise exp.Error(msg)
|
|
37
|
+
elif input_dtypes_changed and raise_error:
|
|
38
|
+
msg = f"""
|
|
39
|
+
Model built to always take input of types
|
|
40
|
+
{expected_dtypes} but got {current_dtypes}
|
|
41
|
+
"""
|
|
42
|
+
raise exp.Error(msg)
|
|
43
|
+
|
|
44
|
+
return input_shapes_changed, input_dtypes_changed
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def save_inputs(inputs, inputs_file, input_dtypes=None, downcast=True):
|
|
48
|
+
|
|
49
|
+
# Detach and downcast inputs
|
|
50
|
+
inputs_converted = copy.deepcopy(inputs)
|
|
51
|
+
for i in range(len(inputs_converted)):
|
|
52
|
+
inputs_converted[i] = {
|
|
53
|
+
k: v for k, v in inputs_converted[i].items() if v is not None
|
|
54
|
+
}
|
|
55
|
+
for k in inputs_converted[i].keys():
|
|
56
|
+
if not hasattr(inputs_converted[i][k], "dtype"):
|
|
57
|
+
continue
|
|
58
|
+
if torch.is_tensor(inputs_converted[i][k]):
|
|
59
|
+
inputs_converted[i][k] = inputs_converted[i][k].cpu().detach().numpy()
|
|
60
|
+
if downcast:
|
|
61
|
+
if input_dtypes is not None and input_dtypes[k] is not None:
|
|
62
|
+
inputs_converted[i][k] = inputs_converted[i][k].astype(
|
|
63
|
+
input_dtypes[k]
|
|
64
|
+
)
|
|
65
|
+
continue
|
|
66
|
+
if (
|
|
67
|
+
inputs_converted[i][k].dtype == np.float32
|
|
68
|
+
or inputs_converted[i][k].dtype == np.float64
|
|
69
|
+
):
|
|
70
|
+
inputs_converted[i][k] = inputs_converted[i][k].astype("float16")
|
|
71
|
+
if inputs_converted[i][k].dtype == np.int64:
|
|
72
|
+
inputs_converted[i][k] = inputs_converted[i][k].astype("int32")
|
|
73
|
+
|
|
74
|
+
# Save models inputs to file for later profiling
|
|
75
|
+
if os.path.isfile(inputs_file):
|
|
76
|
+
os.remove(inputs_file)
|
|
77
|
+
np.save(inputs_file, inputs_converted)
|
|
78
|
+
|
|
79
|
+
return inputs_converted
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
83
|
+
# Modifications Copyright (c) 2025 AMD
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def create_test_dir(
|
|
6
|
+
key: str,
|
|
7
|
+
base_dir: str = os.path.dirname(os.path.abspath(__file__)),
|
|
8
|
+
):
|
|
9
|
+
# Define paths to be used
|
|
10
|
+
cache_dir = os.path.join(base_dir, "generated", f"{key}_cache_dir")
|
|
11
|
+
corpus_dir = os.path.join(base_dir, "generated", "test_corpus")
|
|
12
|
+
|
|
13
|
+
# Delete folders if they exist and
|
|
14
|
+
if os.path.isdir(cache_dir):
|
|
15
|
+
shutil.rmtree(cache_dir)
|
|
16
|
+
if os.path.isdir(corpus_dir):
|
|
17
|
+
shutil.rmtree(corpus_dir)
|
|
18
|
+
os.makedirs(corpus_dir, exist_ok=True)
|
|
19
|
+
|
|
20
|
+
return cache_dir, corpus_dir
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def strip_dot_py(test_script_file: str) -> str:
|
|
24
|
+
return test_script_file.split(".")[0]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
28
|
+
# Modifications Copyright (c) 2025 AMD
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .profiler import Profiler
|